Skip to content

Commit c77d7c5

Browse files
committed
return xarray object from distributed_shuffle
1 parent 231533c commit c77d7c5

File tree

2 files changed

+60
-32
lines changed

2 files changed

+60
-32
lines changed

xarray/core/groupby.py

+4-14
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def sizes(self) -> Mapping[Hashable, int]:
682682
self._sizes = self._obj.isel({self._group_dim: index}).sizes
683683
return self._sizes
684684

685-
def distributed_shuffle(self, chunks: T_Chunks = None):
685+
def distributed_shuffle(self, chunks: T_Chunks = None) -> T_Xarray:
686686
"""
687687
Sort or "shuffle" the underlying object.
688688
@@ -711,8 +711,8 @@ def distributed_shuffle(self, chunks: T_Chunks = None):
711711
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
712712
... name="a",
713713
... )
714-
>>> shuffled = da.groupby("x").shuffle()
715-
>>> shuffled.quantile(q=0.5).compute()
714+
>>> shuffled = da.groupby("x")
715+
>>> shuffled.groupby("x").quantile(q=0.5).compute()
716716
<xarray.DataArray 'a' (x: 4)> Size: 32B
717717
array([9., 3., 4., 5.])
718718
Coordinates:
@@ -725,17 +725,7 @@ def distributed_shuffle(self, chunks: T_Chunks = None):
725725
dask.array.shuffle
726726
"""
727727
self._raise_if_by_is_chunked()
728-
new_groupers = {
729-
# Using group.name handles the BinGrouper case
730-
# It does *not* handle the TimeResampler case,
731-
# so we just override this method in Resample
732-
grouper.group.name: grouper.grouper.reset()
733-
for grouper in self.groupers
734-
}
735-
return self._shuffle_obj(chunks).groupby(
736-
new_groupers,
737-
restore_coord_dims=self._restore_coord_dims,
738-
)
728+
return self._shuffle_obj(chunks)
739729

740730
def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
741731
from xarray.core.dataarray import DataArray

xarray/tests/test_groupby.py

+56-18
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None
672672
ds["variable"] = ds["variable"].chunk(chunk)
673673
grouped = ds.groupby(ds.id)
674674
if shuffle:
675-
grouped = grouped.distributed_shuffle()
675+
grouped = grouped.distributed_shuffle().groupby(ds.id)
676676

677677
# non reduction operation
678678
expected1 = ds.copy()
@@ -1418,7 +1418,7 @@ def test_groupby_reductions(
14181418
with raise_if_dask_computes():
14191419
grouped = array.groupby("abc")
14201420
if shuffle:
1421-
grouped = grouped.distributed_shuffle()
1421+
grouped = grouped.distributed_shuffle().groupby("abc")
14221422

14231423
with xr.set_options(use_flox=use_flox):
14241424
actual = getattr(grouped, method)(dim="y")
@@ -1687,13 +1687,16 @@ def test_groupby_bins(
16871687

16881688
with xr.set_options(use_flox=use_flox):
16891689
gb = array.groupby_bins("dim_0", bins=bins, **cut_kwargs)
1690+
shuffled = gb.distributed_shuffle().groupby_bins(
1691+
"dim_0", bins=bins, **cut_kwargs
1692+
)
16901693
actual = gb.sum()
16911694
assert_identical(expected, actual)
1692-
assert_identical(expected, gb.distributed_shuffle().sum())
1695+
assert_identical(expected, shuffled.sum())
16931696

16941697
actual = gb.map(lambda x: x.sum())
16951698
assert_identical(expected, actual)
1696-
assert_identical(expected, gb.distributed_shuffle().map(lambda x: x.sum()))
1699+
assert_identical(expected, shuffled.map(lambda x: x.sum()))
16971700

16981701
# make sure original array dims are unchanged
16991702
assert len(array.dim_0) == 4
@@ -1877,17 +1880,18 @@ def resample_as_pandas(array, *args, **kwargs):
18771880
array = DataArray(np.arange(10), [("time", times)])
18781881

18791882
rs = array.resample(time=resample_freq)
1883+
shuffled = rs.distributed_shuffle().resample(time=resample_freq)
18801884
actual = rs.mean()
18811885
expected = resample_as_pandas(array, resample_freq)
18821886
assert_identical(expected, actual)
1883-
assert_identical(expected, rs.distributed_shuffle().mean())
1887+
assert_identical(expected, shuffled.mean())
18841888

18851889
assert_identical(expected, rs.reduce(np.mean))
1886-
assert_identical(expected, rs.distributed_shuffle().reduce(np.mean))
1890+
assert_identical(expected, shuffled.reduce(np.mean))
18871891

18881892
rs = array.resample(time="24h", closed="right")
18891893
actual = rs.mean()
1890-
shuffled = rs.distributed_shuffle().mean()
1894+
shuffled = rs.distributed_shuffle().resample(time="24h", closed="right").mean()
18911895
expected = resample_as_pandas(array, "24h", closed="right")
18921896
assert_identical(expected, actual)
18931897
assert_identical(expected, shuffled)
@@ -2830,9 +2834,11 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
28302834
name="foo",
28312835
)
28322836

2833-
gb = da.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper())
2837+
groupers: dict[str, Grouper]
2838+
groupers = dict(labels1=UniqueGrouper(), labels2=UniqueGrouper())
2839+
gb = da.groupby(groupers)
28342840
if shuffle:
2835-
gb = gb.distributed_shuffle()
2841+
gb = gb.distributed_shuffle().groupby(groupers)
28362842
repr(gb)
28372843

28382844
expected = DataArray(
@@ -2851,9 +2857,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
28512857
# -------
28522858
coords = {"a": ("x", [0, 0, 1, 1]), "b": ("y", [0, 0, 1, 1])}
28532859
square = DataArray(np.arange(16).reshape(4, 4), coords=coords, dims=["x", "y"])
2854-
gb = square.groupby(a=UniqueGrouper(), b=UniqueGrouper())
2860+
groupers = dict(a=UniqueGrouper(), b=UniqueGrouper())
2861+
gb = square.groupby(groupers)
28552862
if shuffle:
2856-
gb = gb.distributed_shuffle()
2863+
gb = gb.distributed_shuffle().groupby(groupers)
28572864
repr(gb)
28582865
with xr.set_options(use_flox=use_flox):
28592866
actual = gb.mean()
@@ -2883,9 +2890,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
28832890
with xr.set_options(use_flox=use_flox):
28842891
assert_identical(gb.mean("z"), b.mean("z"))
28852892

2886-
gb = b.groupby(x=UniqueGrouper(), xy=UniqueGrouper())
2893+
groupers = dict(x=UniqueGrouper(), xy=UniqueGrouper())
2894+
gb = b.groupby(groupers)
28872895
if shuffle:
2888-
gb = gb.distributed_shuffle()
2896+
gb = gb.distributed_shuffle().groupby(groupers)
28892897
repr(gb)
28902898
with xr.set_options(use_flox=use_flox):
28912899
actual = gb.mean()
@@ -2937,9 +2945,12 @@ def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None:
29372945
{"foo": (("x", "y"), np.arange(12).reshape((4, 3)))},
29382946
coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
29392947
)
2940-
gb = ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper())
2948+
groupers: dict[str, Grouper] = dict(
2949+
x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()
2950+
)
2951+
gb = ds.groupby(groupers)
29412952
if shuffle:
2942-
gb = gb.distributed_shuffle()
2953+
gb = gb.distributed_shuffle().groupby(groupers)
29432954
expected_data = np.array(
29442955
[
29452956
[[0.0, np.nan], [np.nan, 3.0]],
@@ -3168,20 +3179,47 @@ def test_groupby_multiple_bin_grouper_missing_groups():
31683179

31693180

31703181
@requires_dask_ge_2024_08_1
3171-
def test_shuffle_by_simple() -> None:
3182+
def test_shuffle_simple() -> None:
31723183
import dask
31733184

31743185
da = xr.DataArray(
31753186
dims="x",
31763187
data=dask.array.from_array([1, 2, 3, 4, 5, 6], chunks=2),
31773188
coords={"label": ("x", "a b c a b c".split(" "))},
31783189
)
3179-
actual = da.distributed_shuffle_by(label=UniqueGrouper())
3190+
actual = da.groupby(label=UniqueGrouper()).distributed_shuffle()
31803191
expected = da.isel(x=[0, 3, 1, 4, 2, 5])
31813192
assert_identical(actual, expected)
31823193

31833194
with pytest.raises(ValueError):
3184-
da.chunk(x=2, eagerly_load_group=False).distributed_shuffle_by("label")
3195+
da.chunk(x=2, eagerly_load_group=False).groupby("label").distributed_shuffle()
3196+
3197+
3198+
@requires_dask_ge_2024_08_1
3199+
@pytest.mark.parametrize(
3200+
"chunks, expected_chunks",
3201+
[
3202+
((1,), (1, 3, 3, 3)),
3203+
((10,), (10,)),
3204+
],
3205+
)
3206+
def test_shuffle_by(chunks, expected_chunks):
3207+
import dask.array
3208+
3209+
from xarray.groupers import UniqueGrouper
3210+
3211+
da = xr.DataArray(
3212+
dims="x",
3213+
data=dask.array.arange(10, chunks=chunks),
3214+
coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
3215+
name="a",
3216+
)
3217+
ds = da.to_dataset()
3218+
3219+
for obj in [ds, da]:
3220+
actual = obj.groupby(x=UniqueGrouper()).distributed_shuffle()
3221+
assert_identical(actual, obj.sortby("x"))
3222+
assert actual.chunksizes["x"] == expected_chunks
31853223

31863224

31873225
@requires_dask

0 commit comments

Comments
 (0)