diff --git a/aslprep/workflows/asl/fit.py b/aslprep/workflows/asl/fit.py index 31520a833..a764c0d94 100644 --- a/aslprep/workflows/asl/fit.py +++ b/aslprep/workflows/asl/fit.py @@ -26,7 +26,6 @@ import nibabel as nb from nipype.interfaces import utility as niu from nipype.pipeline import engine as pe -from niworkflows.func.util import init_enhance_and_skullstrip_asl_wf from niworkflows.interfaces.header import ValidateImage from niworkflows.interfaces.nitransforms import ConcatenateXFMs from niworkflows.interfaces.utility import KeySelect @@ -41,20 +40,20 @@ ResampleSeries, ) from fmriprep.utils.bids import extract_entities +from fmriprep.workflows.bold.registration import init_bold_reg_wf # ASL workflows from aslprep.interfaces.utility import ReduceASLFiles from aslprep.utils.asl import select_processing_target -from aslprep.utils.misc import estimate_asl_mem_usage from aslprep.workflows.asl.hmc import init_asl_hmc_wf from aslprep.workflows.asl.outputs import ( init_ds_aslref_wf, init_ds_hmc_wf, init_ds_registration_wf, - init_func_fit_reports_wf, + init_asl_fit_reports_wf, ) from aslprep.workflows.asl.reference import init_raw_aslref_wf -from aslprep.workflows.asl.registration import init_asl_reg_wf +from aslprep.workflows.asl.util import init_enhance_and_skullstrip_asl_wf def get_sbrefs( @@ -75,7 +74,7 @@ def get_sbrefs( Returns ------- - sbref_files + sbref_file List of absolute paths to sbref files associated with input ASL files, sorted by EchoTime """ @@ -88,12 +87,12 @@ def get_sbrefs( def init_asl_fit_wf( *, asl_file: str, + m0scan: str, fieldmap_id: ty.Optional[str] = None, omp_nthreads: int = 1, name: str = "asl_fit_wf", ) -> pe.Workflow: - """ - This workflow controls the minimal estimation steps for functional preprocessing. + """This workflow controls the minimal estimation steps for functional preprocessing. Workflow Graph .. workflow:: @@ -182,19 +181,14 @@ def init_asl_fit_wf( layout = config.execution.layout # Collect asl and sbref files, sorted by EchoTime - asl_files = sorted(asl_file, key=lambda fname: layout.get_metadata(fname).get("EchoTime")) - sbref_files = get_sbrefs( - asl_files, + sbref_file = get_sbrefs( + asl_file, entity_overrides=config.execution.get().get("bids_filters", {}).get("sbref", {}), layout=layout, ) - # Fitting operates on the shortest echo - # This could become more complicated in the future - asl_file = asl_files[0] - # Get metadata from ASL file(s) - entities = extract_entities(asl_files) + entities = extract_entities(asl_file) metadata = layout.get_metadata(asl_file) orientation = "".join(nb.aff2axcodes(nb.load(asl_file).affine)) @@ -278,7 +272,7 @@ def init_asl_fit_wf( ) summary.inputs.dummy_scans = config.workflow.dummy_scans - func_fit_reports_wf = init_func_fit_reports_wf( + asl_fit_reports_wf = init_asl_fit_reports_wf( sdc_correction=not (fieldmap_id is None), freesurfer=config.workflow.run_reconall, output_dir=config.execution.fmriprep_dir, @@ -300,7 +294,7 @@ def init_asl_fit_wf( ("movpar_file", "movpar_file"), ("rmsd_file", "rmsd_file"), ]), - (inputnode, func_fit_reports_wf, [ + (inputnode, asl_fit_reports_wf, [ ("asl_file", "inputnode.source_file"), ("t1w_preproc", "inputnode.t1w_preproc"), # May not need all of these @@ -309,11 +303,11 @@ def init_asl_fit_wf( ("subjects_dir", "inputnode.subjects_dir"), ("subject_id", "inputnode.subject_id"), ]), - (outputnode, func_fit_reports_wf, [ + (outputnode, asl_fit_reports_wf, [ ("coreg_aslref", "inputnode.coreg_aslref"), ("aslref2anat_xfm", "inputnode.aslref2anat_xfm"), ]), - (summary, func_fit_reports_wf, [("out_report", "inputnode.summary_report")]), + (summary, asl_fit_reports_wf, [("out_report", "inputnode.summary_report")]), ]) # fmt:on @@ -324,7 +318,7 @@ def init_asl_fit_wf( asl_file=asl_file, m0scan=(metadata["M0Type"] == "Separate"), ) - hmc_aslref_wf.inputs.inputnode.m0scan = run_data["m0scan"] + hmc_aslref_wf.inputs.inputnode.m0scan = m0scan hmc_aslref_wf.inputs.inputnode.dummy_scans = config.workflow.dummy_scans ds_hmc_aslref_wf = init_ds_aslref_wf( @@ -344,7 +338,7 @@ def init_asl_fit_wf( ]), (hmcref_buffer, ds_hmc_aslref_wf, [("aslref", "inputnode.aslref")]), (hmc_aslref_wf, summary, [("outputnode.algo_dummy_scans", "algo_dummy_scans")]), - (hmc_aslref_wf, func_fit_reports_wf, [ + (hmc_aslref_wf, asl_fit_reports_wf, [ ("outputnode.validation_report", "inputnode.validation_report"), ]), ]) @@ -406,7 +400,7 @@ def init_asl_fit_wf( config.loggers.workflow.info("Stage 3: Adding coregistration aslref workflow") # Select initial aslref, enhance contrast, and generate mask - fmapref_buffer.inputs.sbref_files = sbref_files + fmapref_buffer.inputs.sbref_file = sbref_file enhance_aslref_wf = init_enhance_and_skullstrip_asl_wf(omp_nthreads=omp_nthreads) ds_coreg_aslref_wf = init_ds_aslref_wf( @@ -422,7 +416,7 @@ def init_asl_fit_wf( (fmapref_buffer, enhance_aslref_wf, [("out", "inputnode.in_file")]), (fmapref_buffer, ds_coreg_aslref_wf, [("out", "inputnode.source_files")]), (ds_coreg_aslref_wf, regref_buffer, [("outputnode.aslref", "aslref")]), - (fmapref_buffer, func_fit_reports_wf, [("out", "inputnode.sdc_aslref")]), + (fmapref_buffer, asl_fit_reports_wf, [("out", "inputnode.sdc_aslref")]), ]) # fmt:on @@ -499,12 +493,12 @@ def init_asl_fit_wf( ]), (unwarp_wf, ds_coreg_aslref_wf, [("outputnode.corrected", "inputnode.aslref")]), (unwarp_wf, regref_buffer, [("outputnode.corrected_mask", "aslmask")]), - (fmap_select, func_fit_reports_wf, [("fmap_ref", "inputnode.fmap_ref")]), + (fmap_select, asl_fit_reports_wf, [("fmap_ref", "inputnode.fmap_ref")]), (fmap_select, summary, [("sdc_method", "distortion_correction")]), - (fmapreg_buffer, func_fit_reports_wf, [ + (fmapreg_buffer, asl_fit_reports_wf, [ ("aslref2fmap_xfm", "inputnode.aslref2fmap_xfm"), ]), - (unwarp_wf, func_fit_reports_wf, [("outputnode.fieldmap", "inputnode.fieldmap")]), + (unwarp_wf, asl_fit_reports_wf, [("outputnode.fieldmap", "inputnode.fieldmap")]), ]) # fmt:on else: @@ -518,9 +512,9 @@ def init_asl_fit_wf( # fmt:on # calculate ASL registration to T1w - asl_reg_wf = init_asl_reg_wf( - asl2t1w_dof=config.workflow.asl2t1w_dof, - asl2t1w_init=config.workflow.asl2t1w_init, + asl_reg_wf = init_bold_reg_wf( + bold2t1w_dof=config.workflow.asl2t1w_dof, + bold2t1w_init=config.workflow.asl2t1w_init, freesurfer=config.workflow.run_reconall, mem_gb=mem_gb["resampled"], name="asl_reg_wf", @@ -550,10 +544,10 @@ def init_asl_fit_wf( ("subject_id", "inputnode.subject_id"), ("fsnative2t1w_xfm", "inputnode.fsnative2t1w_xfm"), ]), - (regref_buffer, asl_reg_wf, [("aslref", "inputnode.ref_asl_brain")]), + (regref_buffer, asl_reg_wf, [("aslref", "inputnode.ref_bold_brain")]), # Incomplete sources (regref_buffer, ds_aslreg_wf, [("aslref", "inputnode.source_files")]), - (asl_reg_wf, ds_aslreg_wf, [("outputnode.itk_asl_to_t1", "inputnode.xform")]), + (asl_reg_wf, ds_aslreg_wf, [("outputnode.itk_bold_to_t1", "inputnode.xform")]), (ds_aslreg_wf, outputnode, [("outputnode.xform", "aslref2anat_xfm")]), (asl_reg_wf, summary, [("outputnode.fallback", "fallback")]), ]) @@ -777,9 +771,9 @@ def init_asl_native_wf( return workflow -def _select_ref(sbref_files, aslref_files): +def _select_ref(sbref_file, aslref_files): """Select first sbref or aslref file, preferring sbref if available""" from niworkflows.utils.connections import listify - refs = sbref_files or aslref_files + refs = sbref_file or aslref_files return listify(refs)[0] diff --git a/aslprep/workflows/asl/outputs.py b/aslprep/workflows/asl/outputs.py index 0a29c8812..27d9d5de1 100644 --- a/aslprep/workflows/asl/outputs.py +++ b/aslprep/workflows/asl/outputs.py @@ -8,6 +8,7 @@ from niworkflows.engine.workflows import LiterateWorkflow as Workflow from niworkflows.interfaces.fixes import FixHeaderApplyTransforms as ApplyTransforms from niworkflows.interfaces.utility import KeySelect +from niworkflows.utils.images import dseg_label from smriprep.workflows.outputs import _bids_relative from aslprep import config @@ -70,6 +71,331 @@ } +def init_asl_fit_reports_wf( + *, + sdc_correction: bool, + freesurfer: bool, + output_dir: str, + name="asl_fit_reports_wf", +) -> pe.Workflow: + """Set up a battery of datasinks to store reports in the right location. + + Parameters + ---------- + freesurfer : :obj:`bool` + FreeSurfer was enabled + output_dir : :obj:`str` + Directory in which to save derivatives + name : :obj:`str` + Workflow name (default: anat_reports_wf) + + Inputs + ------ + source_file + Input BOLD images + + std_t1w + T1w image resampled to standard space + std_mask + Mask of skull-stripped template + subject_dir + FreeSurfer SUBJECTS_DIR + subject_id + FreeSurfer subject ID + t1w_conform_report + Conformation report + t1w_preproc + The T1w reference map, which is calculated as the average of bias-corrected + and preprocessed T1w images, defining the anatomical space. + t1w_dseg + Segmentation in T1w space + t1w_mask + Brain (binary) mask estimated by brain extraction. + template + Template space and specifications + + """ + from niworkflows.interfaces.reportlets.registration import ( + SimpleBeforeAfterRPT as SimpleBeforeAfter, + ) + from sdcflows.interfaces.reportlets import FieldmapReportlet + + workflow = pe.Workflow(name=name) + + inputfields = [ + "source_file", + "sdc_aslref", + "coreg_aslref", + "aslref2anat_xfm", + "aslref2fmap_xfm", + "t1w_preproc", + "t1w_mask", + "t1w_dseg", + "fieldmap", + "fmap_ref", + # May be missing + "subject_id", + "subjects_dir", + # Report snippets + "summary_report", + "validation_report", + ] + inputnode = pe.Node(niu.IdentityInterface(fields=inputfields), name="inputnode") + + ds_summary = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc="summary", + datatype="figures", + dismiss_entities=("echo",), + ), + name="ds_report_summary", + run_without_submitting=True, + mem_gb=config.DEFAULT_MEMORY_MIN_GB, + ) + + ds_validation = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc="validation", + datatype="figures", + dismiss_entities=("echo",), + ), + name="ds_report_validation", + run_without_submitting=True, + mem_gb=config.DEFAULT_MEMORY_MIN_GB, + ) + + # Resample anatomical references into BOLD space for plotting + t1w_aslref = pe.Node( + ApplyTransforms( + dimension=3, + default_value=0, + float=True, + invert_transform_flags=[True], + interpolation="LanczosWindowedSinc", + ), + name="t1w_aslref", + mem_gb=1, + ) + + t1w_wm = pe.Node( + niu.Function(function=dseg_label), + name="t1w_wm", + mem_gb=config.DEFAULT_MEMORY_MIN_GB, + ) + t1w_wm.inputs.label = 2 # BIDS default is WM=2 + + aslref_wm = pe.Node( + ApplyTransforms( + dimension=3, + default_value=0, + invert_transform_flags=[True], + interpolation="NearestNeighbor", + ), + name="aslref_wm", + mem_gb=1, + ) + + # fmt:off + workflow.connect([ + (inputnode, ds_summary, [ + ("source_file", "source_file"), + ("summary_report", "in_file"), + ]), + (inputnode, ds_validation, [ + ("source_file", "source_file"), + ("validation_report", "in_file"), + ]), + (inputnode, t1w_aslref, [ + ("t1w_preproc", "input_image"), + ("coreg_aslref", "reference_image"), + ("aslref2anat_xfm", "transforms"), + ]), + (inputnode, t1w_wm, [("t1w_dseg", "in_seg")]), + (inputnode, aslref_wm, [ + ("coreg_aslref", "reference_image"), + ("aslref2anat_xfm", "transforms"), + ]), + (t1w_wm, aslref_wm, [("out", "input_image")]), + ]) + # fmt:on + + # Reportlets follow the structure of init_asl_fit_wf stages + # - SDC1: + # Before: Pre-SDC aslref + # After: Fieldmap reference resampled on aslref + # Three-way: Fieldmap resampled on aslref + # - SDC2: + # Before: Pre-SDC aslref with white matter mask + # After: Post-SDC aslref with white matter mask + # - EPI-T1 registration: + # Before: T1w brain with white matter mask + # After: Resampled aslref with white matter mask + + if sdc_correction: + fmapref_aslref = pe.Node( + ApplyTransforms( + dimension=3, + default_value=0, + float=True, + invert_transform_flags=[True], + interpolation="LanczosWindowedSinc", + ), + name="fmapref_aslref", + mem_gb=1, + ) + + # SDC1 + sdcreg_report = pe.Node( + FieldmapReportlet( + reference_label="BOLD reference", + moving_label="Fieldmap reference", + show="both", + ), + name="sdecreg_report", + mem_gb=0.1, + ) + + ds_sdcreg_report = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc="fmapCoreg", + suffix="asl", + datatype="figures", + dismiss_entities=("echo",), + ), + name="ds_sdcreg_report", + ) + + # SDC2 + sdc_report = pe.Node( + SimpleBeforeAfter( + before_label="Distorted", + after_label="Corrected", + dismiss_affine=True, + ), + name="sdc_report", + mem_gb=0.1, + ) + + ds_sdc_report = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc="sdc", + suffix="asl", + datatype="figures", + dismiss_entities=("echo",), + ), + name="ds_sdc_report", + ) + + # fmt:off + workflow.connect([ + (inputnode, fmapref_aslref, [ + ("fmap_ref", "input_image"), + ("coreg_aslref", "reference_image"), + ("aslref2fmap_xfm", "transforms"), + ]), + (inputnode, sdcreg_report, [ + ("sdc_aslref", "reference"), + ("fieldmap", "fieldmap") + ]), + (fmapref_aslref, sdcreg_report, [("output_image", "moving")]), + (inputnode, ds_sdcreg_report, [("source_file", "source_file")]), + (sdcreg_report, ds_sdcreg_report, [("out_report", "in_file")]), + (inputnode, sdc_report, [ + ("sdc_aslref", "before"), + ("coreg_aslref", "after"), + ]), + (aslref_wm, sdc_report, [("output_image", "wm_seg")]), + (inputnode, ds_sdc_report, [("source_file", "source_file")]), + (sdc_report, ds_sdc_report, [("out_report", "in_file")]), + ]) + # fmt:on + + # EPI-T1 registration + # Resample T1w image onto EPI-space + + epi_t1_report = pe.Node( + SimpleBeforeAfter( + before_label="T1w", + after_label="EPI", + dismiss_affine=True, + ), + name="epi_t1_report", + mem_gb=0.1, + ) + + ds_epi_t1_report = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc="coreg", + suffix="asl", + datatype="figures", + dismiss_entities=("echo",), + ), + name="ds_epi_t1_report", + ) + + # fmt:off + workflow.connect([ + (inputnode, epi_t1_report, [("coreg_aslref", "after")]), + (t1w_aslref, epi_t1_report, [("output_image", "before")]), + (aslref_wm, epi_t1_report, [("output_image", "wm_seg")]), + (inputnode, ds_epi_t1_report, [("source_file", "source_file")]), + (epi_t1_report, ds_epi_t1_report, [("out_report", "in_file")]), + ]) + # fmt:on + + return workflow + + +def init_ds_aslref_wf( + *, + bids_root, + output_dir, + desc: str, + name="ds_aslref_wf", +) -> pe.Workflow: + workflow = pe.Workflow(name=name) + + inputnode = pe.Node( + niu.IdentityInterface(fields=["source_files", "aslref"]), + name="inputnode", + ) + outputnode = pe.Node(niu.IdentityInterface(fields=["aslref"]), name="outputnode") + + raw_sources = pe.Node(niu.Function(function=_bids_relative), name="raw_sources") + raw_sources.inputs.bids_root = bids_root + + ds_aslref = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc=desc, + suffix="aslref", + compress=True, + dismiss_entities=("echo",), + ), + name="ds_aslref", + run_without_submitting=True, + ) + + # fmt:off + workflow.connect([ + (inputnode, raw_sources, [("source_files", "in_files")]), + (inputnode, ds_aslref, [ + ("aslref", "in_file"), + ("source_files", "source_file"), + ]), + (raw_sources, ds_aslref, [("out", "RawSources")]), + (ds_aslref, outputnode, [("out_file", "aslref")]), + ]) + # fmt:on + + return workflow + + def init_ds_registration_wf( *, bids_root: str, @@ -122,6 +448,52 @@ def init_ds_registration_wf( return workflow +def init_ds_hmc_wf( + *, + bids_root, + output_dir, + name="ds_hmc_wf", +) -> pe.Workflow: + workflow = pe.Workflow(name=name) + + inputnode = pe.Node( + niu.IdentityInterface(fields=["source_files", "xforms"]), + name="inputnode", + ) + outputnode = pe.Node(niu.IdentityInterface(fields=["xforms"]), name="outputnode") + + raw_sources = pe.Node(niu.Function(function=_bids_relative), name="raw_sources") + raw_sources.inputs.bids_root = bids_root + + ds_xforms = pe.Node( + DerivativesDataSink( + base_directory=output_dir, + desc="hmc", + suffix="xfm", + extension=".txt", + compress=True, + dismiss_entities=("echo",), + **{"from": "orig", "to": "aslref"}, + ), + name="ds_xforms", + run_without_submitting=True, + ) + + # fmt:off + workflow.connect([ + (inputnode, raw_sources, [("source_files", "in_files")]), + (inputnode, ds_xforms, [ + ("xforms", "in_file"), + ("source_files", "source_file"), + ]), + (raw_sources, ds_xforms, [("out", "RawSources")]), + (ds_xforms, outputnode, [("out_file", "xforms")]), + ]) + # fmt:on + + return workflow + + def init_ds_asl_native_wf( *, bids_root: str, @@ -441,7 +813,7 @@ def init_asl_derivatives_wf( is_multi_pld : :obj:`bool` True if data are multi-delay, False otherwise. name : :obj:`str` - This workflow's identifier (default: ``func_derivatives_wf``). + This workflow's identifier (default: ``asl_derivatives_wf``). """ nonstd_spaces = set(spaces.get_nonstandard()) workflow = Workflow(name=name) @@ -902,7 +1274,7 @@ def init_asl_derivatives_wf( ds_asl_surfs = pe.MapNode( DerivativesDataSink( base_directory=config.execution.aslprep_dir, - extension=".func.gii", + extension=".asl.gii", TaskName=metadata.get("TaskName"), ), iterfield=["in_file", "hemi"],