Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Parallelize serialized 3D+t transforms #220

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions nitransforms/resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Resampling utilities."""

from os import cpu_count
from concurrent.futures import ProcessPoolExecutor, as_completed
from pathlib import Path
import numpy as np
from nibabel.loadsave import load as _nbload
Expand All @@ -25,6 +27,25 @@
"""Minimum number of volumes to automatically serialize 4D transforms."""


def _apply_volume(
index,
data,
targets,
order=3,
mode="constant",
cval=0.0,
prefilter=True,
):
return index, ndi.map_coordinates(
data,
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)


def apply(
transform,
spatialimage,
Expand Down Expand Up @@ -135,34 +156,47 @@ def apply(
else None
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)
if njobs is None:
njobs = cpu_count()

for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]
with ProcessPoolExecutor(max_workers=min(njobs, n_resamplings)) as executor:
results = []
for t in range(n_resamplings):
xfm_t = transform if n_resamplings == 1 else transform[t]

if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)
if targets is None:
targets = ImageGrid(spatialimage).index( # data should be an image
_as_homogeneous(xfm_t.map(ref_ndcoords), dim=_ref.ndim)
)

# Interpolate
resampled[..., t] = ndi.map_coordinates(
(
data_t = (
data
if data is not None
else spatialimage.dataobj[..., t].astype(input_dtype, copy=False)
),
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)

results.append(
executor.submit(
_apply_volume,
t,
data_t,
targets,
order=order,
mode=mode,
cval=cval,
prefilter=prefilter,
)
)

# Order F ensures individual volumes are contiguous in memory
# Also matches NIfTI, making final save more efficient
resampled = np.zeros(
(len(ref_ndcoords), len(transform)), dtype=input_dtype, order="F"
)

for future in as_completed(results):
t, resampled_t = future.result()
resampled[..., t] = resampled_t
else:
data = np.asanyarray(spatialimage.dataobj, dtype=input_dtype)

Expand Down
Loading