diff --git a/flox/core.py b/flox/core.py index 1c9599a15..570c09cc0 100644 --- a/flox/core.py +++ b/flox/core.py @@ -329,7 +329,7 @@ def chunk_unique(labels, slicer, nlabels, label_is_present=None): rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols)) cols_array = np.concatenate(cols) - return make_bitmask(rows_array, cols_array) + return make_bitmask(rows_array, cols_array), nlabels, ilabels # @memoize @@ -362,10 +362,9 @@ def find_group_cohorts( cohorts: dict_values Iterable of cohorts """ - # To do this, we must have values in memory so casting to numpy should be safe - labels = np.asarray(labels) + if not is_duck_array(labels): + labels = np.asarray(labels) - shape = tuple(sum(c) for c in chunks) nchunks = math.prod(len(c) for c in chunks) # assumes that `labels` are factorized @@ -378,8 +377,14 @@ def find_group_cohorts( if nchunks == 1: return "blockwise", {(0,): list(range(nlabels))} - labels = np.broadcast_to(labels, shape[-labels.ndim :]) - bitmask = _compute_label_chunk_bitmask(labels, chunks, nlabels) + if is_duck_dask_array(labels): + import dask + + ((bitmask, nlabels, ilabels),) = dask.compute( + dask.delayed(_compute_label_chunk_bitmask)(labels, chunks, nlabels) + ) + else: + bitmask, nlabels, ilabels = _compute_label_chunk_bitmask(labels, chunks, nlabels) CHUNK_AXIS, LABEL_AXIS = 0, 1 chunks_per_label = bitmask.sum(axis=CHUNK_AXIS) @@ -726,6 +731,26 @@ def offset_labels(labels: np.ndarray, ngroups: int) -> tuple[np.ndarray, int]: return offset, size +def fast_isin(ar1, ar2, invert): + rev_idx, ar1 = pd.factorize(ar1, sort=False) + + ar = np.concatenate((ar1, ar2)) + # We need this to be a stable sort, so always use 'mergesort' + # here. The values from the first array should always come before + # the values from the second array. + order = ar.argsort(kind="mergesort") + sar = ar[order] + if invert: + bool_ar = sar[1:] != sar[:-1] + else: + bool_ar = sar[1:] == sar[:-1] + flag = np.concatenate((bool_ar, [invert])) + ret = np.empty(ar.shape, dtype=bool) + ret[order] = flag + + return ret[rev_idx] + + @overload def factorize_( by: T_Bys, @@ -821,8 +846,18 @@ def factorize_( if expect is not None and reindex: sorter = np.argsort(expect) groups = expect[(sorter,)] if sort else expect + idx = np.searchsorted(expect, flat, sorter=sorter) - mask = ~np.isin(flat, expect) | isnull(flat) | (idx == len(expect)) + mask = fast_isin(flat, expect, invert=True) + if not np.issubdtype(flat.dtype, np.integer): + mask |= isnull(flat) + mask |= idx == len(expect) + + # idx = np.full(flat.shape, -1) + # result = np.searchsorted(expect.values, flat[~mask], sorter=sorter) + # idx[~mask] = result + # idx = np.searchsorted(expect.values, flat, sorter=sorter) + # idx[mask] = -1 if not sort: # idx is the index in to the sorted array. # if we didn't want sorting, unsort it back @@ -2125,11 +2160,10 @@ def _factorize_multiple( for by_, expect in zip(by, expected_groups): if expect is None: if is_duck_dask_array(by_): - raise ValueError( - "Please provide expected_groups when grouping by a dask array." - ) - - found_group = pd.unique(by_.reshape(-1)) + # could be remote dataset, execute remotely in that case + found_group = np.unique(by_.reshape(-1)).compute() + else: + found_group = pd.unique(by_.reshape(-1)) else: found_group = expect.to_numpy() @@ -2409,9 +2443,6 @@ def groupby_reduce( "Try engine='numpy' or engine='numba' instead." ) - if method == "cohorts" and any_by_dask: - raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.") - reindex = _validate_reindex( reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array) ) @@ -2439,10 +2470,15 @@ def groupby_reduce( # Don't factorize early only when # grouping by dask arrays, and not having expected_groups + # except for cohorts factorize_early = not ( # can't do it if we are grouping by dask array but don't have expected_groups - any(is_dask and ex_ is None for is_dask, ex_ in zip(by_is_dask, expected_groups)) + any( + is_dask and ex_ is None and method != "cohorts" + for is_dask, ex_ in zip(by_is_dask, expected_groups) + ) ) + expected_: pd.RangeIndex | None if factorize_early: bys, final_groups, grp_shape = _factorize_multiple( diff --git a/tests/test_core.py b/tests/test_core.py index e12e695db..392835357 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -862,15 +862,13 @@ def test_groupby_bins(chunk_labels, kwargs, chunks, engine, method) -> None: array = [1, 1, 1, 1, 1, 1] labels = [0.2, 1.5, 1.9, 2, 3, 20] - if method == "cohorts" and chunk_labels: - pytest.xfail() - if chunks: array = dask.array.from_array(array, chunks=chunks) if chunk_labels: labels = dask.array.from_array(labels, chunks=chunks) - with raise_if_dask_computes(): + max_computes = 1 if method == "cohorts" else 0 + with raise_if_dask_computes(max_computes): actual, *groups = groupby_reduce( array, labels, func="count", fill_value=0, engine=engine, method=method, **kwargs ) @@ -1063,10 +1061,12 @@ def test_cohorts_map_reduce_consistent_dtypes(method, dtype, labels_dtype): @requires_dask -@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("func", ["sum"]) @pytest.mark.parametrize("axis", (-1, None)) @pytest.mark.parametrize("method", ["blockwise", "cohorts", "map-reduce"]) -def test_cohorts_nd_by(func, method, axis, engine): +@pytest.mark.parametrize("by_is_dask", [True, False]) +def test_cohorts_nd_by(by_is_dask, func, method, axis): + engine = "numpy" if ( ("arg" in func and (axis is None or engine in ["flox", "numbagg"])) or (method != "blockwise" and func in BLOCKWISE_FUNCS) @@ -1074,16 +1074,20 @@ def test_cohorts_nd_by(func, method, axis, engine): ): pytest.skip() if axis is not None and method != "map-reduce": - pytest.xfail() + pytest.skip() + if by_is_dask and method == "blockwise": + pytest.skip() o = dask.array.ones((3,), chunks=-1) o2 = dask.array.ones((2, 3), chunks=-1) array = dask.array.block([[o, 2 * o], [3 * o2, 4 * o2]]) - by = array.compute().astype(np.int64) + by = array.astype(np.int64) by[0, 1] = 30 by[2, 1] = 40 by[0, 4] = 31 + if not by_is_dask: + by = by.compute() array = np.broadcast_to(array, (2, 3) + array.shape) if func in ["any", "all"]: @@ -1092,6 +1096,9 @@ def test_cohorts_nd_by(func, method, axis, engine): fill_value = -123 kwargs = dict(func=func, engine=engine, method=method, axis=axis, fill_value=fill_value) + if by_is_dask and axis is not None and method == "map-reduce": + kwargs["expected_groups"] = pd.Index([1, 2, 3, 4, 30, 31, 40]) + if "quantile" in func: kwargs["finalize_kwargs"] = {"q": DEFAULT_QUANTILE} actual, groups = groupby_reduce(array, by, **kwargs) @@ -1099,10 +1106,20 @@ def test_cohorts_nd_by(func, method, axis, engine): assert_equal(groups, sorted_groups) assert_equal(actual, expected) - actual, groups = groupby_reduce(array, by, sort=False, **kwargs) - assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64)) - reindexed = reindex_(actual, groups, pd.Index(sorted_groups)) - assert_equal(reindexed, expected) + if isinstance(by, dask.array.Array): + cache.clear() + actual_cohorts = find_group_cohorts(by, array.chunks[-by.ndim :]) + cache.clear() + expected_cohorts = find_group_cohorts(by.compute(), array.chunks[-by.ndim :]) + assert actual_cohorts == expected_cohorts + # assert cache.nbytes + + if not isinstance(by, dask.array.Array): + # Always sorting groups with cohorts and dask array + actual, groups = groupby_reduce(array, by, sort=False, **kwargs) + assert_equal(groups, np.array([1, 30, 2, 31, 3, 4, 40], dtype=np.int64)) + reindexed = reindex_(actual, groups, pd.Index(sorted_groups)) + assert_equal(reindexed, expected) @pytest.mark.parametrize("func", ["sum", "count"])