Skip to content

Commit 063e1f0

Browse files
oestebaneffigies
andcommitted
enh: port from process pool into asyncio concurrent
Co-authored-by: Chris Markiewicz <[email protected]>
1 parent 7c7608f commit 063e1f0

File tree

2 files changed

+112
-75
lines changed

2 files changed

+112
-75
lines changed

nitransforms/resampling.py

+111-61
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
"""Resampling utilities."""
1010

11+
import asyncio
1112
from os import cpu_count
12-
from concurrent.futures import ProcessPoolExecutor, as_completed
13+
from functools import partial
1314
from pathlib import Path
14-
from typing import Tuple
15+
from typing import Callable, TypeVar
1516

1617
import numpy as np
1718
from nibabel.loadsave import load as _nbload
@@ -27,30 +28,58 @@
2728
_as_homogeneous,
2829
)
2930

31+
R = TypeVar("R")
32+
3033
SERIALIZE_VOLUME_WINDOW_WIDTH: int = 8
3134
"""Minimum number of volumes to automatically serialize 4D transforms."""
3235

3336

34-
def _apply_volume(
35-
index: int,
37+
async def worker(job: Callable[[], R], semaphore) -> R:
38+
async with semaphore:
39+
loop = asyncio.get_running_loop()
40+
return await loop.run_in_executor(None, job)
41+
42+
43+
async def _apply_serial(
3644
data: np.ndarray,
45+
spatialimage: SpatialImage,
3746
targets: np.ndarray,
47+
transform: TransformBase,
48+
ref_ndim: int,
49+
ref_ndcoords: np.ndarray,
50+
n_resamplings: int,
51+
output: np.ndarray,
52+
input_dtype: np.dtype,
3853
order: int = 3,
3954
mode: str = "constant",
4055
cval: float = 0.0,
4156
prefilter: bool = True,
42-
) -> Tuple[int, np.ndarray]:
57+
max_concurrent: int = min(cpu_count(), 12),
58+
):
4359
"""
44-
Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization.
60+
Resample through a given transform serially, in a 3D+t setting.
4561
4662
Parameters
4763
----------
48-
index : :obj:`int`
49-
The index of the volume to apply the interpolation to.
5064
data : :obj:`~numpy.ndarray`
5165
The input data array.
66+
spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
67+
The image object containing the data to be resampled in reference
68+
space
5269
targets : :obj:`~numpy.ndarray`
5370
The target coordinates for mapping.
71+
transform : :obj:`~nitransforms.base.TransformBase`
72+
The 3D, 3D+t, or 4D transform through which data will be resampled.
73+
ref_ndim : :obj:`int`
74+
Dimensionality of the resampling target (reference image).
75+
ref_ndcoords : :obj:`~numpy.ndarray`
76+
Physical coordinates (RAS+) where data will be interpolated, if the resampling
77+
target is a grid, the scanner coordinates of all voxels.
78+
n_resamplings : :obj:`int`
79+
Total number of 3D resamplings (can be defined by the input image, the transform,
80+
or be matched, that is, same number of volumes in the input and number of transforms).
81+
output : :obj:`~numpy.ndarray`
82+
The output data array where resampled values will be stored volume-by-volume.
5483
order : :obj:`int`, optional
5584
The order of the spline interpolation, default is 3.
5685
The order has to be in the range 0-5.
@@ -71,18 +100,46 @@ def _apply_volume(
71100
72101
Returns
73102
-------
74-
(:obj:`int`, :obj:`~numpy.ndarray`)
75-
The index and the array resulting from the interpolation.
103+
np.ndarray
104+
Data resampled on the 3D+t array of input coordinates.
76105
77106
"""
78-
return index, ndi.map_coordinates(
79-
data,
80-
targets,
81-
order=order,
82-
mode=mode,
83-
cval=cval,
84-
prefilter=prefilter,
85-
)
107+
tasks = []
108+
semaphore = asyncio.Semaphore(max_concurrent)
109+
110+
for t in range(n_resamplings):
111+
xfm_t = transform if n_resamplings == 1 else transform[t]
112+
113+
if targets is None:
114+
targets = ImageGrid(spatialimage).index( # data should be an image
115+
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=ref_ndim)
116+
)
117+
118+
data_t = (
119+
data
120+
if data is not None
121+
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
122+
)
123+
124+
tasks.append(
125+
asyncio.create_task(
126+
worker(
127+
partial(
128+
ndi.map_coordinates,
129+
data_t,
130+
targets,
131+
output=output[..., t],
132+
order=order,
133+
mode=mode,
134+
cval=cval,
135+
prefilter=prefilter,
136+
),
137+
semaphore,
138+
)
139+
)
140+
)
141+
await asyncio.gather(*tasks)
142+
return output
86143

87144

88145
def apply(
@@ -94,15 +151,17 @@ def apply(
94151
cval: float = 0.0,
95152
prefilter: bool = True,
96153
output_dtype: np.dtype = None,
97-
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
98-
njobs: int = None,
99154
dtype_width: int = 8,
155+
serialize_nvols: int = SERIALIZE_VOLUME_WINDOW_WIDTH,
156+
max_concurrent: int = min(cpu_count(), 12),
100157
) -> SpatialImage | np.ndarray:
101158
"""
102159
Apply a transformation to an image, resampling on the reference spatial object.
103160
104161
Parameters
105162
----------
163+
transform: :obj:`~nitransforms.base.TransformBase`
164+
The 3D, 3D+t, or 4D transform through which data will be resampled.
106165
spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
107166
The image object containing the data to be resampled in reference
108167
space
@@ -118,15 +177,15 @@ def apply(
118177
or ``'wrap'``. Default is ``'constant'``.
119178
cval : :obj:`float`, optional
120179
Constant value for ``mode='constant'``. Default is 0.0.
121-
prefilter: :obj:`bool`, optional
180+
prefilter : :obj:`bool`, optional
122181
Determines if the image's data array is prefiltered with
123182
a spline filter before interpolation. The default is ``True``,
124183
which will create a temporary *float64* array of filtered values
125184
if *order > 1*. If setting this to ``False``, the output will be
126185
slightly blurred if *order > 1*, unless the input is prefiltered,
127186
i.e. it is the result of calling the spline filter on the original
128187
input.
129-
output_dtype: :obj:`~numpy.dtype`, optional
188+
output_dtype : :obj:`~numpy.dtype`, optional
130189
The dtype of the returned array or image, if specified.
131190
If ``None``, the default behavior is to use the effective dtype of
132191
the input image. If slope and/or intercept are defined, the effective
@@ -135,10 +194,17 @@ def apply(
135194
If ``reference`` is defined, then the return value is an image, with
136195
a data array of the effective dtype but with the on-disk dtype set to
137196
the input image's on-disk dtype.
138-
dtype_width: :obj:`int`
197+
dtype_width : :obj:`int`
139198
Cap the width of the input data type to the given number of bytes.
140199
This argument is intended to work as a way to implement lower memory
141200
requirements in resampling.
201+
serialize_nvols : :obj:`int`
202+
Minimum number of volumes in a 3D+t (that is, a series of 3D transformations
203+
independent in time) to resample on a one-by-one basis.
204+
Serialized resampling can be executed concurrently (parallelized) with
205+
the argument ``max_concurrent``.
206+
max_concurrent : :obj:`int`
207+
Maximum number of 3D resamplings to be executed concurrently.
142208
143209
Returns
144210
-------
@@ -201,46 +267,30 @@ def apply(
201267
else None
202268
)
203269

204-
njobs = cpu_count() if njobs is None or njobs < 1 else njobs
205-
206-
with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
207-
results = []
208-
for t in range(n_resamplings):
209-
xfm_t = transform if n_resamplings == 1 else transform[t]
210-
211-
if targets is None:
212-
targets = ImageGrid(spatialimage).index( # data should be an image
213-
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
214-
)
215-
216-
data_t = (
217-
data
218-
if data is not None
219-
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
220-
)
221-
222-
results.append(
223-
executor.submit(
224-
_apply_volume,
225-
t,
226-
data_t,
227-
targets,
228-
order=order,
229-
mode=mode,
230-
cval=cval,
231-
prefilter=prefilter,
232-
)
233-
)
270+
# Order F ensures individual volumes are contiguous in memory
271+
# Also matches NIfTI, making final save more efficient
272+
resampled = np.zeros(
273+
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
274+
)
234275

235-
# Order F ensures individual volumes are contiguous in memory
236-
# Also matches NIfTI, making final save more efficient
237-
resampled = np.zeros(
238-
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
276+
resampled = asyncio.run(
277+
_apply_serial(
278+
data,
279+
spatialimage,
280+
targets,
281+
transform,
282+
_ref.ndim,
283+
ref_ndcoords,
284+
n_resamplings,
285+
resampled,
286+
input_dtype,
287+
order=order,
288+
mode=mode,
289+
cval=cval,
290+
prefilter=prefilter,
291+
max_concurrent=max_concurrent,
239292
)
240-
241-
for future in as_completed(results):
242-
t, resampled_t = future.result()
243-
resampled[..., t] = resampled_t
293+
)
244294
else:
245295
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)
246296

nitransforms/tests/test_resampling.py

+1-14
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from nitransforms import nonlinear as nitnl
1616
from nitransforms import manip as nitm
1717
from nitransforms import io
18-
from nitransforms.resampling import apply, _apply_volume
18+
from nitransforms.resampling import apply
1919

2020
RMSE_TOL_LINEAR = 0.09
2121
RMSE_TOL_NONLINEAR = 0.05
@@ -363,16 +363,3 @@ def test_LinearTransformsMapping_apply(
363363
reference=testdata_path / "sbref.nii.gz",
364364
serialize_nvols=2 if serialize_4d else np.inf,
365365
)
366-
367-
368-
@pytest.mark.parametrize("t", list(range(4)))
369-
def test_apply_helper(monkeypatch, t):
370-
"""Ensure the apply helper function correctly just decorates with index."""
371-
from nitransforms.resampling import ndi
372-
373-
def _retval(*args, **kwargs):
374-
return 1
375-
376-
monkeypatch.setattr(ndi, "map_coordinates", _retval)
377-
378-
assert _apply_volume(t, None, None) == (t, 1)

0 commit comments

Comments
 (0)