Skip to content

Commit b15990e

Browse files
committed
Fix bug with NaNs in by and method='blockwise'
xref pydata/xarray#9320
1 parent f0ce343 commit b15990e

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

flox/core.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -1797,12 +1797,13 @@ def dask_groupby_agg(
17971797
output_chunks = new_dims_shape + reduced.chunks[: -len(axis)] + group_chunks
17981798
new_axes = dict(zip(new_inds, new_dims_shape))
17991799

1800-
if method == "blockwise" and len(axis) > 1:
1801-
# The final results are available but the blocks along axes
1802-
# need to be reshaped to axis=-1
1803-
# I don't know that this is possible with blockwise
1804-
# All other code paths benefit from an unmaterialized Blockwise layer
1805-
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)
1800+
if method == "blockwise":
1801+
if len(axis) > 1:
1802+
# The final results are available but the blocks along axes
1803+
# need to be reshaped to axis=-1
1804+
# I don't know that this is possible with blockwise
1805+
# All other code paths benefit from an unmaterialized Blockwise layer
1806+
reduced = _collapse_blocks_along_axes(reduced, axis, group_chunks)
18061807

18071808
# Can't use map_blocks because it forces concatenate=True along drop_axes,
18081809
result = dask.array.blockwise(
@@ -1817,7 +1818,6 @@ def dask_groupby_agg(
18171818
concatenate=False,
18181819
new_axes=new_axes,
18191820
)
1820-
18211821
return (result, groups)
18221822

18231823

@@ -2663,10 +2663,18 @@ def groupby_reduce(
26632663
groups = (groups[0][sorted_idx],)
26642664

26652665
if factorize_early:
2666+
assert len(groups) == 1
2667+
(groups_,) = groups
26662668
# nan group labels are factorized to -1, and preserved
26672669
# now we get rid of them by reindexing
2668-
# This also handles bins with no data
2669-
result = reindex_(result, from_=groups[0], to=expected_, fill_value=fill_value).reshape(
2670+
# First, for "blockwise", we can have -1 repeated in different blocks
2671+
# This breaks the reindexing so remove those first.
2672+
if method == "blockwise" and (mask := groups_ == -1).sum(axis=-1) > 1:
2673+
result = result[..., ~mask]
2674+
groups_ = groups_[..., ~mask]
2675+
2676+
# This reindex also handles bins with no data
2677+
result = reindex_(result, from_=groups_, to=expected_, fill_value=fill_value).reshape(
26702678
result.shape[:-1] + grp_shape
26712679
)
26722680
groups = final_groups

tests/test_core.py

+14
Original file line numberDiff line numberDiff line change
@@ -1928,3 +1928,17 @@ def test_ffill_bfill(chunks, size, add_nan_by, func):
19281928
expected = flox.groupby_scan(array.compute(), by, func=func)
19291929
actual = flox.groupby_scan(array, by, func=func)
19301930
assert_equal(expected, actual)
1931+
1932+
1933+
@requires_dask
1934+
def test_blockwise_nans():
1935+
array = dask.array.ones((1, 10), chunks=2)
1936+
by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4])
1937+
actual, actual_groups = flox.groupby_reduce(
1938+
array, by, func="sum", expected_groups=pd.RangeIndex(0, 5)
1939+
)
1940+
expected, expected_groups = flox.groupby_reduce(
1941+
array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5)
1942+
)
1943+
assert_equal(expected_groups, actual_groups)
1944+
assert_equal(expected, actual)

tests/test_properties.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_groupby_reduce(data, array, func: str) -> None:
8888
flox_kwargs: dict[str, Any] = {}
8989
with np.errstate(invalid="ignore", divide="ignore"):
9090
actual, *_ = groupby_reduce(
91-
array, by, func=func, axis=axis, engine="numpy", **flox_kwargs, finalize_kwargs=kwargs
91+
array, by, func=func, axis=axis, engine="flox", **flox_kwargs, finalize_kwargs=kwargs
9292
)
9393

9494
# numpy-groupies always does the calculation in float64

0 commit comments

Comments
 (0)