Skip to content

Commit

Permalink
Merge pull request #1277 from nipreps/enh/parallelization-settings
Browse files Browse the repository at this point in the history
ENH: Annotate nodes with ``n_procs`` to allow safe parallelization
  • Loading branch information
oesteban authored Apr 16, 2024
2 parents 630ced8 + 8c205f2 commit f33031f
Showing 1 changed file with 73 additions and 29 deletions.
102 changes: 73 additions & 29 deletions mriqc/workflows/diffusion/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def dmri_qc_workflow(name='dwiMRIQC'):

workflow = pe.Workflow(name=name)

mem_gb = config.workflow.biggest_file_gb

dataset = config.workflow.inputs.get('dwi', [])
full_data = []

Expand Down Expand Up @@ -135,24 +133,35 @@ def dmri_qc_workflow(name='dwiMRIQC'):
max_32bit=config.execution.float32,
),
name='sanitize',
mem_gb=mem_gb * 4.0,
mem_gb=4.0,
)

# Workflow --------------------------------------------------------

# Read metadata & bvec/bval, estimate number of shells, extract and split B0s
meta = pe.Node(ReadDWIMetadata(index_db=config.execution.bids_database_dir), name='metadata')
load_bmat = pe.Node(
ReadDWIMetadata(index_db=config.execution.bids_database_dir),
name='load_bmat',
)
shells = pe.Node(NumberOfShells(), name='shells')
drift = pe.Node(CorrectSignalDrift(), name='drift')
get_lowb = pe.Node(ExtractOrientations(), name='get_lowb')
get_lowb = pe.Node(
ExtractOrientations(),
name='get_lowb',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)

# Generate B0 reference
dwi_ref = pe.Node(RobustAverage(mc_method=None), name='dwi_ref')
dwi_ref = pe.Node(
RobustAverage(mc_method=None),
name='dwi_ref',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)

hmc_b0 = pe.Node(
Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'),
name='hmc_b0',
mem_gb=mem_gb * 2.5,
mem_gb=3.0,
n_procs=config.nipype.omp_nthreads,
)

# Calculate brainmask
Expand All @@ -171,13 +180,13 @@ def dmri_qc_workflow(name='dwiMRIQC'):
averages = pe.MapNode(
WeightedStat(),
name='averages',
mem_gb=mem_gb * 1.5,
n_procs=max(1, config.nipype.omp_nthreads // 2),
iterfield=['in_weights'],
)
stddev = pe.MapNode(
WeightedStat(stat='std'),
name='stddev',
mem_gb=mem_gb * 1.5,
n_procs=max(1, config.nipype.omp_nthreads // 2),
iterfield=['in_weights'],
)

Expand All @@ -187,19 +196,39 @@ def dmri_qc_workflow(name='dwiMRIQC'):
nthreads=config.nipype.omp_nthreads,
),
name='dwidenoise',
nprocs=config.nipype.omp_nthreads,
mem_gb=mem_gb * 4,
n_procs=config.nipype.omp_nthreads,
)
drift = pe.Node(
CorrectSignalDrift(),
name='drift',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)
sp_mask = pe.Node(
SpikingVoxelsMask(),
name='sp_mask',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)
# Fit DTI/DKI model
dwimodel = pe.Node(DiffusionModel(), name='dwimodel')

sp_mask = pe.Node(SpikingVoxelsMask(), name='sp_mask')
# Fit DTI/DKI model
dwimodel = pe.Node(
DiffusionModel(),
name='dwimodel',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)

# Calculate CC mask
cc_mask = pe.Node(CCSegmentation(), name='cc_mask')
cc_mask = pe.Node(
CCSegmentation(),
name='cc_mask',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)

# Run PIESNO noise estimation
piesno = pe.Node(PIESNO(), name='piesno')
piesno = pe.Node(
PIESNO(),
name='piesno',
n_procs=max(1, config.nipype.omp_nthreads // 2),
)

# EPI to MNI registration
spatial_norm = epi_mni_align()
Expand All @@ -213,7 +242,7 @@ def dmri_qc_workflow(name='dwiMRIQC'):
# fmt: off
workflow.connect([
(inputnode, datalad_get, [('in_file', 'in_file')]),
(inputnode, meta, [('in_file', 'in_file')]),
(inputnode, load_bmat, [('in_file', 'in_file')]),
(inputnode, dwi_report_wf, [
('in_file', 'inputnode.name_source'),
]),
Expand All @@ -224,11 +253,11 @@ def dmri_qc_workflow(name='dwiMRIQC'):
(sanitize, piesno, [('out_file', 'in_file')]),
(shells, dwi_ref, [(('b_masks', _first), 't_mask')]),
(shells, sp_mask, [('b_masks', 'b_masks')]),
(meta, shells, [('out_bval_file', 'in_bvals')]),
(load_bmat, shells, [('out_bval_file', 'in_bvals')]),
(sanitize, drift, [('out_file', 'full_epi')]),
(shells, get_lowb, [(('b_indices', _first), 'indices')]),
(sanitize, get_lowb, [('out_file', 'in_file')]),
(meta, drift, [('out_bval_file', 'bval_file')]),
(load_bmat, drift, [('out_bval_file', 'bval_file')]),
(get_lowb, hmc_b0, [('out_file', 'in_file')]),
(dwi_ref, hmc_b0, [('out_file', 'basefile')]),
(hmc_b0, drift, [('out_file', 'in_file')]),
Expand All @@ -237,25 +266,25 @@ def dmri_qc_workflow(name='dwiMRIQC'):
(dmri_bmsk, sp_mask, [('outputnode.out_mask', 'brain_mask')]),
(dmri_bmsk, drift, [('outputnode.out_mask', 'brainmask_file')]),
(drift, hmcwf, [('out_full_file', 'inputnode.in_file')]),
(meta, hmcwf, [('out_bvec_file', 'inputnode.in_bvec')]),
(load_bmat, hmcwf, [('out_bvec_file', 'inputnode.in_bvec')]),
(drift, averages, [('out_full_file', 'in_file')]),
(drift, stddev, [('out_full_file', 'in_file')]),
(shells, averages, [('b_masks', 'in_weights')]),
(averages, hmcwf, [(('out_file', _first), 'inputnode.reference')]),
(shells, stddev, [('b_masks', 'in_weights')]),
(shells, dwimodel, [('out_data', 'bvals'),
('n_shells', 'n_shells')]),
(meta, dwimodel, [('out_bvec_file', 'bvec_file')]),
(load_bmat, dwimodel, [('out_bvec_file', 'bvec_file')]),
(drift, dwidenoise, [('out_full_file', 'in_file')]),
(dmri_bmsk, dwidenoise, [('outputnode.out_mask', 'mask')]),
(dwidenoise, dwimodel, [('out_file', 'in_file')]),
(dmri_bmsk, dwimodel, [('outputnode.out_mask', 'brain_mask')]),
(meta, get_hmc_shells, [('out_bvec_file', 'in_bvec_file')]),
(load_bmat, get_hmc_shells, [('out_bvec_file', 'in_bvec_file')]),
(shells, get_hmc_shells, [('b_indices', 'indices')]),
(hmcwf, get_hmc_shells, [('outputnode.out_file', 'in_file')]),
(dwimodel, cc_mask, [('out_fa', 'in_fa'),
('out_cfa', 'in_cfa')]),
(meta, iqms_wf, [
(load_bmat, iqms_wf, [
('out_bval_file', 'inputnode.b_values_file'),
('qspace_neighbors', 'inputnode.qspace_neighbors'),
]),
Expand Down Expand Up @@ -314,7 +343,6 @@ def compute_iqms(name='ComputeIQMs'):
from mriqc.interfaces.reports import AddProvenance

# from mriqc.workflows.utils import _tofloat, get_fwhmx
# mem_gb = config.workflow.biggest_file_gb

workflow = pe.Workflow(name=name)
inputnode = pe.Node(
Expand Down Expand Up @@ -418,7 +446,7 @@ def compute_iqms(name='ComputeIQMs'):
('acquisition', 'acq_id'),
('reconstruction', 'rec_id'),
('run', 'run_id'),
('out_dict', 'metadata')]),
(('out_dict', _filter_metadata), 'metadata')]),
(datasink, outputnode, [('out_file', 'out_file')]),
(meta, outputnode, [('out_dict', 'meta_sidecar')]),
(measures, datasink, [('out_qc', 'root')]),
Expand Down Expand Up @@ -448,8 +476,6 @@ def hmc_workflow(name='dMRI_HMC'):

from mriqc.interfaces.diffusion import RotateVectors

mem_gb = config.workflow.biggest_file_gb

workflow = pe.Workflow(name=name)

inputnode = pe.Node(
Expand Down Expand Up @@ -478,7 +504,8 @@ def hmc_workflow(name='dMRI_HMC'):
hmc = pe.Node(
Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'),
name='motion_correct',
mem_gb=mem_gb * 2.5,
mem_gb=3.0,
n_procs=config.nipype.omp_nthreads,
)

bvec_rot = pe.Node(RotateVectors(), name='bvec_rot')
Expand Down Expand Up @@ -716,3 +743,20 @@ def _bvals_report(in_file):
return 'Likely DSI'

return bvals


def _filter_metadata(in_dict, keys=(
'global',
'dcmmeta_affine',
'dcmmeta_reorient_transform',
'dcmmeta_shape',
'dcmmeta_slice_dim',
'dcmmeta_version',
'time',
)):
"""Drop large and partially redundant objects generated by dcm2niix."""

for key in keys:
in_dict.pop(key, None)

return in_dict

0 comments on commit f33031f

Please sign in to comment.