diff --git a/flox/core.py b/flox/core.py index 46124e2ab..8a12e0304 100644 --- a/flox/core.py +++ b/flox/core.py @@ -1602,7 +1602,7 @@ def dask_groupby_agg( engine: T_Engine = "numpy", sort: bool = True, chunks_cohorts=None, -) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]: +) -> tuple[DaskArray, tuple[pd.Index | np.ndarray | DaskArray]]: import dask.array from dask.array.core import slices_from_chunks from dask.highlevelgraph import HighLevelGraph @@ -1730,7 +1730,7 @@ def dask_groupby_agg( group_chunks = ((np.nan,),) else: assert expected_groups is not None - groups = (expected_groups.to_numpy(),) + groups = (expected_groups,) group_chunks = ((len(expected_groups),),) elif method == "cohorts": @@ -1846,7 +1846,7 @@ def cubed_groupby_agg( engine: T_Engine = "numpy", sort: bool = True, chunks_cohorts=None, -) -> tuple[CubedArray, tuple[np.ndarray | CubedArray]]: +) -> tuple[CubedArray, tuple[pd.Index | np.ndarray | CubedArray]]: import cubed import cubed.core.groupby @@ -1882,7 +1882,7 @@ def _reduction_func(a, by, axis, start_group, num_groups): result = cubed.core.groupby.groupby_blockwise( array, by, axis=axis, func=_reduction_func, num_groups=num_groups ) - groups = (expected_groups.to_numpy(),) + groups = (expected_groups,) return (result, groups) else: @@ -1964,7 +1964,7 @@ def _groupby_aggregate(a, **kwargs): num_groups=num_groups, ) - groups = (expected_groups.to_numpy(),) + groups = (expected_groups,) return (result, groups) diff --git a/flox/dask_array_ops.py b/flox/dask_array_ops.py index 15fbe08a0..4812229f4 100644 --- a/flox/dask_array_ops.py +++ b/flox/dask_array_ops.py @@ -4,7 +4,9 @@ from itertools import product from numbers import Integral +import pandas as pd from dask import config +from dask.base import normalize_token from dask.blockwise import lol_tuples from toolz import partition_all @@ -12,6 +14,12 @@ from .types import Graph +# workaround for https://github.com/dask/dask/issues/11862 +@normalize_token.register(pd.RangeIndex) +def normalize_range_index(x): + return normalize_token(type(x)), x.start, x.stop, x.step, x.dtype, x.name + + # _tree_reduce and partial_reduce are copied from dask.array.reductions # They have been modified to work purely with graphs, and without creating new Array layers # in the graph. The `block_index` kwarg is new and avoids a concatenation by simply setting the right diff --git a/tests/conftest.py b/tests/conftest.py index a81c825d2..ac0624ceb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,15 +30,3 @@ ) def engine(request): return request.param - - -@pytest.fixture( - scope="module", - params=[ - "flox", - "numpy", - pytest.param("numbagg", marks=requires_numbagg), - ], -) -def engine_no_numba(request): - return request.param diff --git a/tests/test_xarray.py b/tests/test_xarray.py index 85f8453b2..737c74135 100644 --- a/tests/test_xarray.py +++ b/tests/test_xarray.py @@ -34,8 +34,7 @@ @pytest.mark.parametrize("min_count", [None, 1, 3]) @pytest.mark.parametrize("add_nan", [True, False]) @pytest.mark.parametrize("skipna", [True, False]) -def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex): - engine = engine_no_numba +def test_xarray_reduce(skipna, add_nan, min_count, engine, reindex): if skipna is False and min_count is not None: pytest.skip() @@ -91,11 +90,9 @@ def test_xarray_reduce(skipna, add_nan, min_count, engine_no_numba, reindex): # TODO: sort @pytest.mark.parametrize("pass_expected_groups", [True, False]) @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False)) -def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_numba): +def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine): if chunk and pass_expected_groups is False: pytest.skip() - engine = engine_no_numba - arr = np.ones((4, 12)) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) labels2 = np.array([1, 2, 2, 1]) @@ -140,10 +137,9 @@ def test_xarray_reduce_multiple_groupers(pass_expected_groups, chunk, engine_no_ @pytest.mark.parametrize("pass_expected_groups", [True, False]) @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False)) -def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine_no_numba): +def test_xarray_reduce_multiple_groupers_2(pass_expected_groups, chunk, engine): if chunk and pass_expected_groups is False: pytest.skip() - engine = engine_no_numba arr = np.ones((2, 12)) labels = np.array(["a", "a", "c", "c", "c", "b", "b", "c", "c", "b", "b", "f"]) @@ -218,8 +214,7 @@ def test_xarray_reduce_cftime_var(engine, indexer, expected_groups, func): @requires_cftime @requires_dask -def test_xarray_reduce_single_grouper(engine_no_numba): - engine = engine_no_numba +def test_xarray_reduce_single_grouper(engine): # DataArray ds = xr.Dataset( { @@ -326,8 +321,7 @@ def test_rechunk_for_blockwise(inchunks, expected): # TODO: dim=None, dim=Ellipsis, groupby unindexed dim -def test_groupby_duplicate_coordinate_labels(engine_no_numba): - engine = engine_no_numba +def test_groupby_duplicate_coordinate_labels(engine): # fix for http://stackoverflow.com/questions/38065129 array = xr.DataArray([1, 2, 3], [("x", [1, 1, 2])]) expected = xr.DataArray([3, 3], [("x", [1, 2])]) @@ -335,8 +329,7 @@ def test_groupby_duplicate_coordinate_labels(engine_no_numba): assert_equal(expected, actual) -def test_multi_index_groupby_sum(engine_no_numba): - engine = engine_no_numba +def test_multi_index_groupby_sum(engine): # regression test for xarray GH873 ds = xr.Dataset( {"foo": (("x", "y", "z"), np.ones((3, 4, 2)))}, @@ -362,8 +355,7 @@ def test_multi_index_groupby_sum(engine_no_numba): @pytest.mark.parametrize("chunks", (None, pytest.param(2, marks=requires_dask))) -def test_xarray_groupby_bins(chunks, engine_no_numba): - engine = engine_no_numba +def test_xarray_groupby_bins(chunks, engine): array = xr.DataArray([1, 1, 1, 1, 1], dims="x") labels = xr.DataArray([1, 1.5, 1.9, 2, 3], dims="x", name="labels") @@ -532,11 +524,10 @@ def test_alignment_error(): @pytest.mark.parametrize("dtype_out", [np.float64, "float64", np.dtype("float64")]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) @pytest.mark.parametrize("chunk", (pytest.param(True, marks=requires_dask), False)) -def test_dtype(add_nan, chunk, dtype, dtype_out, engine_no_numba): - if engine_no_numba == "numbagg": +def test_dtype(add_nan, chunk, dtype, dtype_out, engine): + if engine == "numbagg": # https://github.com/numbagg/numbagg/issues/121 pytest.skip() - engine = engine_no_numba xp = dask.array if chunk else np data = xp.linspace(0, 1, 48, dtype=dtype).reshape((4, 12))