Skip to content
forked from pydata/xarray

Commit

Permalink
Some refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 7, 2024
1 parent 978fad9 commit d1a3fc1
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
8 changes: 8 additions & 0 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
Bins,
DaCompatible,
NetcdfWriteModes,
T_Chunks,
T_DataArray,
T_DataArrayOrSet,
ZarrWriteModes,
Expand Down Expand Up @@ -105,6 +106,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupIndices,
GroupInput,
InterpOptions,
PadModeOptions,
Expand Down Expand Up @@ -1687,6 +1689,12 @@ def sel(
)
return self._from_temp_dataset(ds)

def _shuffle(
self, dim: Hashable, *, indices: GroupIndices, chunks: T_Chunks
) -> Self:
ds = self._to_temp_dataset()._shuffle(dim=dim, indices=indices, chunks=chunks)
return self._from_temp_dataset(ds)

def head(
self,
indexers: Mapping[Any, int] | int | None = None,
Expand Down
15 changes: 12 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupIndices,
GroupInput,
InterpOptions,
JoinOptions,
Expand Down Expand Up @@ -3238,7 +3239,7 @@ def sel(
result = self.isel(indexers=query_results.dim_indexers, drop=drop)
return result._overwrite_indexes(*query_results.as_tuple()[1:])

def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> Self:
def _shuffle(self, dim, *, indices: GroupIndices, chunks: T_Chunks) -> Self:
# Shuffling is only different from `isel` for chunked arrays.
# Extract them out, and treat them specially. The rest, we route through isel.
# This makes it easy to ensure correct handling of indexes.
Expand All @@ -3249,14 +3250,22 @@ def _shuffle(self, dim, *, indices: list[list[int]], chunks: T_Chunks) -> Self:
}
subset = self[[name for name in self._variables if name not in is_chunked]]

no_slices: list[list[int]] = [
list(range(*idx.indices(self.sizes[dim])))
if isinstance(idx, slice)
else idx
for idx in indices
]
no_slices = [idx for idx in no_slices if idx]

shuffled = (
subset
if dim not in subset.dims
else subset.isel({dim: np.concatenate(indices)})
else subset.isel({dim: np.concatenate(no_slices)})
)
for name, var in is_chunked.items():
shuffled[name] = var._shuffle(
indices=indices,
indices=no_slices,
dim=dim,
chunks=chunks,
)
Expand Down
9 changes: 1 addition & 8 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,19 +743,12 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
was_array = isinstance(self._obj, DataArray)
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj

size = self._obj.sizes[self._group_dim]
no_slices: list[list[int]] = [
list(range(*idx.indices(size))) if isinstance(idx, slice) else idx
for idx in self.encoded.group_indices
]
no_slices = [idx for idx in no_slices if idx]

for grouper in self.groupers:
if grouper.name not in as_dataset._variables:
as_dataset.coords[grouper.name] = grouper.group

shuffled = as_dataset._shuffle(
dim=self._group_dim, indices=no_slices, chunks=chunks
dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks
)
shuffled = self._maybe_unstack(shuffled)
new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled
Expand Down

0 comments on commit d1a3fc1

Please sign in to comment.