8
8
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9
9
"""Resampling utilities."""
10
10
11
+ import asyncio
11
12
from os import cpu_count
12
- from concurrent . futures import ProcessPoolExecutor , as_completed
13
+ from functools import partial
13
14
from pathlib import Path
14
- from typing import Tuple
15
+ from typing import Callable , TypeVar
15
16
16
17
import numpy as np
17
18
from nibabel .loadsave import load as _nbload
27
28
_as_homogeneous ,
28
29
)
29
30
31
+ R = TypeVar ("R" )
32
+
30
33
SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
31
34
"""Minimum number of volumes to automatically serialize 4D transforms."""
32
35
33
36
34
- def _apply_volume (
35
- index : int ,
37
+ async def worker (job : Callable [[], R ], semaphore ) -> R :
38
+ async with semaphore :
39
+ loop = asyncio .get_running_loop ()
40
+ return await loop .run_in_executor (None , job )
41
+
42
+
43
+ async def _apply_serial (
36
44
data : np .ndarray ,
45
+ spatialimage : SpatialImage ,
37
46
targets : np .ndarray ,
47
+ transform : TransformBase ,
48
+ ref_ndim : int ,
49
+ ref_ndcoords : np .ndarray ,
50
+ n_resamplings : int ,
51
+ output : np .ndarray ,
52
+ input_dtype : np .dtype ,
38
53
order : int = 3 ,
39
54
mode : str = "constant" ,
40
55
cval : float = 0.0 ,
41
56
prefilter : bool = True ,
42
- ) -> Tuple [int , np .ndarray ]:
57
+ max_concurrent : int = min (cpu_count (), 12 ),
58
+ ):
43
59
"""
44
- Decorate :obj:`~scipy.ndimage.map_coordinates` to return an order index for parallelization .
60
+ Resample through a given transform serially, in a 3D+t setting .
45
61
46
62
Parameters
47
63
----------
48
- index : :obj:`int`
49
- The index of the volume to apply the interpolation to.
50
64
data : :obj:`~numpy.ndarray`
51
65
The input data array.
66
+ spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
67
+ The image object containing the data to be resampled in reference
68
+ space
52
69
targets : :obj:`~numpy.ndarray`
53
70
The target coordinates for mapping.
71
+ transform : :obj:`~nitransforms.base.TransformBase`
72
+ The 3D, 3D+t, or 4D transform through which data will be resampled.
73
+ ref_ndim : :obj:`int`
74
+ Dimensionality of the resampling target (reference image).
75
+ ref_ndcoords : :obj:`~numpy.ndarray`
76
+ Physical coordinates (RAS+) where data will be interpolated, if the resampling
77
+ target is a grid, the scanner coordinates of all voxels.
78
+ n_resamplings : :obj:`int`
79
+ Total number of 3D resamplings (can be defined by the input image, the transform,
80
+ or be matched, that is, same number of volumes in the input and number of transforms).
81
+ output : :obj:`~numpy.ndarray`
82
+ The output data array where resampled values will be stored volume-by-volume.
54
83
order : :obj:`int`, optional
55
84
The order of the spline interpolation, default is 3.
56
85
The order has to be in the range 0-5.
@@ -71,18 +100,46 @@ def _apply_volume(
71
100
72
101
Returns
73
102
-------
74
- (:obj:`int`, :obj:`~numpy .ndarray`)
75
- The index and the array resulting from the interpolation .
103
+ np .ndarray
104
+ Data resampled on the 3D+t array of input coordinates .
76
105
77
106
"""
78
- return index , ndi .map_coordinates (
79
- data ,
80
- targets ,
81
- order = order ,
82
- mode = mode ,
83
- cval = cval ,
84
- prefilter = prefilter ,
85
- )
107
+ tasks = []
108
+ semaphore = asyncio .Semaphore (max_concurrent )
109
+
110
+ for t in range (n_resamplings ):
111
+ xfm_t = transform if n_resamplings == 1 else transform [t ]
112
+
113
+ if targets is None :
114
+ targets = ImageGrid (spatialimage ).index ( # data should be an image
115
+ _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116
+ )
117
+
118
+ data_t = (
119
+ data
120
+ if data is not None
121
+ else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
122
+ )
123
+
124
+ tasks .append (
125
+ asyncio .create_task (
126
+ worker (
127
+ partial (
128
+ ndi .map_coordinates ,
129
+ data_t ,
130
+ targets ,
131
+ output = output [..., t ],
132
+ order = order ,
133
+ mode = mode ,
134
+ cval = cval ,
135
+ prefilter = prefilter ,
136
+ ),
137
+ semaphore ,
138
+ )
139
+ )
140
+ )
141
+ await asyncio .gather (* tasks )
142
+ return output
86
143
87
144
88
145
def apply (
@@ -94,15 +151,17 @@ def apply(
94
151
cval : float = 0.0 ,
95
152
prefilter : bool = True ,
96
153
output_dtype : np .dtype = None ,
97
- serialize_nvols : int = SERIALIZE_VOLUME_WINDOW_WIDTH ,
98
- njobs : int = None ,
99
154
dtype_width : int = 8 ,
155
+ serialize_nvols : int = SERIALIZE_VOLUME_WINDOW_WIDTH ,
156
+ max_concurrent : int = min (cpu_count (), 12 ),
100
157
) -> SpatialImage | np .ndarray :
101
158
"""
102
159
Apply a transformation to an image, resampling on the reference spatial object.
103
160
104
161
Parameters
105
162
----------
163
+ transform: :obj:`~nitransforms.base.TransformBase`
164
+ The 3D, 3D+t, or 4D transform through which data will be resampled.
106
165
spatialimage : :obj:`~nibabel.spatialimages.SpatialImage` or `os.pathlike`
107
166
The image object containing the data to be resampled in reference
108
167
space
@@ -118,15 +177,15 @@ def apply(
118
177
or ``'wrap'``. Default is ``'constant'``.
119
178
cval : :obj:`float`, optional
120
179
Constant value for ``mode='constant'``. Default is 0.0.
121
- prefilter: :obj:`bool`, optional
180
+ prefilter : :obj:`bool`, optional
122
181
Determines if the image's data array is prefiltered with
123
182
a spline filter before interpolation. The default is ``True``,
124
183
which will create a temporary *float64* array of filtered values
125
184
if *order > 1*. If setting this to ``False``, the output will be
126
185
slightly blurred if *order > 1*, unless the input is prefiltered,
127
186
i.e. it is the result of calling the spline filter on the original
128
187
input.
129
- output_dtype: :obj:`~numpy.dtype`, optional
188
+ output_dtype : :obj:`~numpy.dtype`, optional
130
189
The dtype of the returned array or image, if specified.
131
190
If ``None``, the default behavior is to use the effective dtype of
132
191
the input image. If slope and/or intercept are defined, the effective
@@ -135,10 +194,17 @@ def apply(
135
194
If ``reference`` is defined, then the return value is an image, with
136
195
a data array of the effective dtype but with the on-disk dtype set to
137
196
the input image's on-disk dtype.
138
- dtype_width: :obj:`int`
197
+ dtype_width : :obj:`int`
139
198
Cap the width of the input data type to the given number of bytes.
140
199
This argument is intended to work as a way to implement lower memory
141
200
requirements in resampling.
201
+ serialize_nvols : :obj:`int`
202
+ Minimum number of volumes in a 3D+t (that is, a series of 3D transformations
203
+ independent in time) to resample on a one-by-one basis.
204
+ Serialized resampling can be executed concurrently (parallelized) with
205
+ the argument ``max_concurrent``.
206
+ max_concurrent : :obj:`int`
207
+ Maximum number of 3D resamplings to be executed concurrently.
142
208
143
209
Returns
144
210
-------
@@ -201,46 +267,30 @@ def apply(
201
267
else None
202
268
)
203
269
204
- njobs = cpu_count () if njobs is None or njobs < 1 else njobs
205
-
206
- with ProcessPoolExecutor (max_workers = min (njobs , n_resamplings )) as executor :
207
- results = []
208
- for t in range (n_resamplings ):
209
- xfm_t = transform if n_resamplings == 1 else transform [t ]
210
-
211
- if targets is None :
212
- targets = ImageGrid (spatialimage ).index ( # data should be an image
213
- _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = _ref .ndim )
214
- )
215
-
216
- data_t = (
217
- data
218
- if data is not None
219
- else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
220
- )
221
-
222
- results .append (
223
- executor .submit (
224
- _apply_volume ,
225
- t ,
226
- data_t ,
227
- targets ,
228
- order = order ,
229
- mode = mode ,
230
- cval = cval ,
231
- prefilter = prefilter ,
232
- )
233
- )
270
+ # Order F ensures individual volumes are contiguous in memory
271
+ # Also matches NIfTI, making final save more efficient
272
+ resampled = np .zeros (
273
+ (len (ref_ndcoords ), len (transform )), dtype = input_dtype , order = "F"
274
+ )
234
275
235
- # Order F ensures individual volumes are contiguous in memory
236
- # Also matches NIfTI, making final save more efficient
237
- resampled = np .zeros (
238
- (len (ref_ndcoords ), len (transform )), dtype = input_dtype , order = "F"
276
+ resampled = asyncio .run (
277
+ _apply_serial (
278
+ data ,
279
+ spatialimage ,
280
+ targets ,
281
+ transform ,
282
+ _ref .ndim ,
283
+ ref_ndcoords ,
284
+ n_resamplings ,
285
+ resampled ,
286
+ input_dtype ,
287
+ order = order ,
288
+ mode = mode ,
289
+ cval = cval ,
290
+ prefilter = prefilter ,
291
+ max_concurrent = max_concurrent ,
239
292
)
240
-
241
- for future in as_completed (results ):
242
- t , resampled_t = future .result ()
243
- resampled [..., t ] = resampled_t
293
+ )
244
294
else :
245
295
data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
246
296
0 commit comments