Skip to content

Commit bccacfe

Browse files
committed
fix
1 parent c77d7c5 commit bccacfe

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

xarray/core/groupby.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -704,14 +704,14 @@ def distributed_shuffle(self, chunks: T_Chunks = None) -> T_Xarray:
704704
705705
Examples
706706
--------
707-
>>> import dask
707+
>>> import dask.array
708708
>>> da = xr.DataArray(
709709
... dims="x",
710710
... data=dask.array.arange(10, chunks=3),
711711
... coords={"x": [1, 2, 3, 1, 2, 3, 1, 2, 3, 0]},
712712
... name="a",
713713
... )
714-
>>> shuffled = da.groupby("x")
714+
>>> shuffled = da.groupby("x").distributed_shuffle()
715715
>>> shuffled.groupby("x").quantile(q=0.5).compute()
716716
<xarray.DataArray 'a' (x: 4)> Size: 32B
717717
array([9., 3., 4., 5.])
@@ -740,9 +740,11 @@ def _shuffle_obj(self, chunks: T_Chunks) -> T_Xarray:
740740
shuffled = as_dataset._shuffle(
741741
dim=self._group_dim, indices=self.encoded.group_indices, chunks=chunks
742742
)
743-
shuffled = self._maybe_unstack(shuffled)
744-
new_obj = self._obj._from_temp_dataset(shuffled) if was_array else shuffled
745-
return new_obj
743+
unstacked: Dataset = self._maybe_unstack(shuffled)
744+
if was_array:
745+
return self._obj._from_temp_dataset(unstacked)
746+
else:
747+
return unstacked # type: ignore[return-value]
746748

747749
def map(
748750
self,

xarray/tests/test_groupby.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -1891,10 +1891,10 @@ def resample_as_pandas(array, *args, **kwargs):
18911891

18921892
rs = array.resample(time="24h", closed="right")
18931893
actual = rs.mean()
1894-
shuffled = rs.distributed_shuffle().resample(time="24h", closed="right").mean()
1894+
shuffled = rs.distributed_shuffle().resample(time="24h", closed="right")
18951895
expected = resample_as_pandas(array, "24h", closed="right")
18961896
assert_identical(expected, actual)
1897-
assert_identical(expected, shuffled)
1897+
assert_identical(expected, shuffled.mean())
18981898

18991899
with pytest.raises(ValueError, match=r"Index must be monotonic"):
19001900
array[[2, 0, 1]].resample(time=resample_freq)
@@ -2883,9 +2883,10 @@ def test_multiple_groupers(use_flox: bool, shuffle: bool) -> None:
28832883
coords={"xy": (("x", "y"), [["a", "b", "c"], ["b", "c", "c"]], {"foo": "bar"})},
28842884
dims=["x", "y", "z"],
28852885
)
2886-
gb = b.groupby(x=UniqueGrouper(), y=UniqueGrouper())
2886+
groupers = dict(x=UniqueGrouper(), y=UniqueGrouper())
2887+
gb = b.groupby(groupers)
28872888
if shuffle:
2888-
gb = gb.distributed_shuffle()
2889+
gb = gb.distributed_shuffle().groupby(groupers)
28892890
repr(gb)
28902891
with xr.set_options(use_flox=use_flox):
28912892
assert_identical(gb.mean("z"), b.mean("z"))

0 commit comments

Comments
 (0)