Skip to content

Commit 74d7bb0

Browse files
committed
return xarray object from distributed_shuffle
1 parent 231533c commit 74d7bb0

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

xarray/core/groupby.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def sizes(self) -> Mapping[Hashable, int]:
682682
self._sizes = self._obj.isel({self._group_dim: index}).sizes
683683
return self._sizes
684684

685-
def distributed_shuffle(self, chunks: T_Chunks = None):
685+
def distributed_shuffle(self, chunks: T_Chunks = None) -> T_Xarray:
686686
"""
687687
Sort or "shuffle" the underlying object.
688688
@@ -711,8 +711,8 @@ def distributed_shuffle(self, chunks: T_Chunks = None):
711711
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
712712
... name="a",
713713
... )
714-
>>> shuffled = da.groupby("x").shuffle()
715-
>>> shuffled.quantile(q=0.5).compute()
714+
>>> shuffled = da.groupby("x")
715+
>>> shuffled.groupby("x").quantile(q=0.5).compute()
716716
<xarray.DataArray 'a' (x: 4)> Size: 32B
717717
array([9., 3., 4., 5.])
718718
Coordinates:
@@ -725,17 +725,7 @@ def distributed_shuffle(self, chunks: T_Chunks = None):
725725
dask.array.shuffle
726726
"""
727727
self._raise_if_by_is_chunked()
728-
new_groupers = {
729-
# Using group.name handles the BinGrouper case
730-
# It does *not* handle the TimeResampler case,
731-
# so we just override this method in Resample
732-
grouper.group.name: grouper.grouper.reset()
733-
for grouper in self.groupers
734-
}
735-
return self._shuffle_obj(chunks).groupby(
736-
new_groupers,
737-
restore_coord_dims=self._restore_coord_dims,
738-
)
728+
return self._shuffle_obj(chunks)
739729

740730
def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
741731
from xarray.core.dataarray import DataArray

xarray/tests/test_groupby.py

+41-10
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None
672672
ds["variable"] = ds["variable"].chunk(chunk)
673673
grouped = ds.groupby(ds.id)
674674
if shuffle:
675-
grouped = grouped.distributed_shuffle()
675+
grouped = grouped.distributed_shuffle().groupby(ds.id)
676676

677677
# non reduction operation
678678
expected1 = ds.copy()
@@ -1418,7 +1418,7 @@ def test_groupby_reductions(
14181418
with raise_if_dask_computes():
14191419
grouped = array.groupby("abc")
14201420
if shuffle:
1421-
grouped = grouped.distributed_shuffle()
1421+
grouped = grouped.distributed_shuffle().groupby("abc")
14221422

14231423
with xr.set_options(use_flox=use_flox):
14241424
actual = getattr(grouped, method)(dim="y")
@@ -1687,13 +1687,16 @@ def test_groupby_bins(
16871687

16881688
with xr.set_options(use_flox=use_flox):
16891689
gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs)
1690+
shuffled = gb.distributed_shuffle().groupby_bins(
1691+
"dim_0", bins=bins, **cut_kwargs
1692+
)
16901693
actual = gb.sum()
16911694
assert_identical(expected, actual)
1692-
assert_identical(expected, gb.distributed_shuffle().sum())
1695+
assert_identical(expected, shuffled.sum())
16931696

16941697
actual = gb.map(lambda x: x.sum())
16951698
assert_identical(expected, actual)
1696-
assert_identical(expected, gb.distributed_shuffle().map(lambda x: x.sum()))
1699+
assert_identical(expected, shuffled.map(lambda x: x.sum()))
16971700

16981701
# make sure original array dims are unchanged
16991702
assert len(array.dim_0) == 4
@@ -1877,17 +1880,18 @@ def resample_as_pandas(array, *args, **kwargs):
18771880
array = DataArray(np.arange(10), [("time", times)])
18781881

18791882
rs = array.resample(time=resample_freq)
1883+
shuffled = rs.distributed_shuffle().resample(time=resample_freq)
18801884
actual = rs.mean()
18811885
expected = resample_as_pandas(array, resample_freq)
18821886
assert_identical(expected, actual)
1883-
assert_identical(expected, rs.distributed_shuffle().mean())
1887+
assert_identical(expected, shuffled.mean())
18841888

18851889
assert_identical(expected, rs.reduce(np.mean))
1886-
assert_identical(expected, rs.distributed_shuffle().reduce(np.mean))
1890+
assert_identical(expected, shuffled.reduce(np.mean))
18871891

18881892
rs = array.resample(time="24h", closed="right")
18891893
actual = rs.mean()
1890-
shuffled = rs.distributed_shuffle().mean()
1894+
shuffled = rs.distributed_shuffle().resample(time="24h", closed="right").mean()
18911895
expected = resample_as_pandas(array, "24h", closed="right")
18921896
assert_identical(expected, actual)
18931897
assert_identical(expected, shuffled)
@@ -3168,20 +3172,47 @@ def test_groupby_multiple_bin_grouper_missing_groups():
31683172

31693173

31703174
@requires_dask_ge_2024_08_1
3171-
def test_shuffle_by_simple() -> None:
3175+
def test_shuffle_simple() -> None:
31723176
import dask
31733177

31743178
da = xr.DataArray(
31753179
dims="x",
31763180
data=dask.array.from_array([1, 2, 3, 4, 5, 6], chunks=2),
31773181
coords={"label": ("x", "a b c a b c".split(" "))},
31783182
)
3179-
actual = da.distributed_shuffle_by(label=UniqueGrouper())
3183+
actual = da.groupby(label=UniqueGrouper()).distributed_shuffle()
31803184
expected = da.isel(x=[0, 3, 1, 4, 2, 5])
31813185
assert_identical(actual, expected)
31823186

31833187
with pytest.raises(ValueError):
3184-
da.chunk(x=2, eagerly_load_group=False).distributed_shuffle_by("label")
3188+
da.chunk(x=2, eagerly_load_group=False).groupby("label").distributed_shuffle()
3189+
3190+
3191+
@requires_dask_ge_2024_08_1
3192+
@pytest.mark.parametrize(
3193+
"chunks, expected_chunks",
3194+
[
3195+
((1,), (1, 3, 3, 3)),
3196+
((10,), (10,)),
3197+
],
3198+
)
3199+
def test_shuffle_by(chunks, expected_chunks):
3200+
import dask.array
3201+
3202+
from xarray.groupers import UniqueGrouper
3203+
3204+
da = xr.DataArray(
3205+
dims="x",
3206+
data=dask.array.arange(10, chunks=chunks),
3207+
coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
3208+
name="a",
3209+
)
3210+
ds = da.to_dataset()
3211+
3212+
for obj in [ds, da]:
3213+
actual = obj.groupby(x=UniqueGrouper()).distributed_shuffle()
3214+
assert_identical(actual, obj.sortby("x"))
3215+
assert actual.chunksizes["x"] == expected_chunks
31853216

31863217

31873218
@requires_dask

0 commit comments

Comments
 (0)