From 8c205f260a347f3b14cd9cecd182e98b7dc3e0bf Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Tue, 16 Apr 2024 11:51:28 +0200 Subject: [PATCH] enh: annotate nodes with ``n_procs`` to allow safe parallelization --- mriqc/workflows/diffusion/base.py | 102 +++++++++++++++++++++--------- 1 file changed, 73 insertions(+), 29 deletions(-) diff --git a/mriqc/workflows/diffusion/base.py b/mriqc/workflows/diffusion/base.py index a1317145..2244558b 100644 --- a/mriqc/workflows/diffusion/base.py +++ b/mriqc/workflows/diffusion/base.py @@ -90,8 +90,6 @@ def dmri_qc_workflow(name='dwiMRIQC'): workflow = pe.Workflow(name=name) - mem_gb = config.workflow.biggest_file_gb - dataset = config.workflow.inputs.get('dwi', []) full_data = [] @@ -135,24 +133,35 @@ def dmri_qc_workflow(name='dwiMRIQC'): max_32bit=config.execution.float32, ), name='sanitize', - mem_gb=mem_gb * 4.0, + mem_gb=4.0, ) # Workflow -------------------------------------------------------- # Read metadata & bvec/bval, estimate number of shells, extract and split B0s - meta = pe.Node(ReadDWIMetadata(index_db=config.execution.bids_database_dir), name='metadata') + load_bmat = pe.Node( + ReadDWIMetadata(index_db=config.execution.bids_database_dir), + name='load_bmat', + ) shells = pe.Node(NumberOfShells(), name='shells') - drift = pe.Node(CorrectSignalDrift(), name='drift') - get_lowb = pe.Node(ExtractOrientations(), name='get_lowb') + get_lowb = pe.Node( + ExtractOrientations(), + name='get_lowb', + n_procs=max(1, config.nipype.omp_nthreads // 2), + ) # Generate B0 reference - dwi_ref = pe.Node(RobustAverage(mc_method=None), name='dwi_ref') + dwi_ref = pe.Node( + RobustAverage(mc_method=None), + name='dwi_ref', + n_procs=max(1, config.nipype.omp_nthreads // 2), + ) hmc_b0 = pe.Node( Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'), name='hmc_b0', - mem_gb=mem_gb * 2.5, + mem_gb=3.0, + n_procs=config.nipype.omp_nthreads, ) # Calculate brainmask @@ -171,13 +180,13 @@ def dmri_qc_workflow(name='dwiMRIQC'): averages = pe.MapNode( WeightedStat(), name='averages', - mem_gb=mem_gb * 1.5, + n_procs=max(1, config.nipype.omp_nthreads // 2), iterfield=['in_weights'], ) stddev = pe.MapNode( WeightedStat(stat='std'), name='stddev', - mem_gb=mem_gb * 1.5, + n_procs=max(1, config.nipype.omp_nthreads // 2), iterfield=['in_weights'], ) @@ -187,19 +196,39 @@ def dmri_qc_workflow(name='dwiMRIQC'): nthreads=config.nipype.omp_nthreads, ), name='dwidenoise', - nprocs=config.nipype.omp_nthreads, - mem_gb=mem_gb * 4, + n_procs=config.nipype.omp_nthreads, + ) + drift = pe.Node( + CorrectSignalDrift(), + name='drift', + n_procs=max(1, config.nipype.omp_nthreads // 2), + ) + sp_mask = pe.Node( + SpikingVoxelsMask(), + name='sp_mask', + n_procs=max(1, config.nipype.omp_nthreads // 2), ) - # Fit DTI/DKI model - dwimodel = pe.Node(DiffusionModel(), name='dwimodel') - sp_mask = pe.Node(SpikingVoxelsMask(), name='sp_mask') + # Fit DTI/DKI model + dwimodel = pe.Node( + DiffusionModel(), + name='dwimodel', + n_procs=max(1, config.nipype.omp_nthreads // 2), + ) # Calculate CC mask - cc_mask = pe.Node(CCSegmentation(), name='cc_mask') + cc_mask = pe.Node( + CCSegmentation(), + name='cc_mask', + n_procs=max(1, config.nipype.omp_nthreads // 2), + ) # Run PIESNO noise estimation - piesno = pe.Node(PIESNO(), name='piesno') + piesno = pe.Node( + PIESNO(), + name='piesno', + n_procs=max(1, config.nipype.omp_nthreads // 2), + ) # EPI to MNI registration spatial_norm = epi_mni_align() @@ -213,7 +242,7 @@ def dmri_qc_workflow(name='dwiMRIQC'): # fmt: off workflow.connect([ (inputnode, datalad_get, [('in_file', 'in_file')]), - (inputnode, meta, [('in_file', 'in_file')]), + (inputnode, load_bmat, [('in_file', 'in_file')]), (inputnode, dwi_report_wf, [ ('in_file', 'inputnode.name_source'), ]), @@ -224,11 +253,11 @@ def dmri_qc_workflow(name='dwiMRIQC'): (sanitize, piesno, [('out_file', 'in_file')]), (shells, dwi_ref, [(('b_masks', _first), 't_mask')]), (shells, sp_mask, [('b_masks', 'b_masks')]), - (meta, shells, [('out_bval_file', 'in_bvals')]), + (load_bmat, shells, [('out_bval_file', 'in_bvals')]), (sanitize, drift, [('out_file', 'full_epi')]), (shells, get_lowb, [(('b_indices', _first), 'indices')]), (sanitize, get_lowb, [('out_file', 'in_file')]), - (meta, drift, [('out_bval_file', 'bval_file')]), + (load_bmat, drift, [('out_bval_file', 'bval_file')]), (get_lowb, hmc_b0, [('out_file', 'in_file')]), (dwi_ref, hmc_b0, [('out_file', 'basefile')]), (hmc_b0, drift, [('out_file', 'in_file')]), @@ -237,7 +266,7 @@ def dmri_qc_workflow(name='dwiMRIQC'): (dmri_bmsk, sp_mask, [('outputnode.out_mask', 'brain_mask')]), (dmri_bmsk, drift, [('outputnode.out_mask', 'brainmask_file')]), (drift, hmcwf, [('out_full_file', 'inputnode.in_file')]), - (meta, hmcwf, [('out_bvec_file', 'inputnode.in_bvec')]), + (load_bmat, hmcwf, [('out_bvec_file', 'inputnode.in_bvec')]), (drift, averages, [('out_full_file', 'in_file')]), (drift, stddev, [('out_full_file', 'in_file')]), (shells, averages, [('b_masks', 'in_weights')]), @@ -245,17 +274,17 @@ def dmri_qc_workflow(name='dwiMRIQC'): (shells, stddev, [('b_masks', 'in_weights')]), (shells, dwimodel, [('out_data', 'bvals'), ('n_shells', 'n_shells')]), - (meta, dwimodel, [('out_bvec_file', 'bvec_file')]), + (load_bmat, dwimodel, [('out_bvec_file', 'bvec_file')]), (drift, dwidenoise, [('out_full_file', 'in_file')]), (dmri_bmsk, dwidenoise, [('outputnode.out_mask', 'mask')]), (dwidenoise, dwimodel, [('out_file', 'in_file')]), (dmri_bmsk, dwimodel, [('outputnode.out_mask', 'brain_mask')]), - (meta, get_hmc_shells, [('out_bvec_file', 'in_bvec_file')]), + (load_bmat, get_hmc_shells, [('out_bvec_file', 'in_bvec_file')]), (shells, get_hmc_shells, [('b_indices', 'indices')]), (hmcwf, get_hmc_shells, [('outputnode.out_file', 'in_file')]), (dwimodel, cc_mask, [('out_fa', 'in_fa'), ('out_cfa', 'in_cfa')]), - (meta, iqms_wf, [ + (load_bmat, iqms_wf, [ ('out_bval_file', 'inputnode.b_values_file'), ('qspace_neighbors', 'inputnode.qspace_neighbors'), ]), @@ -314,7 +343,6 @@ def compute_iqms(name='ComputeIQMs'): from mriqc.interfaces.reports import AddProvenance # from mriqc.workflows.utils import _tofloat, get_fwhmx - # mem_gb = config.workflow.biggest_file_gb workflow = pe.Workflow(name=name) inputnode = pe.Node( @@ -418,7 +446,7 @@ def compute_iqms(name='ComputeIQMs'): ('acquisition', 'acq_id'), ('reconstruction', 'rec_id'), ('run', 'run_id'), - ('out_dict', 'metadata')]), + (('out_dict', _filter_metadata), 'metadata')]), (datasink, outputnode, [('out_file', 'out_file')]), (meta, outputnode, [('out_dict', 'meta_sidecar')]), (measures, datasink, [('out_qc', 'root')]), @@ -448,8 +476,6 @@ def hmc_workflow(name='dMRI_HMC'): from mriqc.interfaces.diffusion import RotateVectors - mem_gb = config.workflow.biggest_file_gb - workflow = pe.Workflow(name=name) inputnode = pe.Node( @@ -478,7 +504,8 @@ def hmc_workflow(name='dMRI_HMC'): hmc = pe.Node( Volreg(args='-Fourier -twopass', zpad=4, outputtype='NIFTI_GZ'), name='motion_correct', - mem_gb=mem_gb * 2.5, + mem_gb=3.0, + n_procs=config.nipype.omp_nthreads, ) bvec_rot = pe.Node(RotateVectors(), name='bvec_rot') @@ -716,3 +743,20 @@ def _bvals_report(in_file): return 'Likely DSI' return bvals + + +def _filter_metadata(in_dict, keys=( + 'global', + 'dcmmeta_affine', + 'dcmmeta_reorient_transform', + 'dcmmeta_shape', + 'dcmmeta_slice_dim', + 'dcmmeta_version', + 'time', +)): + """Drop large and partially redundant objects generated by dcm2niix.""" + + for key in keys: + in_dict.pop(key, None) + + return in_dict