Skip to content

Commit c398f4e

Browse files
authored
Use threadpool for finding labels in chunk (#327)
* Use threadpool for finding labels in chunk Great when we have lots of decent size chunks, particularly the NWM county groupby: 600ms -> 400ms. ``` | Before [0cccb90] <optimize-again> | After [38fe8a6c] <threadpool> | Ratio | Benchmark (Parameter) | |--------------------------------------|---------------------------------|---------|---------------------------------------------| | 3.50±0.2ms | 2.93±0.07ms | 0.84 | cohorts.PerfectMonthly.time_graph_construct | | 20.0±1ms | 9.66±1ms | 0.48 | cohorts.NWMMidwest.time_find_group_cohorts | ``` * Add threshold * Fix + comment * Fix benchmark. * Tweak threshold * Small cleanup * Comment * Try single allocation * Revert "Try single allocation" This reverts commit c6b93367e2024e60d77af24a69d177670a040dfc. * cleanup
1 parent eb3c0ef commit c398f4e

File tree

1 file changed

+43
-7
lines changed

1 file changed

+43
-7
lines changed

flox/core.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import warnings
99
from collections import namedtuple
1010
from collections.abc import Sequence
11+
from concurrent.futures import ThreadPoolExecutor
1112
from functools import partial, reduce
1213
from itertools import product
1314
from numbers import Integral
@@ -261,6 +262,7 @@ def make_bitmask(rows, cols):
261262
assert isinstance(labels, np.ndarray)
262263
shape = tuple(sum(c) for c in chunks)
263264
nchunks = math.prod(len(c) for c in chunks)
265+
approx_chunk_size = math.prod(c[0] for c in chunks)
264266

265267
# Shortcut for 1D with size-1 chunks
266268
if shape == (nchunks,):
@@ -271,21 +273,55 @@ def make_bitmask(rows, cols):
271273

272274
labels = np.broadcast_to(labels, shape[-labels.ndim :])
273275
cols = []
274-
# Add one to handle the -1 sentinel value
275-
label_is_present = np.zeros((nlabels + 1,), dtype=bool)
276276
ilabels = np.arange(nlabels)
277-
for region in slices_from_chunks(chunks):
277+
278+
def chunk_unique(labels, slicer, nlabels, label_is_present=None):
279+
if label_is_present is None:
280+
label_is_present = np.empty((nlabels + 1,), dtype=bool)
281+
label_is_present[:] = False
282+
subset = labels[slicer]
278283
# This is a quite fast way to find unique integers, when we know how many there are
279284
# inspired by a similar idea in numpy_groupies for first, last
280285
# instead of explicitly finding uniques, repeatedly write True to the same location
281-
subset = labels[region]
282-
# The reshape is not strictly necessary but is about 100ms faster on a test problem.
283286
label_is_present[subset.reshape(-1)] = True
284287
# skip the -1 sentinel by slicing
285288
# Faster than np.argwhere by a lot
286289
uniques = ilabels[label_is_present[:-1]]
287-
cols.append(uniques)
288-
label_is_present[:] = False
290+
return uniques
291+
292+
# TODO: refine this heuristic.
293+
# The general idea is that with the threadpool, we repeatedly allocate memory
294+
# for `label_is_present`. We trade that off against the parallelism across number of chunks.
295+
# For large enough number of chunks (relative to number of labels), it makes sense to
296+
# suffer the extra allocation in exchange for parallelism.
297+
THRESHOLD = 2
298+
if nlabels < THRESHOLD * approx_chunk_size:
299+
logger.debug(
300+
"Using threadpool since num_labels %s < %d * chunksize %s",
301+
nlabels,
302+
THRESHOLD,
303+
approx_chunk_size,
304+
)
305+
with ThreadPoolExecutor() as executor:
306+
futures = [
307+
executor.submit(chunk_unique, labels, slicer, nlabels)
308+
for slicer in slices_from_chunks(chunks)
309+
]
310+
cols = tuple(f.result() for f in futures)
311+
312+
else:
313+
logger.debug(
314+
"Using serial loop since num_labels %s > %d * chunksize %s",
315+
nlabels,
316+
THRESHOLD,
317+
approx_chunk_size,
318+
)
319+
cols = []
320+
# Add one to handle the -1 sentinel value
321+
label_is_present = np.empty((nlabels + 1,), dtype=bool)
322+
for region in slices_from_chunks(chunks):
323+
uniques = chunk_unique(labels, region, nlabels, label_is_present)
324+
cols.append(uniques)
289325
rows_array = np.repeat(np.arange(nchunks), tuple(len(col) for col in cols))
290326
cols_array = np.concatenate(cols)
291327

0 commit comments

Comments
 (0)