Skip to content

Commit f0ce343

Browse files
authored
Fix first, last again (#381)
* Fix first, last again Add more first, last tests * Fix * fix type ignores * Add one more property test * Support cohorts and grouped_combine * fix docs * fix profile
1 parent b05586c commit f0ce343

File tree

4 files changed

+143
-20
lines changed

4 files changed

+143
-20
lines changed

flox/core.py

+39-13
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:
170170

171171

172172
def _is_first_last_reduction(func: T_Agg) -> bool:
173-
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
173+
if isinstance(func, Aggregation):
174+
func = func.name
175+
return func in ["nanfirst", "nanlast", "first", "last"]
174176

175177

176178
def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
@@ -1642,7 +1644,12 @@ def dask_groupby_agg(
16421644
# This allows us to discover groups at compute time, support argreductions, lower intermediate
16431645
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
16441646
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
1645-
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
1647+
do_grouped_combine = (
1648+
_is_arg_reduction(agg)
1649+
or labels_are_unknown
1650+
or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
1651+
)
1652+
do_simple_combine = not do_grouped_combine
16461653

16471654
if method == "blockwise":
16481655
# use the "non dask" code path, but applied blockwise
@@ -1698,7 +1705,7 @@ def dask_groupby_agg(
16981705

16991706
tree_reduce = partial(
17001707
dask.array.reductions._tree_reduce,
1701-
name=f"{name}-reduce",
1708+
name=f"{name}-simple-reduce",
17021709
dtype=array.dtype,
17031710
axis=axis,
17041711
keepdims=True,
@@ -1733,14 +1740,20 @@ def dask_groupby_agg(
17331740
groups_ = []
17341741
for blks, cohort in chunks_cohorts.items():
17351742
cohort_index = pd.Index(cohort)
1736-
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
1743+
reindexer = (
1744+
partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
1745+
if do_simple_combine
1746+
else identity
1747+
)
17371748
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
17381749
# now that we have reindexed, we can set reindex=True explicitlly
17391750
reduced_.append(
17401751
tree_reduce(
17411752
reindexed,
1742-
combine=partial(combine, agg=agg, reindex=True),
1743-
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
1753+
combine=partial(combine, agg=agg, reindex=do_simple_combine),
1754+
aggregate=partial(
1755+
aggregate, expected_groups=cohort_index, reindex=do_simple_combine
1756+
),
17441757
)
17451758
)
17461759
# This is done because pandas promotes to 64-bit types when an Index is created
@@ -1986,8 +1999,13 @@ def _validate_reindex(
19861999
expected_groups,
19872000
any_by_dask: bool,
19882001
is_dask_array: bool,
2002+
array_dtype: Any,
19892003
) -> bool | None:
19902004
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
2005+
def first_or_last():
2006+
return func in ["first", "last"] or (
2007+
_is_first_last_reduction(func) and array_dtype.kind != "f"
2008+
)
19912009

19922010
all_numpy = not is_dask_array and not any_by_dask
19932011
if reindex is True and not all_numpy:
@@ -1997,7 +2015,7 @@ def _validate_reindex(
19972015
raise ValueError(
19982016
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
19992017
)
2000-
if func in ["first", "last"]:
2018+
if first_or_last():
20012019
raise ValueError("reindex must be None or False when func is 'first' or 'last.")
20022020

20032021
if reindex is None:
@@ -2008,9 +2026,10 @@ def _validate_reindex(
20082026
if all_numpy:
20092027
return True
20102028

2011-
if func in ["first", "last"]:
2029+
if first_or_last():
20122030
# have to do the grouped_combine since there's no good fill_value
2013-
reindex = False
2031+
# Also needed for nanfirst, nanlast with no-NaN dtypes
2032+
return False
20142033

20152034
if method == "blockwise":
20162035
# for grouping by dask arrays, we set reindex=True
@@ -2412,12 +2431,19 @@ def groupby_reduce(
24122431
if method == "cohorts" and any_by_dask:
24132432
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")
24142433

2434+
if not is_duck_array(array):
2435+
array = np.asarray(array)
2436+
24152437
reindex = _validate_reindex(
2416-
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
2438+
reindex,
2439+
func,
2440+
method,
2441+
expected_groups,
2442+
any_by_dask,
2443+
is_duck_dask_array(array),
2444+
array.dtype,
24172445
)
24182446

2419-
if not is_duck_array(array):
2420-
array = np.asarray(array)
24212447
is_bool_array = np.issubdtype(array.dtype, bool)
24222448
array = array.astype(np.intp) if is_bool_array else array
24232449

@@ -2601,7 +2627,7 @@ def groupby_reduce(
26012627

26022628
# TODO: clean this up
26032629
reindex = _validate_reindex(
2604-
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
2630+
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
26052631
)
26062632

26072633
if TYPE_CHECKING:

tests/conftest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
1111
)
1212
settings.register_profile(
13-
"local",
13+
"default",
1414
max_examples=300,
1515
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
1616
verbosity=Verbosity.verbose,
1717
)
18+
settings.load_profile("default")
1819

1920

2021
@pytest.fixture(

tests/test_core.py

+87-6
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset():
613613
)
614614

615615

616+
@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]])
617+
@pytest.mark.parametrize(
618+
"func",
619+
[
620+
# "first", "last",
621+
"nanfirst",
622+
"nanlast",
623+
],
624+
)
625+
@pytest.mark.parametrize(
626+
"chunks",
627+
[
628+
None,
629+
pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
630+
pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
631+
pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
632+
],
633+
)
634+
def test_first_last_useless(func, chunks, group_idx):
635+
array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8)
636+
if chunks is not None:
637+
array = dask.array.from_array(array, chunks=chunks)
638+
actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy")
639+
expected = np.array([[0, 0], [0, 0]], dtype=np.int8)
640+
assert_equal(actual, expected)
641+
642+
616643
@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
617644
@pytest.mark.parametrize("axis", [(0, 1)])
618645
def test_first_last_disallowed(axis, func):
@@ -1563,18 +1590,36 @@ def test_validate_reindex_map_reduce(
15631590
dask_expected, reindex, func, expected_groups, any_by_dask
15641591
) -> None:
15651592
actual = _validate_reindex(
1566-
reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True
1593+
reindex,
1594+
func,
1595+
"map-reduce",
1596+
expected_groups,
1597+
any_by_dask,
1598+
is_dask_array=True,
1599+
array_dtype=np.dtype("int32"),
15671600
)
15681601
assert actual is dask_expected
15691602

15701603
# always reindex with all numpy inputs
15711604
actual = _validate_reindex(
1572-
reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
1605+
reindex,
1606+
func,
1607+
"map-reduce",
1608+
expected_groups,
1609+
any_by_dask=False,
1610+
is_dask_array=False,
1611+
array_dtype=np.dtype("int32"),
15731612
)
15741613
assert actual
15751614

15761615
actual = _validate_reindex(
1577-
True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
1616+
True,
1617+
func,
1618+
"map-reduce",
1619+
expected_groups,
1620+
any_by_dask=False,
1621+
is_dask_array=False,
1622+
array_dtype=np.dtype("int32"),
15781623
)
15791624
assert actual
15801625

@@ -1584,19 +1629,37 @@ def test_validate_reindex() -> None:
15841629
for method in methods:
15851630
with pytest.raises(NotImplementedError):
15861631
_validate_reindex(
1587-
True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True
1632+
True,
1633+
"argmax",
1634+
method,
1635+
expected_groups=None,
1636+
any_by_dask=False,
1637+
is_dask_array=True,
1638+
array_dtype=np.dtype("int32"),
15881639
)
15891640

15901641
methods: list[T_Method] = ["blockwise", "cohorts"]
15911642
for method in methods:
15921643
with pytest.raises(ValueError):
15931644
_validate_reindex(
1594-
True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True
1645+
True,
1646+
"sum",
1647+
method,
1648+
expected_groups=None,
1649+
any_by_dask=False,
1650+
is_dask_array=True,
1651+
array_dtype=np.dtype("int32"),
15951652
)
15961653

15971654
for func in ["sum", "argmax"]:
15981655
actual = _validate_reindex(
1599-
None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True
1656+
None,
1657+
func,
1658+
method,
1659+
expected_groups=None,
1660+
any_by_dask=False,
1661+
is_dask_array=True,
1662+
array_dtype=np.dtype("int32"),
16001663
)
16011664
assert actual is False
16021665

@@ -1608,6 +1671,7 @@ def test_validate_reindex() -> None:
16081671
expected_groups=np.array([1, 2, 3]),
16091672
any_by_dask=False,
16101673
is_dask_array=True,
1674+
array_dtype=np.dtype("int32"),
16111675
)
16121676

16131677
assert _validate_reindex(
@@ -1617,6 +1681,7 @@ def test_validate_reindex() -> None:
16171681
expected_groups=np.array([1, 2, 3]),
16181682
any_by_dask=True,
16191683
is_dask_array=True,
1684+
array_dtype=np.dtype("int32"),
16201685
)
16211686
assert _validate_reindex(
16221687
None,
@@ -1625,8 +1690,24 @@ def test_validate_reindex() -> None:
16251690
expected_groups=np.array([1, 2, 3]),
16261691
any_by_dask=True,
16271692
is_dask_array=True,
1693+
array_dtype=np.dtype("int32"),
1694+
)
1695+
1696+
kwargs = dict(
1697+
method="blockwise",
1698+
expected_groups=np.array([1, 2, 3]),
1699+
any_by_dask=True,
1700+
is_dask_array=True,
16281701
)
16291702

1703+
for func in ["nanfirst", "nanlast"]:
1704+
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type]
1705+
assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type]
1706+
1707+
for func in ["first", "last"]:
1708+
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type]
1709+
assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type]
1710+
16301711

16311712
@requires_dask
16321713
def test_1d_blockwise_sort_optimization():

tests/test_properties.py

+15
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
pytest.importorskip("cftime")
1010

1111
import dask
12+
import hypothesis.extra.numpy as npst
1213
import hypothesis.strategies as st
1314
import numpy as np
1415
from hypothesis import assume, given, note
@@ -19,6 +20,7 @@
1920

2021
from . import assert_equal
2122
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
23+
from .strategies import chunks as chunks_strategy
2224

2325
dask.config.set(scheduler="sync")
2426

@@ -208,3 +210,16 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None:
208210
first, *_ = groupby_reduce(array, by, func=func, engine="flox")
209211
second, *_ = groupby_reduce(array, by, func=mate, engine="flox")
210212
assert_equal(first, second)
213+
214+
215+
@given(data=st.data(), func=st.sampled_from(["nanfirst", "nanlast"]))
216+
def test_first_last_useless(data, func):
217+
shape = data.draw(npst.array_shapes())
218+
by = data.draw(by_arrays(shape=shape[slice(-1, None)]))
219+
chunks = data.draw(chunks_strategy(shape=shape))
220+
array = np.zeros(shape, dtype=np.int8)
221+
if chunks is not None:
222+
array = dask.array.from_array(array, chunks=chunks)
223+
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
224+
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
225+
assert_equal(actual, expected)

0 commit comments

Comments
 (0)