8
8
import warnings
9
9
from collections import namedtuple
10
10
from collections .abc import Sequence
11
+ from concurrent .futures import ThreadPoolExecutor
11
12
from functools import partial , reduce
12
13
from itertools import product
13
14
from numbers import Integral
@@ -261,6 +262,7 @@ def make_bitmask(rows, cols):
261
262
assert isinstance (labels , np .ndarray )
262
263
shape = tuple (sum (c ) for c in chunks )
263
264
nchunks = math .prod (len (c ) for c in chunks )
265
+ approx_chunk_size = math .prod (c [0 ] for c in chunks )
264
266
265
267
# Shortcut for 1D with size-1 chunks
266
268
if shape == (nchunks ,):
@@ -271,21 +273,55 @@ def make_bitmask(rows, cols):
271
273
272
274
labels = np .broadcast_to (labels , shape [- labels .ndim :])
273
275
cols = []
274
- # Add one to handle the -1 sentinel value
275
- label_is_present = np .zeros ((nlabels + 1 ,), dtype = bool )
276
276
ilabels = np .arange (nlabels )
277
- for region in slices_from_chunks (chunks ):
277
+
278
+ def chunk_unique (labels , slicer , nlabels , label_is_present = None ):
279
+ if label_is_present is None :
280
+ label_is_present = np .empty ((nlabels + 1 ,), dtype = bool )
281
+ label_is_present [:] = False
282
+ subset = labels [slicer ]
278
283
# This is a quite fast way to find unique integers, when we know how many there are
279
284
# inspired by a similar idea in numpy_groupies for first, last
280
285
# instead of explicitly finding uniques, repeatedly write True to the same location
281
- subset = labels [region ]
282
- # The reshape is not strictly necessary but is about 100ms faster on a test problem.
283
286
label_is_present [subset .reshape (- 1 )] = True
284
287
# skip the -1 sentinel by slicing
285
288
# Faster than np.argwhere by a lot
286
289
uniques = ilabels [label_is_present [:- 1 ]]
287
- cols .append (uniques )
288
- label_is_present [:] = False
290
+ return uniques
291
+
292
+ # TODO: refine this heuristic.
293
+ # The general idea is that with the threadpool, we repeatedly allocate memory
294
+ # for `label_is_present`. We trade that off against the parallelism across number of chunks.
295
+ # For large enough number of chunks (relative to number of labels), it makes sense to
296
+ # suffer the extra allocation in exchange for parallelism.
297
+ THRESHOLD = 2
298
+ if nlabels < THRESHOLD * approx_chunk_size :
299
+ logger .debug (
300
+ "Using threadpool since num_labels %s < %d * chunksize %s" ,
301
+ nlabels ,
302
+ THRESHOLD ,
303
+ approx_chunk_size ,
304
+ )
305
+ with ThreadPoolExecutor () as executor :
306
+ futures = [
307
+ executor .submit (chunk_unique , labels , slicer , nlabels )
308
+ for slicer in slices_from_chunks (chunks )
309
+ ]
310
+ cols = tuple (f .result () for f in futures )
311
+
312
+ else :
313
+ logger .debug (
314
+ "Using serial loop since num_labels %s > %d * chunksize %s" ,
315
+ nlabels ,
316
+ THRESHOLD ,
317
+ approx_chunk_size ,
318
+ )
319
+ cols = []
320
+ # Add one to handle the -1 sentinel value
321
+ label_is_present = np .empty ((nlabels + 1 ,), dtype = bool )
322
+ for region in slices_from_chunks (chunks ):
323
+ uniques = chunk_unique (labels , region , nlabels , label_is_present )
324
+ cols .append (uniques )
289
325
rows_array = np .repeat (np .arange (nchunks ), tuple (len (col ) for col in cols ))
290
326
cols_array = np .concatenate (cols )
291
327
0 commit comments