@@ -672,7 +672,7 @@ def test_groupby_drops_nans(shuffle: bool, chunk: Literal[False] | dict) -> None
672
672
ds ["variable" ] = ds ["variable" ].chunk (chunk )
673
673
grouped = ds .groupby (ds .id )
674
674
if shuffle :
675
- grouped = grouped .distributed_shuffle ()
675
+ grouped = grouped .distributed_shuffle (). groupby ( ds . id )
676
676
677
677
# non reduction operation
678
678
expected1 = ds .copy ()
@@ -1418,7 +1418,7 @@ def test_groupby_reductions(
1418
1418
with raise_if_dask_computes ():
1419
1419
grouped = array .groupby ("abc" )
1420
1420
if shuffle :
1421
- grouped = grouped .distributed_shuffle ()
1421
+ grouped = grouped .distributed_shuffle (). groupby ( "abc" )
1422
1422
1423
1423
with xr .set_options (use_flox = use_flox ):
1424
1424
actual = getattr (grouped , method )(dim = "y" )
@@ -1687,13 +1687,16 @@ def test_groupby_bins(
1687
1687
1688
1688
with xr .set_options (use_flox = use_flox ):
1689
1689
gb = array .groupby_bins ("dim_0" , bins = bins , ** cut_kwargs )
1690
+ shuffled = gb .distributed_shuffle ().groupby_bins (
1691
+ "dim_0" , bins = bins , ** cut_kwargs
1692
+ )
1690
1693
actual = gb .sum ()
1691
1694
assert_identical (expected , actual )
1692
- assert_identical (expected , gb . distributed_shuffle () .sum ())
1695
+ assert_identical (expected , shuffled .sum ())
1693
1696
1694
1697
actual = gb .map (lambda x : x .sum ())
1695
1698
assert_identical (expected , actual )
1696
- assert_identical (expected , gb . distributed_shuffle () .map (lambda x : x .sum ()))
1699
+ assert_identical (expected , shuffled .map (lambda x : x .sum ()))
1697
1700
1698
1701
# make sure original array dims are unchanged
1699
1702
assert len (array .dim_0 ) == 4
@@ -1877,17 +1880,18 @@ def resample_as_pandas(array, *args, **kwargs):
1877
1880
array = DataArray (np .arange (10 ), [("time" , times )])
1878
1881
1879
1882
rs = array .resample (time = resample_freq )
1883
+ shuffled = rs .distributed_shuffle ().resample (time = resample_freq )
1880
1884
actual = rs .mean ()
1881
1885
expected = resample_as_pandas (array , resample_freq )
1882
1886
assert_identical (expected , actual )
1883
- assert_identical (expected , rs . distributed_shuffle () .mean ())
1887
+ assert_identical (expected , shuffled .mean ())
1884
1888
1885
1889
assert_identical (expected , rs .reduce (np .mean ))
1886
- assert_identical (expected , rs . distributed_shuffle () .reduce (np .mean ))
1890
+ assert_identical (expected , shuffled .reduce (np .mean ))
1887
1891
1888
1892
rs = array .resample (time = "24h" , closed = "right" )
1889
1893
actual = rs .mean ()
1890
- shuffled = rs .distributed_shuffle ().mean ()
1894
+ shuffled = rs .distributed_shuffle ().resample ( time = "24h" , closed = "right" ). mean ()
1891
1895
expected = resample_as_pandas (array , "24h" , closed = "right" )
1892
1896
assert_identical (expected , actual )
1893
1897
assert_identical (expected , shuffled )
@@ -2830,9 +2834,11 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
2830
2834
name = "foo" ,
2831
2835
)
2832
2836
2833
- gb = da .groupby (labels1 = UniqueGrouper (), labels2 = UniqueGrouper ())
2837
+ groupers : dict [str , Grouper ]
2838
+ groupers = dict (labels1 = UniqueGrouper (), labels2 = UniqueGrouper ())
2839
+ gb = da .groupby (groupers )
2834
2840
if shuffle :
2835
- gb = gb .distributed_shuffle ()
2841
+ gb = gb .distributed_shuffle (). groupby ( groupers )
2836
2842
repr (gb )
2837
2843
2838
2844
expected = DataArray (
@@ -2851,9 +2857,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
2851
2857
# -------
2852
2858
coords = {"a" : ("x" , [0 , 0 , 1 , 1 ]), "b" : ("y" , [0 , 0 , 1 , 1 ])}
2853
2859
square = DataArray (np .arange (16 ).reshape (4 , 4 ), coords = coords , dims = ["x" , "y" ])
2854
- gb = square .groupby (a = UniqueGrouper (), b = UniqueGrouper ())
2860
+ groupers = dict (a = UniqueGrouper (), b = UniqueGrouper ())
2861
+ gb = square .groupby (groupers )
2855
2862
if shuffle :
2856
- gb = gb .distributed_shuffle ()
2863
+ gb = gb .distributed_shuffle (). groupby ( groupers )
2857
2864
repr (gb )
2858
2865
with xr .set_options (use_flox = use_flox ):
2859
2866
actual = gb .mean ()
@@ -2883,9 +2890,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
2883
2890
with xr .set_options (use_flox = use_flox ):
2884
2891
assert_identical (gb .mean ("z" ), b .mean ("z" ))
2885
2892
2886
- gb = b .groupby (x = UniqueGrouper (), xy = UniqueGrouper ())
2893
+ groupers = dict (x = UniqueGrouper (), xy = UniqueGrouper ())
2894
+ gb = b .groupby (groupers )
2887
2895
if shuffle :
2888
- gb = gb .distributed_shuffle ()
2896
+ gb = gb .distributed_shuffle (). groupby ( groupers )
2889
2897
repr (gb )
2890
2898
with xr .set_options (use_flox = use_flox ):
2891
2899
actual = gb .mean ()
@@ -2937,9 +2945,12 @@ def test_multiple_groupers_mixed(use_flox: bool, shuffle: bool) -> None:
2937
2945
{"foo" : (("x" , "y" ), np .arange (12 ).reshape ((4 , 3 )))},
2938
2946
coords = {"x" : [10 , 20 , 30 , 40 ], "letters" : ("x" , list ("abba" ))},
2939
2947
)
2940
- gb = ds .groupby (x = BinGrouper (bins = [5 , 15 , 25 ]), letters = UniqueGrouper ())
2948
+ groupers : dict [str , Grouper ] = dict (
2949
+ x = BinGrouper (bins = [5 , 15 , 25 ]), letters = UniqueGrouper ()
2950
+ )
2951
+ gb = ds .groupby (groupers )
2941
2952
if shuffle :
2942
- gb = gb .distributed_shuffle ()
2953
+ gb = gb .distributed_shuffle (). groupby ( groupers )
2943
2954
expected_data = np .array (
2944
2955
[
2945
2956
[[0.0 , np .nan ], [np .nan , 3.0 ]],
@@ -3168,20 +3179,47 @@ def test_groupby_multiple_bin_grouper_missing_groups():
3168
3179
3169
3180
3170
3181
@requires_dask_ge_2024_08_1
3171
- def test_shuffle_by_simple () -> None :
3182
+ def test_shuffle_simple () -> None :
3172
3183
import dask
3173
3184
3174
3185
da = xr .DataArray (
3175
3186
dims = "x" ,
3176
3187
data = dask .array .from_array ([1 , 2 , 3 , 4 , 5 , 6 ], chunks = 2 ),
3177
3188
coords = {"label" : ("x" , "a b c a b c" .split (" " ))},
3178
3189
)
3179
- actual = da .distributed_shuffle_by (label = UniqueGrouper ())
3190
+ actual = da .groupby (label = UniqueGrouper ()). distributed_shuffle ( )
3180
3191
expected = da .isel (x = [0 , 3 , 1 , 4 , 2 , 5 ])
3181
3192
assert_identical (actual , expected )
3182
3193
3183
3194
with pytest .raises (ValueError ):
3184
- da .chunk (x = 2 , eagerly_load_group = False ).distributed_shuffle_by ("label" )
3195
+ da .chunk (x = 2 , eagerly_load_group = False ).groupby ("label" ).distributed_shuffle ()
3196
+
3197
+
3198
+ @requires_dask_ge_2024_08_1
3199
+ @pytest .mark .parametrize (
3200
+ "chunks, expected_chunks" ,
3201
+ [
3202
+ ((1 ,), (1 , 3 , 3 , 3 )),
3203
+ ((10 ,), (10 ,)),
3204
+ ],
3205
+ )
3206
+ def test_shuffle_by (chunks , expected_chunks ):
3207
+ import dask .array
3208
+
3209
+ from xarray .groupers import UniqueGrouper
3210
+
3211
+ da = xr .DataArray (
3212
+ dims = "x" ,
3213
+ data = dask .array .arange (10 , chunks = chunks ),
3214
+ coords = {"x" : [1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 , 0 ]},
3215
+ name = "a" ,
3216
+ )
3217
+ ds = da .to_dataset ()
3218
+
3219
+ for obj in [ds , da ]:
3220
+ actual = obj .groupby (x = UniqueGrouper ()).distributed_shuffle ()
3221
+ assert_identical (actual , obj .sortby ("x" ))
3222
+ assert actual .chunksizes ["x" ] == expected_chunks
3185
3223
3186
3224
3187
3225
@requires_dask
0 commit comments