Skip to content

Commit d557418

Browse files
committed
Add typing
1 parent e07ae31 commit d557418

File tree

1 file changed

+75
-40
lines changed

1 file changed

+75
-40
lines changed

xarray/core/groupby.py

Lines changed: 75 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Callable,
1212
Generic,
1313
Literal,
14+
overload,
1415
TypeVar,
1516
Union,
1617
cast,
@@ -36,7 +37,7 @@
3637
)
3738
from xarray.core.options import _get_keep_attrs
3839
from xarray.core.pycompat import integer_types
39-
from xarray.core.types import Dims, QuantileMethods, T_Xarray
40+
from xarray.core.types import Dims, QuantileMethods, T_DataArray, T_Xarray
4041
from xarray.core.utils import (
4142
either_dict_or_kwargs,
4243
hashable,
@@ -56,9 +57,7 @@
5657

5758
GroupKey = Any
5859
GroupIndex = Union[int, slice, list[int]]
59-
60-
T_GroupIndicesListInt = list[list[int]]
61-
T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray]
60+
T_GroupIndices = list[GroupIndex]
6261

6362

6463
def check_reduce_dims(reduce_dims, dimensions):
@@ -99,8 +98,8 @@ def unique_value_groups(
9998
return values, groups, inverse
10099

101100

102-
def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndicesListInt:
103-
groups: T_GroupIndicesListInt = [[] for _ in range(N)]
101+
def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
102+
groups: T_GroupIndices = [[] for _ in range(N)]
104103
for n, g in enumerate(inverse):
105104
if g >= 0:
106105
groups[g].append(n)
@@ -147,7 +146,7 @@ def _is_one_or_none(obj) -> bool:
147146

148147
def _consolidate_slices(slices: list[slice]) -> list[slice]:
149148
"""Consolidate adjacent slices in a list of slices."""
150-
result = []
149+
result: list[slice] = []
151150
last_slice = slice(None)
152151
for slice_ in slices:
153152
if not isinstance(slice_, slice):
@@ -191,7 +190,7 @@ def _inverse_permutation_indices(positions, N: int | None = None) -> np.ndarray
191190
return newpositions[newpositions != -1]
192191

193192

194-
class _DummyGroup:
193+
class _DummyGroup(Generic[T_Xarray]):
195194
"""Class for keeping track of grouped dimensions without coordinates.
196195
197196
Should not be user visible.
@@ -247,18 +246,19 @@ def to_dataarray(self) -> DataArray:
247246
)
248247

249248

250-
T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
249+
# T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
250+
T_Group = Union["T_DataArray", "IndexVariable", _DummyGroup]
251251

252252

253253
def _ensure_1d(
254254
group: T_Group, obj: T_Xarray
255-
) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable]]:
255+
) -> tuple[T_Group, T_Xarray, Hashable | None, list[Hashable],]:
256256
# 1D cases: do nothing
257-
from xarray.core.dataarray import DataArray
258-
259257
if isinstance(group, (IndexVariable, _DummyGroup)) or group.ndim == 1:
260258
return group, obj, None, []
261259

260+
from xarray.core.dataarray import DataArray
261+
262262
if isinstance(group, DataArray):
263263
# try to stack the dims of the group into a single dim
264264
orig_dims = group.dims
@@ -267,7 +267,7 @@ def _ensure_1d(
267267
inserted_dims = [dim for dim in group.dims if dim not in group.coords]
268268
newgroup = group.stack({stacked_dim: orig_dims})
269269
newobj = obj.stack({stacked_dim: orig_dims})
270-
return cast(T_Group, newgroup), newobj, stacked_dim, inserted_dims
270+
return newgroup, newobj, stacked_dim, inserted_dims
271271

272272
raise TypeError(
273273
f"group must be DataArray, IndexVariable or _DummyGroup, got {type(group)!r}."
@@ -311,25 +311,36 @@ def _apply_loffset(
311311
result.index = result.index + loffset
312312

313313

314-
class ResolvedGrouper(ABC):
315-
def __init__(self, grouper: Grouper, group, obj):
316-
self.labels = None
317-
self._group_as_index: pd.Index | None = None
314+
@dataclass
315+
class ResolvedGrouper(ABC, Generic[T_Xarray]):
316+
grouper: Grouper
317+
group: T_Group
318+
obj: T_Xarray
319+
320+
_group_as_index: pd.Index | None = field(default=None, init=False)
321+
322+
# Not used here:?
323+
labels: Any | None = field(default=None, init=False) # TODO: Typing?
324+
codes: DataArray = field(init=False)
325+
group_indices: T_GroupIndices = field(init=False)
326+
unique_coord: IndexVariable | _DummyGroup = field(init=False)
327+
full_index: pd.Index = field(init=False)
318328

319-
self.codes: DataArray
320-
self.group_indices: list[int] | list[slice] | list[list[int]]
321-
self.unique_coord: IndexVariable | _DummyGroup
322-
self.full_index: pd.Index
329+
# _ensure_1d:
330+
group1d: T_Group = field(init=False)
331+
stacked_obj: T_Xarray = field(init=False)
332+
stacked_dim: Hashable | None = field(init=False)
333+
inserted_dims: list[Hashable] = field(init=False)
323334

324-
self.grouper = grouper
325-
self.group = _resolve_group(obj, group)
335+
def __post_init__(self) -> None:
336+
self.group: T_Group = _resolve_group(self.obj, self.group)
326337

327338
(
328339
self.group1d,
329340
self.stacked_obj,
330341
self.stacked_dim,
331342
self.inserted_dims,
332-
) = _ensure_1d(self.group, obj)
343+
) = _ensure_1d(group=self.group, obj=self.obj)
333344

334345
@property
335346
def name(self) -> Hashable:
@@ -340,7 +351,7 @@ def size(self) -> int:
340351
return len(self)
341352

342353
def __len__(self) -> int:
343-
return len(self.full_index)
354+
return len(self.full_index) # TODO: full_index not def, abstractmethod?
344355

345356
@property
346357
def dims(self):
@@ -364,7 +375,10 @@ def group_as_index(self) -> pd.Index:
364375
return self._group_as_index
365376

366377

378+
@dataclass
367379
class ResolvedUniqueGrouper(ResolvedGrouper):
380+
grouper: UniqueGrouper
381+
368382
def factorize(self, squeeze) -> None:
369383
is_dimension = self.group.dims == (self.group.name,)
370384
if is_dimension and self.is_unique_and_monotonic:
@@ -407,7 +421,10 @@ def _factorize_dummy(self, squeeze) -> None:
407421
self.full_index = IndexVariable(self.name, self.group.values, self.group.attrs)
408422

409423

424+
@dataclass
410425
class ResolvedBinGrouper(ResolvedGrouper):
426+
grouper: BinGrouper
427+
411428
def factorize(self, squeeze: bool) -> None:
412429
from xarray.core.dataarray import DataArray
413430

@@ -438,21 +455,26 @@ def factorize(self, squeeze: bool) -> None:
438455
self.group_indices = group_indices
439456

440457

458+
@dataclass
441459
class ResolvedTimeResampleGrouper(ResolvedGrouper):
442-
def __init__(self, grouper, group, obj):
443-
from xarray import CFTimeIndex
444-
from xarray.core.resample_cftime import CFTimeGrouper
460+
grouper: TimeResampleGrouper
461+
462+
def __post_init__(self) -> None:
463+
super().__post_init__()
445464

446-
super().__init__(grouper, group, obj)
465+
from xarray import CFTimeIndex
447466

448-
self._group_as_index = safe_cast_to_index(group)
449-
group_as_index = self._group_as_index
467+
group_as_index = safe_cast_to_index(self.group)
468+
self._group_as_index = group_as_index
450469

451470
if not group_as_index.is_monotonic_increasing:
452471
# TODO: sort instead of raising an error
453472
raise ValueError("index must be monotonic for resampling")
454473

474+
grouper = self.grouper
455475
if isinstance(group_as_index, CFTimeIndex):
476+
from xarray.core.resample_cftime import CFTimeGrouper
477+
456478
index_grouper = CFTimeGrouper(
457479
freq=grouper.freq,
458480
closed=grouper.closed,
@@ -501,9 +523,9 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
501523
def factorize(self, squeeze: bool) -> None:
502524
self.full_index, first_items, codes = self._get_index_and_items()
503525
sbins = first_items.values.astype(np.int64)
504-
self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])] + [
505-
slice(sbins[-1], None)
506-
]
526+
self.group_indices = [slice(i, j) for i, j in zip(sbins[:-1], sbins[1:])]
527+
self.group_indices += [slice(sbins[-1], None)]
528+
507529
self.unique_coord = IndexVariable(
508530
self.group.name, first_items.index, self.group.attrs
509531
)
@@ -550,7 +572,7 @@ def _validate_groupby_squeeze(squeeze):
550572
raise TypeError(f"`squeeze` must be True or False, but {squeeze} was supplied")
551573

552574

553-
def _resolve_group(obj, group: T_Group | Hashable) -> T_Group:
575+
def _resolve_group(obj: T_Xarray, group: T_Group | Hashable) -> T_Group:
554576
from xarray.core.dataarray import DataArray
555577

556578
if isinstance(group, (DataArray, IndexVariable)):
@@ -625,6 +647,19 @@ class GroupBy(Generic[T_Xarray]):
625647
"_codes",
626648
)
627649
_obj: T_Xarray
650+
groupers: tuple[ResolvedGrouper]
651+
_squeeze: bool
652+
_restore_coord_dims: bool
653+
654+
_original_obj: T_Xarray
655+
_original_group: T_Group
656+
_group_indices: T_GroupIndices
657+
_codes: DataArray
658+
_group_dim: Hashable
659+
660+
_groups: dict[GroupKey, GroupIndex] | None
661+
_dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None
662+
_sizes: Frozen[Hashable, int] | None
628663

629664
def __init__(
630665
self,
@@ -647,7 +682,7 @@ def __init__(
647682
"""
648683
self.groupers = groupers
649684

650-
self._original_obj: T_Xarray = obj
685+
self._original_obj = obj
651686

652687
for grouper_ in self.groupers:
653688
grouper_.factorize(squeeze)
@@ -656,7 +691,7 @@ def __init__(
656691
self._original_group = grouper.group
657692

658693
# specification for the groupby operation
659-
self._obj: T_Xarray = grouper.stacked_obj
694+
self._obj = grouper.stacked_obj
660695
self._restore_coord_dims = restore_coord_dims
661696
self._squeeze = squeeze
662697

@@ -666,9 +701,9 @@ def __init__(
666701

667702
(self._group_dim,) = grouper.group1d.dims
668703
# cached attributes
669-
self._groups: dict[GroupKey, slice | int | list[int]] | None = None
670-
self._dims: tuple[Hashable, ...] | Frozen[Hashable, int] | None = None
671-
self._sizes: Frozen[Hashable, int] | None = None
704+
self._groups = None
705+
self._dims = None
706+
self._sizes = None
672707

673708
@property
674709
def sizes(self) -> Frozen[Hashable, int]:

0 commit comments

Comments
 (0)