@@ -672,7 +672,7 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None
672
672
ds ["variable" ] = ds ["variable" ].chunk (chunk )
673
673
grouped = ds .groupby (ds .id )
674
674
if shuffle :
675
- grouped = grouped .distributed_shuffle ()
675
+ grouped = grouped .distributed_shuffle (). groupby ( ds . id )
676
676
677
677
# non reduction operation
678
678
expected1 = ds .copy ()
@@ -1418,7 +1418,7 @@ def test_groupby_reductions(
1418
1418
with raise_if_dask_computes ():
1419
1419
grouped = array .groupby ("abc" )
1420
1420
if shuffle :
1421
- grouped = grouped .distributed_shuffle ()
1421
+ grouped = grouped .distributed_shuffle (). groupby ( "abc" )
1422
1422
1423
1423
with xr .set_options (use_flox = use_flox ):
1424
1424
actual = getattr (grouped , method )(dim = "y" )
@@ -1687,13 +1687,16 @@ def test_groupby_bins(
1687
1687
1688
1688
with xr .set_options (use_flox = use_flox ):
1689
1689
gb = array .groupby_bins ("dim_0" , bins = bins , ** cut_kwargs )
1690
+ shuffled = gb .distributed_shuffle ().groupby_bins (
1691
+ "dim_0" , bins = bins , ** cut_kwargs
1692
+ )
1690
1693
actual = gb .sum ()
1691
1694
assert_identical (expected , actual )
1692
- assert_identical (expected , gb . distributed_shuffle () .sum ())
1695
+ assert_identical (expected , shuffled .sum ())
1693
1696
1694
1697
actual = gb .map (lambda x : x .sum ())
1695
1698
assert_identical (expected , actual )
1696
- assert_identical (expected , gb . distributed_shuffle () .map (lambda x : x .sum ()))
1699
+ assert_identical (expected , shuffled .map (lambda x : x .sum ()))
1697
1700
1698
1701
# make sure original array dims are unchanged
1699
1702
assert len (array .dim_0 ) == 4
@@ -1877,17 +1880,18 @@ def resample_as_pandas(array, *args, **kwargs):
1877
1880
array = DataArray (np .arange (10 ), [("time" , times )])
1878
1881
1879
1882
rs = array .resample (time = resample_freq )
1883
+ shuffled = rs .distributed_shuffle ().resample (time = resample_freq )
1880
1884
actual = rs .mean ()
1881
1885
expected = resample_as_pandas (array , resample_freq )
1882
1886
assert_identical (expected , actual )
1883
- assert_identical (expected , rs . distributed_shuffle () .mean ())
1887
+ assert_identical (expected , shuffled .mean ())
1884
1888
1885
1889
assert_identical (expected , rs .reduce (np .mean ))
1886
- assert_identical (expected , rs . distributed_shuffle () .reduce (np .mean ))
1890
+ assert_identical (expected , shuffled .reduce (np .mean ))
1887
1891
1888
1892
rs = array .resample (time = "24h" , closed = "right" )
1889
1893
actual = rs .mean ()
1890
- shuffled = rs .distributed_shuffle ().mean ()
1894
+ shuffled = rs .distributed_shuffle ().resample ( time = "24h" , closed = "right" ). mean ()
1891
1895
expected = resample_as_pandas (array , "24h" , closed = "right" )
1892
1896
assert_identical (expected , actual )
1893
1897
assert_identical (expected , shuffled )
@@ -3168,20 +3172,47 @@ def test_groupby_multiple_bin_grouper_missing_groups():
3168
3172
3169
3173
3170
3174
@requires_dask_ge_2024_08_1
3171
- def test_shuffle_by_simple () -> None :
3175
+ def test_shuffle_simple () -> None :
3172
3176
import dask
3173
3177
3174
3178
da = xr .DataArray (
3175
3179
dims = "x" ,
3176
3180
data = dask .array .from_array ([1 , 2 , 3 , 4 , 5 , 6 ], chunks = 2 ),
3177
3181
coords = {"label" : ("x" , "a b c a b c" .split (" " ))},
3178
3182
)
3179
- actual = da .distributed_shuffle_by (label = UniqueGrouper ())
3183
+ actual = da .groupby (label = UniqueGrouper ()). distributed_shuffle ( )
3180
3184
expected = da .isel (x = [0 , 3 , 1 , 4 , 2 , 5 ])
3181
3185
assert_identical (actual , expected )
3182
3186
3183
3187
with pytest .raises (ValueError ):
3184
- da .chunk (x = 2 , eagerly_load_group = False ).distributed_shuffle_by ("label" )
3188
+ da .chunk (x = 2 , eagerly_load_group = False ).groupby ("label" ).distributed_shuffle ()
3189
+
3190
+
3191
+ @requires_dask_ge_2024_08_1
3192
+ @pytest .mark .parametrize (
3193
+ "chunks, expected_chunks" ,
3194
+ [
3195
+ ((1 ,), (1 , 3 , 3 , 3 )),
3196
+ ((10 ,), (10 ,)),
3197
+ ],
3198
+ )
3199
+ def test_shuffle_by (chunks , expected_chunks ):
3200
+ import dask .array
3201
+
3202
+ from xarray .groupers import UniqueGrouper
3203
+
3204
+ da = xr .DataArray (
3205
+ dims = "x" ,
3206
+ data = dask .array .arange (10 , chunks = chunks ),
3207
+ coords = {"x" : [1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 , 0 ]},
3208
+ name = "a" ,
3209
+ )
3210
+ ds = da .to_dataset ()
3211
+
3212
+ for obj in [ds , da ]:
3213
+ actual = obj .groupby (x = UniqueGrouper ()).distributed_shuffle ()
3214
+ assert_identical (actual , obj .sortby ("x" ))
3215
+ assert actual .chunksizes ["x" ] == expected_chunks
3185
3216
3186
3217
3187
3218
@requires_dask
0 commit comments