Skip to content

Commit

Permalink
More work.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Nov 17, 2023
1 parent 9837a0c commit 14d6300
Show file tree
Hide file tree
Showing 2 changed files with 401 additions and 35 deletions.
60 changes: 27 additions & 33 deletions aslprep/workflows/asl/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import nibabel as nb
from nipype.interfaces import utility as niu
from nipype.pipeline import engine as pe
from niworkflows.func.util import init_enhance_and_skullstrip_asl_wf
from niworkflows.interfaces.header import ValidateImage
from niworkflows.interfaces.nitransforms import ConcatenateXFMs
from niworkflows.interfaces.utility import KeySelect
Expand All @@ -41,20 +40,20 @@
ResampleSeries,
)
from fmriprep.utils.bids import extract_entities
from fmriprep.workflows.bold.registration import init_bold_reg_wf

# ASL workflows
from aslprep.interfaces.utility import ReduceASLFiles
from aslprep.utils.asl import select_processing_target
from aslprep.utils.misc import estimate_asl_mem_usage
from aslprep.workflows.asl.hmc import init_asl_hmc_wf
from aslprep.workflows.asl.outputs import (
init_ds_aslref_wf,
init_ds_hmc_wf,
init_ds_registration_wf,
init_func_fit_reports_wf,
init_asl_fit_reports_wf,
)
from aslprep.workflows.asl.reference import init_raw_aslref_wf
from aslprep.workflows.asl.registration import init_asl_reg_wf
from aslprep.workflows.asl.util import init_enhance_and_skullstrip_asl_wf


def get_sbrefs(
Expand All @@ -75,7 +74,7 @@ def get_sbrefs(
Returns
-------
sbref_files
sbref_file
List of absolute paths to sbref files associated with input ASL files,
sorted by EchoTime
"""
Expand All @@ -88,12 +87,12 @@ def get_sbrefs(
def init_asl_fit_wf(
*,
asl_file: str,
m0scan: str,
fieldmap_id: ty.Optional[str] = None,
omp_nthreads: int = 1,
name: str = "asl_fit_wf",
) -> pe.Workflow:
"""
This workflow controls the minimal estimation steps for functional preprocessing.
"""This workflow controls the minimal estimation steps for functional preprocessing.
Workflow Graph
.. workflow::
Expand Down Expand Up @@ -182,19 +181,14 @@ def init_asl_fit_wf(
layout = config.execution.layout

# Collect asl and sbref files, sorted by EchoTime
asl_files = sorted(asl_file, key=lambda fname: layout.get_metadata(fname).get("EchoTime"))
sbref_files = get_sbrefs(
asl_files,
sbref_file = get_sbrefs(
asl_file,
entity_overrides=config.execution.get().get("bids_filters", {}).get("sbref", {}),
layout=layout,
)

# Fitting operates on the shortest echo
# This could become more complicated in the future
asl_file = asl_files[0]

# Get metadata from ASL file(s)
entities = extract_entities(asl_files)
entities = extract_entities(asl_file)
metadata = layout.get_metadata(asl_file)
orientation = "".join(nb.aff2axcodes(nb.load(asl_file).affine))

Expand Down Expand Up @@ -278,7 +272,7 @@ def init_asl_fit_wf(
)
summary.inputs.dummy_scans = config.workflow.dummy_scans

func_fit_reports_wf = init_func_fit_reports_wf(
asl_fit_reports_wf = init_asl_fit_reports_wf(
sdc_correction=not (fieldmap_id is None),
freesurfer=config.workflow.run_reconall,
output_dir=config.execution.fmriprep_dir,
Expand All @@ -300,7 +294,7 @@ def init_asl_fit_wf(
("movpar_file", "movpar_file"),
("rmsd_file", "rmsd_file"),
]),
(inputnode, func_fit_reports_wf, [
(inputnode, asl_fit_reports_wf, [
("asl_file", "inputnode.source_file"),
("t1w_preproc", "inputnode.t1w_preproc"),
# May not need all of these
Expand All @@ -309,11 +303,11 @@ def init_asl_fit_wf(
("subjects_dir", "inputnode.subjects_dir"),
("subject_id", "inputnode.subject_id"),
]),
(outputnode, func_fit_reports_wf, [
(outputnode, asl_fit_reports_wf, [
("coreg_aslref", "inputnode.coreg_aslref"),
("aslref2anat_xfm", "inputnode.aslref2anat_xfm"),
]),
(summary, func_fit_reports_wf, [("out_report", "inputnode.summary_report")]),
(summary, asl_fit_reports_wf, [("out_report", "inputnode.summary_report")]),
])
# fmt:on

Expand All @@ -324,7 +318,7 @@ def init_asl_fit_wf(
asl_file=asl_file,
m0scan=(metadata["M0Type"] == "Separate"),
)
hmc_aslref_wf.inputs.inputnode.m0scan = run_data["m0scan"]
hmc_aslref_wf.inputs.inputnode.m0scan = m0scan
hmc_aslref_wf.inputs.inputnode.dummy_scans = config.workflow.dummy_scans

ds_hmc_aslref_wf = init_ds_aslref_wf(
Expand All @@ -344,7 +338,7 @@ def init_asl_fit_wf(
]),
(hmcref_buffer, ds_hmc_aslref_wf, [("aslref", "inputnode.aslref")]),
(hmc_aslref_wf, summary, [("outputnode.algo_dummy_scans", "algo_dummy_scans")]),
(hmc_aslref_wf, func_fit_reports_wf, [
(hmc_aslref_wf, asl_fit_reports_wf, [
("outputnode.validation_report", "inputnode.validation_report"),
]),
])
Expand Down Expand Up @@ -406,7 +400,7 @@ def init_asl_fit_wf(
config.loggers.workflow.info("Stage 3: Adding coregistration aslref workflow")

# Select initial aslref, enhance contrast, and generate mask
fmapref_buffer.inputs.sbref_files = sbref_files
fmapref_buffer.inputs.sbref_file = sbref_file
enhance_aslref_wf = init_enhance_and_skullstrip_asl_wf(omp_nthreads=omp_nthreads)

ds_coreg_aslref_wf = init_ds_aslref_wf(
Expand All @@ -422,7 +416,7 @@ def init_asl_fit_wf(
(fmapref_buffer, enhance_aslref_wf, [("out", "inputnode.in_file")]),
(fmapref_buffer, ds_coreg_aslref_wf, [("out", "inputnode.source_files")]),
(ds_coreg_aslref_wf, regref_buffer, [("outputnode.aslref", "aslref")]),
(fmapref_buffer, func_fit_reports_wf, [("out", "inputnode.sdc_aslref")]),
(fmapref_buffer, asl_fit_reports_wf, [("out", "inputnode.sdc_aslref")]),
])
# fmt:on

Expand Down Expand Up @@ -499,12 +493,12 @@ def init_asl_fit_wf(
]),
(unwarp_wf, ds_coreg_aslref_wf, [("outputnode.corrected", "inputnode.aslref")]),
(unwarp_wf, regref_buffer, [("outputnode.corrected_mask", "aslmask")]),
(fmap_select, func_fit_reports_wf, [("fmap_ref", "inputnode.fmap_ref")]),
(fmap_select, asl_fit_reports_wf, [("fmap_ref", "inputnode.fmap_ref")]),
(fmap_select, summary, [("sdc_method", "distortion_correction")]),
(fmapreg_buffer, func_fit_reports_wf, [
(fmapreg_buffer, asl_fit_reports_wf, [
("aslref2fmap_xfm", "inputnode.aslref2fmap_xfm"),
]),
(unwarp_wf, func_fit_reports_wf, [("outputnode.fieldmap", "inputnode.fieldmap")]),
(unwarp_wf, asl_fit_reports_wf, [("outputnode.fieldmap", "inputnode.fieldmap")]),
])
# fmt:on
else:
Expand All @@ -518,9 +512,9 @@ def init_asl_fit_wf(
# fmt:on

# calculate ASL registration to T1w
asl_reg_wf = init_asl_reg_wf(
asl2t1w_dof=config.workflow.asl2t1w_dof,
asl2t1w_init=config.workflow.asl2t1w_init,
asl_reg_wf = init_bold_reg_wf(
bold2t1w_dof=config.workflow.asl2t1w_dof,
bold2t1w_init=config.workflow.asl2t1w_init,
freesurfer=config.workflow.run_reconall,
mem_gb=mem_gb["resampled"],
name="asl_reg_wf",
Expand Down Expand Up @@ -550,10 +544,10 @@ def init_asl_fit_wf(
("subject_id", "inputnode.subject_id"),
("fsnative2t1w_xfm", "inputnode.fsnative2t1w_xfm"),
]),
(regref_buffer, asl_reg_wf, [("aslref", "inputnode.ref_asl_brain")]),
(regref_buffer, asl_reg_wf, [("aslref", "inputnode.ref_bold_brain")]),
# Incomplete sources
(regref_buffer, ds_aslreg_wf, [("aslref", "inputnode.source_files")]),
(asl_reg_wf, ds_aslreg_wf, [("outputnode.itk_asl_to_t1", "inputnode.xform")]),
(asl_reg_wf, ds_aslreg_wf, [("outputnode.itk_bold_to_t1", "inputnode.xform")]),
(ds_aslreg_wf, outputnode, [("outputnode.xform", "aslref2anat_xfm")]),
(asl_reg_wf, summary, [("outputnode.fallback", "fallback")]),
])
Expand Down Expand Up @@ -777,9 +771,9 @@ def init_asl_native_wf(
return workflow


def _select_ref(sbref_files, aslref_files):
def _select_ref(sbref_file, aslref_files):
"""Select first sbref or aslref file, preferring sbref if available"""
from niworkflows.utils.connections import listify

refs = sbref_files or aslref_files
refs = sbref_file or aslref_files
return listify(refs)[0]
Loading

0 comments on commit 14d6300

Please sign in to comment.