11
11
Callable ,
12
12
Generic ,
13
13
Literal ,
14
+ overload ,
14
15
TypeVar ,
15
16
Union ,
16
17
cast ,
36
37
)
37
38
from xarray .core .options import _get_keep_attrs
38
39
from xarray .core .pycompat import integer_types
39
- from xarray .core .types import Dims , QuantileMethods , T_Xarray
40
+ from xarray .core .types import Dims , QuantileMethods , T_DataArray , T_Xarray
40
41
from xarray .core .utils import (
41
42
either_dict_or_kwargs ,
42
43
hashable ,
56
57
57
58
GroupKey = Any
58
59
GroupIndex = Union [int , slice , list [int ]]
59
-
60
- T_GroupIndicesListInt = list [list [int ]]
61
- T_GroupIndices = Union [T_GroupIndicesListInt , list [slice ], np .ndarray ]
60
+ T_GroupIndices = list [GroupIndex ]
62
61
63
62
64
63
def check_reduce_dims (reduce_dims , dimensions ):
@@ -99,8 +98,8 @@ def unique_value_groups(
99
98
return values , groups , inverse
100
99
101
100
102
- def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndicesListInt :
103
- groups : T_GroupIndicesListInt = [[] for _ in range (N )]
101
+ def _codes_to_groups (inverse : np .ndarray , N : int ) -> T_GroupIndices :
102
+ groups : T_GroupIndices = [[] for _ in range (N )]
104
103
for n , g in enumerate (inverse ):
105
104
if g >= 0 :
106
105
groups [g ].append (n )
@@ -147,7 +146,7 @@ def _is_one_or_none(obj) -> bool:
147
146
148
147
def _consolidate_slices (slices : list [slice ]) -> list [slice ]:
149
148
"""Consolidate adjacent slices in a list of slices."""
150
- result = []
149
+ result : list [ slice ] = []
151
150
last_slice = slice (None )
152
151
for slice_ in slices :
153
152
if not isinstance (slice_ , slice ):
@@ -191,7 +190,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray
191
190
return newpositions [newpositions != - 1 ]
192
191
193
192
194
- class _DummyGroup :
193
+ class _DummyGroup ( Generic [ T_Xarray ]) :
195
194
"""Class for keeping track of grouped dimensions without coordinates.
196
195
197
196
Should not be user visible.
@@ -247,18 +246,19 @@ def to_dataarray(self) -> DataArray:
247
246
)
248
247
249
248
250
- T_Group = TypeVar ("T_Group" , bound = Union ["DataArray" , "IndexVariable" , _DummyGroup ])
249
+ # T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
250
+ T_Group = Union ["T_DataArray" , "IndexVariable" , _DummyGroup ]
251
251
252
252
253
253
def _ensure_1d (
254
254
group : T_Group , obj : T_Xarray
255
- ) -> tuple [T_Group , T_Xarray , Hashable | None , list [Hashable ]]:
255
+ ) -> tuple [T_Group , T_Xarray , Hashable | None , list [Hashable ], ]:
256
256
# 1D cases: do nothing
257
- from xarray .core .dataarray import DataArray
258
-
259
257
if isinstance (group , (IndexVariable , _DummyGroup )) or group .ndim == 1 :
260
258
return group , obj , None , []
261
259
260
+ from xarray .core .dataarray import DataArray
261
+
262
262
if isinstance (group , DataArray ):
263
263
# try to stack the dims of the group into a single dim
264
264
orig_dims = group .dims
@@ -267,7 +267,7 @@ def _ensure_1d(
267
267
inserted_dims = [dim for dim in group .dims if dim not in group .coords ]
268
268
newgroup = group .stack ({stacked_dim : orig_dims })
269
269
newobj = obj .stack ({stacked_dim : orig_dims })
270
- return cast ( T_Group , newgroup ) , newobj , stacked_dim , inserted_dims
270
+ return newgroup , newobj , stacked_dim , inserted_dims
271
271
272
272
raise TypeError (
273
273
f"group must be DataArray, IndexVariable or _DummyGroup, got { type (group )!r} ."
@@ -311,25 +311,36 @@ def _apply_loffset(
311
311
result .index = result .index + loffset
312
312
313
313
314
- class ResolvedGrouper (ABC ):
315
- def __init__ (self , grouper : Grouper , group , obj ):
316
- self .labels = None
317
- self ._group_as_index : pd .Index | None = None
314
+ @dataclass
315
+ class ResolvedGrouper (ABC , Generic [T_Xarray ]):
316
+ grouper : Grouper
317
+ group : T_Group
318
+ obj : T_Xarray
319
+
320
+ _group_as_index : pd .Index | None = field (default = None , init = False )
321
+
322
+ # Not used here:?
323
+ labels : Any | None = field (default = None , init = False ) # TODO: Typing?
324
+ codes : DataArray = field (init = False )
325
+ group_indices : T_GroupIndices = field (init = False )
326
+ unique_coord : IndexVariable | _DummyGroup = field (init = False )
327
+ full_index : pd .Index = field (init = False )
318
328
319
- self .codes : DataArray
320
- self .group_indices : list [int ] | list [slice ] | list [list [int ]]
321
- self .unique_coord : IndexVariable | _DummyGroup
322
- self .full_index : pd .Index
329
+ # _ensure_1d:
330
+ group1d : T_Group = field (init = False )
331
+ stacked_obj : T_Xarray = field (init = False )
332
+ stacked_dim : Hashable | None = field (init = False )
333
+ inserted_dims : list [Hashable ] = field (init = False )
323
334
324
- self . grouper = grouper
325
- self .group = _resolve_group (obj , group )
335
+ def __post_init__ ( self ) -> None :
336
+ self .group : T_Group = _resolve_group (self . obj , self . group )
326
337
327
338
(
328
339
self .group1d ,
329
340
self .stacked_obj ,
330
341
self .stacked_dim ,
331
342
self .inserted_dims ,
332
- ) = _ensure_1d (self .group , obj )
343
+ ) = _ensure_1d (group = self .group , obj = self . obj )
333
344
334
345
@property
335
346
def name (self ) -> Hashable :
@@ -340,7 +351,7 @@ def size(self) -> int:
340
351
return len (self )
341
352
342
353
def __len__ (self ) -> int :
343
- return len (self .full_index )
354
+ return len (self .full_index ) # TODO: full_index not def, abstractmethod?
344
355
345
356
@property
346
357
def dims (self ):
@@ -364,7 +375,10 @@ def group_as_index(self) -> pd.Index:
364
375
return self ._group_as_index
365
376
366
377
378
+ @dataclass
367
379
class ResolvedUniqueGrouper (ResolvedGrouper ):
380
+ grouper : UniqueGrouper
381
+
368
382
def factorize (self , squeeze ) -> None :
369
383
is_dimension = self .group .dims == (self .group .name ,)
370
384
if is_dimension and self .is_unique_and_monotonic :
@@ -407,7 +421,10 @@ def _factorize_dummy(self, squeeze) -> None:
407
421
self .full_index = IndexVariable (self .name , self .group .values , self .group .attrs )
408
422
409
423
424
+ @dataclass
410
425
class ResolvedBinGrouper (ResolvedGrouper ):
426
+ grouper : BinGrouper
427
+
411
428
def factorize (self , squeeze : bool ) -> None :
412
429
from xarray .core .dataarray import DataArray
413
430
@@ -438,21 +455,26 @@ def factorize(self, squeeze: bool) -> None:
438
455
self .group_indices = group_indices
439
456
440
457
458
+ @dataclass
441
459
class ResolvedTimeResampleGrouper (ResolvedGrouper ):
442
- def __init__ (self , grouper , group , obj ):
443
- from xarray import CFTimeIndex
444
- from xarray .core .resample_cftime import CFTimeGrouper
460
+ grouper : TimeResampleGrouper
461
+
462
+ def __post_init__ (self ) -> None :
463
+ super ().__post_init__ ()
445
464
446
- super (). __init__ ( grouper , group , obj )
465
+ from xarray import CFTimeIndex
447
466
448
- self . _group_as_index = safe_cast_to_index (group )
449
- group_as_index = self ._group_as_index
467
+ group_as_index = safe_cast_to_index (self . group )
468
+ self ._group_as_index = group_as_index
450
469
451
470
if not group_as_index .is_monotonic_increasing :
452
471
# TODO: sort instead of raising an error
453
472
raise ValueError ("index must be monotonic for resampling" )
454
473
474
+ grouper = self .grouper
455
475
if isinstance (group_as_index , CFTimeIndex ):
476
+ from xarray .core .resample_cftime import CFTimeGrouper
477
+
456
478
index_grouper = CFTimeGrouper (
457
479
freq = grouper .freq ,
458
480
closed = grouper .closed ,
@@ -501,9 +523,9 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
501
523
def factorize (self , squeeze : bool ) -> None :
502
524
self .full_index , first_items , codes = self ._get_index_and_items ()
503
525
sbins = first_items .values .astype (np .int64 )
504
- self .group_indices = [slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :])] + [
505
- slice (sbins [- 1 ], None )
506
- ]
526
+ self .group_indices = [slice (i , j ) for i , j in zip (sbins [:- 1 ], sbins [1 :])]
527
+ self . group_indices += [ slice (sbins [- 1 ], None )]
528
+
507
529
self .unique_coord = IndexVariable (
508
530
self .group .name , first_items .index , self .group .attrs
509
531
)
@@ -550,7 +572,7 @@ def _validate_groupby_squeeze(squeeze):
550
572
raise TypeError (f"`squeeze` must be True or False, but { squeeze } was supplied" )
551
573
552
574
553
- def _resolve_group (obj , group : T_Group | Hashable ) -> T_Group :
575
+ def _resolve_group (obj : T_Xarray , group : T_Group | Hashable ) -> T_Group :
554
576
from xarray .core .dataarray import DataArray
555
577
556
578
if isinstance (group , (DataArray , IndexVariable )):
@@ -625,6 +647,19 @@ class GroupBy(Generic[T_Xarray]):
625
647
"_codes" ,
626
648
)
627
649
_obj : T_Xarray
650
+ groupers : tuple [ResolvedGrouper ]
651
+ _squeeze : bool
652
+ _restore_coord_dims : bool
653
+
654
+ _original_obj : T_Xarray
655
+ _original_group : T_Group
656
+ _group_indices : T_GroupIndices
657
+ _codes : DataArray
658
+ _group_dim : Hashable
659
+
660
+ _groups : dict [GroupKey , GroupIndex ] | None
661
+ _dims : tuple [Hashable , ...] | Frozen [Hashable , int ] | None
662
+ _sizes : Frozen [Hashable , int ] | None
628
663
629
664
def __init__ (
630
665
self ,
@@ -647,7 +682,7 @@ def __init__(
647
682
"""
648
683
self .groupers = groupers
649
684
650
- self ._original_obj : T_Xarray = obj
685
+ self ._original_obj = obj
651
686
652
687
for grouper_ in self .groupers :
653
688
grouper_ .factorize (squeeze )
@@ -656,7 +691,7 @@ def __init__(
656
691
self ._original_group = grouper .group
657
692
658
693
# specification for the groupby operation
659
- self ._obj : T_Xarray = grouper .stacked_obj
694
+ self ._obj = grouper .stacked_obj
660
695
self ._restore_coord_dims = restore_coord_dims
661
696
self ._squeeze = squeeze
662
697
@@ -666,9 +701,9 @@ def __init__(
666
701
667
702
(self ._group_dim ,) = grouper .group1d .dims
668
703
# cached attributes
669
- self ._groups : dict [ GroupKey , slice | int | list [ int ]] | None = None
670
- self ._dims : tuple [ Hashable , ...] | Frozen [ Hashable , int ] | None = None
671
- self ._sizes : Frozen [ Hashable , int ] | None = None
704
+ self ._groups = None
705
+ self ._dims = None
706
+ self ._sizes = None
672
707
673
708
@property
674
709
def sizes (self ) -> Frozen [Hashable , int ]:
0 commit comments