Skip to content

Commit 6a8bbf7

Browse files
committed
Avoid realizing a potentially very large RangeIndex in to memory
xref #428
1 parent 9e82b66 commit 6a8bbf7

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

flox/core.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1602,7 +1602,7 @@ def dask_groupby_agg(
16021602
engine: T_Engine = "numpy",
16031603
sort: bool = True,
16041604
chunks_cohorts=None,
1605-
) -> tuple[DaskArray, tuple[np.ndarray | DaskArray]]:
1605+
) -> tuple[DaskArray, tuple[pd.Index | np.ndarray | DaskArray]]:
16061606
import dask.array
16071607
from dask.array.core import slices_from_chunks
16081608
from dask.highlevelgraph import HighLevelGraph
@@ -1730,7 +1730,7 @@ def dask_groupby_agg(
17301730
group_chunks = ((np.nan,),)
17311731
else:
17321732
assert expected_groups is not None
1733-
groups = (expected_groups.to_numpy(),)
1733+
groups = (expected_groups,)
17341734
group_chunks = ((len(expected_groups),),)
17351735

17361736
elif method == "cohorts":
@@ -1846,7 +1846,7 @@ def cubed_groupby_agg(
18461846
engine: T_Engine = "numpy",
18471847
sort: bool = True,
18481848
chunks_cohorts=None,
1849-
) -> tuple[CubedArray, tuple[np.ndarray | CubedArray]]:
1849+
) -> tuple[CubedArray, tuple[pd.Index | np.ndarray | CubedArray]]:
18501850
import cubed
18511851
import cubed.core.groupby
18521852

@@ -1882,7 +1882,7 @@ def _reduction_func(a, by, axis, start_group, num_groups):
18821882
result = cubed.core.groupby.groupby_blockwise(
18831883
array, by, axis=axis, func=_reduction_func, num_groups=num_groups
18841884
)
1885-
groups = (expected_groups.to_numpy(),)
1885+
groups = (expected_groups,)
18861886
return (result, groups)
18871887

18881888
else:
@@ -1964,7 +1964,7 @@ def _groupby_aggregate(a, **kwargs):
19641964
num_groups=num_groups,
19651965
)
19661966

1967-
groups = (expected_groups.to_numpy(),)
1967+
groups = (expected_groups,)
19681968

19691969
return (result, groups)
19701970

flox/dask_array_ops.py

+8
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,22 @@
44
from itertools import product
55
from numbers import Integral
66

7+
import pandas as pd
78
from dask import config
9+
from dask.base import normalize_token
810
from dask.blockwise import lol_tuples
911
from toolz import partition_all
1012

1113
from .lib import ArrayLayer
1214
from .types import Graph
1315

1416

17+
# workaround for https://github.com/dask/dask/issues/11862
18+
@normalize_token.register(pd.RangeIndex)
19+
def normalize_range_index(x):
20+
return normalize_token(type(x)), x.start, x.stop, x.step, x.dtype, x.name
21+
22+
1523
# _tree_reduce and partial_reduce are copied from dask.array.reductions
1624
# They have been modified to work purely with graphs, and without creating new Array layers
1725
# in the graph. The `block_index` kwarg is new and avoids a concatenation by simply setting the right

0 commit comments

Comments
 (0)