Skip to content

Commit 1df705e

Browse files
committed
Cleanup
1 parent 3bc51bd commit 1df705e

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

xarray/core/duck_array_ops.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -835,13 +835,9 @@ def chunked_nanlast(darray, axis):
835835

836836
def shuffle_array(array, indices: list[list[int]], axis: int):
837837
# TODO: do chunk manager dance here.
838-
if is_duck_dask_array(array):
839-
if not module_available("dask", minversion="2024.08.0"):
840-
raise ValueError(
841-
"This method is very inefficient on dask<2024.08.0. Please upgrade."
842-
)
843-
# TODO: handle dimensions
844-
return array.shuffle(indexer=indices, axis=axis)
838+
if is_chunked_array(array):
839+
chunkmanager = get_chunked_array_type(array)
840+
return chunkmanager.shuffle(array, indexer=indices, axis=axis)
845841
else:
846842
indexer = np.concatenate(indices)
847843
# TODO: Do the array API thing here.

xarray/core/groupby.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,8 @@ def shuffle(self) -> None:
538538
if all(isinstance(idx, slice) for idx in self._group_indices):
539539
return
540540

541+
indices: tuple[list[int]] = self._group_indices # type: ignore[assignment]
542+
541543
was_array = isinstance(self._obj, DataArray)
542544
as_dataset = self._obj._to_temp_dataset() if was_array else self._obj
543545

@@ -547,20 +549,24 @@ def shuffle(self) -> None:
547549
shuffled[name] = var
548550
continue
549551
shuffled_data = shuffle_array(
550-
var._data, list(self._group_indices), axis=var.get_axis_num(dim)
552+
var._data, list(indices), axis=var.get_axis_num(dim)
551553
)
552554
shuffled[name] = var._replace(data=shuffled_data)
553555

554556
# Replace self._group_indices with slices
555557
slices = []
556558
start = 0
557559
for idxr in self._group_indices:
560+
if TYPE_CHECKING:
561+
assert not isinstance(idxr, slice)
558562
slices.append(slice(start, start + len(idxr)))
559563
start += len(idxr)
560564
# TODO: we have now broken the invariant
561565
# self._group_indices ≠ self.groupers[0].group_indices
562566
self._group_indices = tuple(slices)
563567
if was_array:
568+
if TYPE_CHECKING:
569+
assert isinstance(self._obj, DataArray)
564570
self._obj = self._obj._from_temp_dataset(shuffled)
565571
else:
566572
self._obj = shuffled

xarray/core/types.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def copy(
297297
ZarrWriteModes = Literal["w", "w-", "a", "a-", "r+", "r"]
298298

299299
GroupKey = Any
300-
GroupIndex = Union[int, slice, list[int]]
300+
GroupIndex = Union[slice, list[int]]
301301
GroupIndices = tuple[GroupIndex, ...]
302302
Bins = Union[
303303
int, Sequence[int], Sequence[float], Sequence[pd.Timestamp], np.ndarray, pd.Index

xarray/namedarray/daskmanager.py

+9
Original file line numberDiff line numberDiff line change
@@ -251,3 +251,12 @@ def store(
251251
targets=targets,
252252
**kwargs,
253253
)
254+
255+
def shuffle(self, x: DaskArray, indexer: list[list[int]], axis: int) -> DaskArray:
256+
import dask.array
257+
258+
if not module_available("dask", minversion="2024.08.0"):
259+
raise ValueError(
260+
"This method is very inefficient on dask<2024.08.0. Please upgrade."
261+
)
262+
return dask.array.shuffle(x, indexer, axis)

xarray/namedarray/parallelcompat.py

+5
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,11 @@ def compute(
364364
"""
365365
raise NotImplementedError()
366366

367+
def shuffle(
368+
self, x: T_ChunkedArray, indexer: list[list[int]], axis: int
369+
) -> T_ChunkedArray:
370+
raise NotImplementedError()
371+
367372
@property
368373
def array_api(self) -> Any:
369374
"""

0 commit comments

Comments
 (0)