Skip to content

Commit eb3c0ef

Browse files
authored
Manually fuse reindexing intermediates with blockwise reduction for cohorts. (#300)
* Manually fuse reindexing intermediates with blockwise reduction for cohorts. ``` | Change | Before [627bf2b] <main> | After [9d710529] <optimize-cohorts-graph> | Ratio | Benchmark (Parameter) | |----------|----------------------------|---------------------------------------------|---------|-------------------------------------------------| | - | 3.39±0.02ms | 2.98±0.01ms | 0.88 | cohorts.PerfectMonthly.time_graph_construct | | - | 20 | 17 | 0.85 | cohorts.PerfectMonthly.track_num_layers | | - | 23.0±0.07ms | 19.0±0.1ms | 0.83 | cohorts.ERA5Google.time_graph_construct | | - | 4878 | 3978 | 0.82 | cohorts.ERA5Google.track_num_tasks | | - | 179±0.8ms | 147±0.5ms | 0.82 | cohorts.OISST.time_graph_construct | | - | 159 | 128 | 0.81 | cohorts.ERA5Google.track_num_layers | | - | 936 | 762 | 0.81 | cohorts.PerfectMonthly.track_num_tasks | | - | 1221 | 978 | 0.8 | cohorts.OISST.track_num_layers | | - | 4929 | 3834 | 0.78 | cohorts.ERA5DayOfYear.track_num_tasks | | - | 351 | 274 | 0.78 | cohorts.NWMMidwest.track_num_layers | | - | 4562 | 3468 | 0.76 | cohorts.ERA5DayOfYear.track_num_tasks_optimized | | - | 164±1ms | 118±0.4ms | 0.72 | cohorts.ERA5DayOfYear.time_graph_construct | | - | 1100 | 735 | 0.67 | cohorts.ERA5DayOfYear.track_num_layers | | - | 3930 | 2605 | 0.66 | cohorts.NWMMidwest.track_num_tasks | | - | 3715 | 2409 | 0.65 | cohorts.NWMMidwest.track_num_tasks_optimized | | - | 28952 | 18798 | 0.65 | cohorts.OISST.track_num_tasks | | - | 27010 | 16858 | 0.62 | cohorts.OISST.track_num_tasks_optimized | ``` * fix typing
1 parent 2439c5c commit eb3c0ef

File tree

3 files changed

+33
-19
lines changed

3 files changed

+33
-19
lines changed

asv_bench/benchmarks/cohorts.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def setup(self, *args, **kwargs):
1414
raise NotImplementedError
1515

1616
@cached_property
17-
def dask(self):
18-
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0].dask
17+
def result(self):
18+
return flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)[0]
1919

2020
def containment(self):
2121
asfloat = self.bitmask().astype(float)
@@ -52,14 +52,14 @@ def time_graph_construct(self):
5252
flox.groupby_reduce(self.array, self.by, func="sum", axis=self.axis)
5353

5454
def track_num_tasks(self):
55-
return len(self.dask.to_dict())
55+
return len(self.result.dask.to_dict())
5656

5757
def track_num_tasks_optimized(self):
58-
(opt,) = dask.optimize(self.dask)
59-
return len(opt.to_dict())
58+
(opt,) = dask.optimize(self.result)
59+
return len(opt.dask.to_dict())
6060

6161
def track_num_layers(self):
62-
return len(self.dask.layers)
62+
return len(self.result.dask.layers)
6363

6464
track_num_tasks.unit = "tasks" # type: ignore[attr-defined] # Lazy
6565
track_num_tasks_optimized.unit = "tasks" # type: ignore[attr-defined] # Lazy

flox/core.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
Callable,
1818
Literal,
1919
TypedDict,
20+
TypeVar,
2021
Union,
2122
overload,
2223
)
@@ -96,6 +97,7 @@
9697
T_MethodOpt = None | Literal["map-reduce", "blockwise", "cohorts"]
9798
T_IsBins = Union[bool | Sequence[bool]]
9899

100+
T = TypeVar("T")
99101

100102
IntermediateDict = dict[Union[str, Callable], Any]
101103
FinalResultsDict = dict[str, Union["DaskArray", "CubedArray", np.ndarray]]
@@ -140,6 +142,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
140142
return result
141143

142144

145+
def identity(x: T) -> T:
146+
return x
147+
148+
143149
def _issorted(arr: np.ndarray) -> bool:
144150
return bool((arr[:-1] <= arr[1:]).all())
145151

@@ -1438,7 +1444,10 @@ def _normalize_indexes(array: DaskArray, flatblocks, blkshape) -> tuple:
14381444

14391445

14401446
def subset_to_blocks(
1441-
array: DaskArray, flatblocks: Sequence[int], blkshape: tuple[int] | None = None
1447+
array: DaskArray,
1448+
flatblocks: Sequence[int],
1449+
blkshape: tuple[int] | None = None,
1450+
reindexer=identity,
14421451
) -> DaskArray:
14431452
"""
14441453
Advanced indexing of .blocks such that we always get a regular array back.
@@ -1464,20 +1473,21 @@ def subset_to_blocks(
14641473
index = _normalize_indexes(array, flatblocks, blkshape)
14651474

14661475
if all(not isinstance(i, np.ndarray) and i == slice(None) for i in index):
1467-
return array
1476+
return dask.array.map_blocks(reindexer, array, meta=array._meta)
14681477

14691478
# These rest is copied from dask.array.core.py with slight modifications
14701479
index = normalize_index(index, array.numblocks)
14711480
index = tuple(slice(k, k + 1) if isinstance(k, Integral) else k for k in index)
14721481

1473-
name = "blocks-" + tokenize(array, index)
1482+
name = "groupby-cohort-" + tokenize(array, index)
14741483
new_keys = array._key_array[index]
14751484

14761485
squeezed = tuple(np.squeeze(i) if isinstance(i, np.ndarray) else i for i in index)
14771486
chunks = tuple(tuple(np.array(c)[i].tolist()) for c, i in zip(array.chunks, squeezed))
14781487

14791488
keys = itertools.product(*(range(len(c)) for c in chunks))
1480-
layer: Graph = {(name,) + key: tuple(new_keys[key].tolist()) for key in keys}
1489+
layer: Graph = {(name,) + key: (reindexer, tuple(new_keys[key].tolist())) for key in keys}
1490+
14811491
graph = HighLevelGraph.from_collections(name, layer, dependencies=[array])
14821492

14831493
return dask.array.Array(graph, name, chunks, meta=array)
@@ -1651,26 +1661,26 @@ def dask_groupby_agg(
16511661

16521662
elif method == "cohorts":
16531663
assert chunks_cohorts
1664+
block_shape = array.blocks.shape[-len(axis) :]
1665+
16541666
reduced_ = []
16551667
groups_ = []
16561668
for blks, cohort in chunks_cohorts.items():
1657-
index = pd.Index(cohort)
1658-
subset = subset_to_blocks(intermediate, blks, array.blocks.shape[-len(axis) :])
1659-
reindexed = dask.array.map_blocks(
1660-
reindex_intermediates, subset, agg, index, meta=subset._meta
1661-
)
1669+
cohort_index = pd.Index(cohort)
1670+
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
1671+
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
16621672
# now that we have reindexed, we can set reindex=True explicitlly
16631673
reduced_.append(
16641674
tree_reduce(
16651675
reindexed,
16661676
combine=partial(combine, agg=agg, reindex=True),
1667-
aggregate=partial(aggregate, expected_groups=index, reindex=True),
1677+
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
16681678
)
16691679
)
16701680
# This is done because pandas promotes to 64-bit types when an Index is created
16711681
# So we use the index to generate the return value for consistency with "map-reduce"
16721682
# This is important on windows
1673-
groups_.append(index.values)
1683+
groups_.append(cohort_index.values)
16741684

16751685
reduced = dask.array.concatenate(reduced_, axis=-1)
16761686
groups = (np.concatenate(groups_),)

tests/test_core.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,14 +1465,18 @@ def test_normalize_block_indexing_2d(flatblocks, expected):
14651465

14661466
@requires_dask
14671467
def test_subset_block_passthrough():
1468+
from flox.core import identity
1469+
14681470
# full slice pass through
14691471
array = dask.array.ones((5,), chunks=(1,))
1472+
expected = dask.array.map_blocks(identity, array)
14701473
subset = subset_to_blocks(array, np.arange(5))
1471-
assert subset.name == array.name
1474+
assert subset.name == expected.name
14721475

14731476
array = dask.array.ones((5, 5), chunks=1)
1477+
expected = dask.array.map_blocks(identity, array)
14741478
subset = subset_to_blocks(array, np.arange(25))
1475-
assert subset.name == array.name
1479+
assert subset.name == expected.name
14761480

14771481

14781482
@requires_dask

0 commit comments

Comments
 (0)