Skip to content

Commit

Permalink
Lots of changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Nov 20, 2023
1 parent 80477b5 commit e5cdab4
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 120 deletions.
13 changes: 9 additions & 4 deletions aslprep/interfaces/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
class _RefineMaskInputSpec(BaseInterfaceInputSpec):
t1w_mask = File(exists=True, mandatory=True, desc="t1 mask")
asl_mask = File(exists=True, mandatory=True, desct="asl mask")
transforms = File(exists=True, mandatory=True, desc="transfom")
aslref2anat_xfm = File(
exists=True,
mandatory=True,
desc="Transform from reference to anatomical space.",
)


class _RefineMaskOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -67,7 +71,7 @@ def _run_interface(self, runtime):
refine_ref_mask(
t1w_mask=self.inputs.t1w_mask,
ref_asl_mask=self.inputs.asl_mask,
t12ref_transform=self.inputs.transforms,
aslref2anat_xfm=self.inputs.aslref2anat_xfm,
tmp_mask=self._results["out_tmp"],
refined_mask=self._results["out_mask"],
)
Expand Down Expand Up @@ -984,7 +988,7 @@ def regmotoasl(asl, m0file):
return flt_results.outputs.out_file


def refine_ref_mask(t1w_mask, ref_asl_mask, t12ref_transform, tmp_mask, refined_mask):
def refine_ref_mask(t1w_mask, ref_asl_mask, aslref2anat_xfm, tmp_mask, refined_mask):
"""Warp T1w mask to ASL space, then use it to mask the ASL mask.
TODO: This should not be a function. It uses interfaces, so it should be a workflow.
Expand All @@ -995,7 +999,8 @@ def refine_ref_mask(t1w_mask, ref_asl_mask, t12ref_transform, tmp_mask, refined_
input_image=t1w_mask,
interpolation="NearestNeighbor",
reference_image=ref_asl_mask,
transforms=[t12ref_transform],
transforms=[aslref2anat_xfm],
invert_transform_flags=[True],
input_image_type=3,
output_image=tmp_mask,
)
Expand Down
18 changes: 8 additions & 10 deletions aslprep/workflows/asl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,13 +402,13 @@ def init_asl_preproc_wf(
]),
(asl_fit_wf, cbf_wf, [
("outputnode.asl_mask", "inputnode.asl_mask"),
("outputnode.aslref2anat_xfm", "inputnode.aslref2anat_xfm"),
]),
(asl_native_wf, cbf_wf, [
("outputnode.asl_native", "inputnode.asl_file"),
("outputnode.aslcontext", "inputnode.aslcontext"),
("outputnode.m0scan", "inputnode.m0scan"),
# XXX: doesn't exist!
("outputnode.anat_to_aslref_xfm", "inputnode.anat_to_aslref_xfm"),
("outputnode.aslref_to_anat_xfm", "inputnode.aslref_to_anat_xfm"),
("outputnode.m0scan_native", "inputnode.m0scan"),
]),
(asl_native_wf, cbf_wf, [("outputnode.asl_native", "inputnode.asl_file")]),
])
# fmt:on

Expand Down Expand Up @@ -517,8 +517,7 @@ def init_asl_preproc_wf(
]),
(asl_fit_wf, parcellate_cbf_wf, [
("outputnode.asl_mask", "inputnode.asl_mask"),
# XXX: doesn't exist!
("outputnode.anat_to_aslref_xfm", "inputnode.anat_to_aslref_xfm"),
("outputnode.aslref2anat_xfm", "inputnode.aslref2anat_xfm"),
]),
])
# fmt:on
Expand Down Expand Up @@ -692,7 +691,7 @@ def init_asl_preproc_wf(
]),
(asl_fit_wf, cbf_qc_wf, [
("outputnode.asl_mask", "inputnode.asl_mask"),
("outputnode.anat_to_aslref_xfm", "inputnode.anat_to_aslref_xfm"),
("outputnode.aslref2anat_xfm", "inputnode.aslref2anat_xfm"),
("outputnode.rmsd_file", "inputnode.rmsd_file"),
]),
(asl_confounds_wf, cbf_qc_wf, [("outputnode.confounds_file", "inputnode.confounds_file")]),
Expand Down Expand Up @@ -724,8 +723,7 @@ def init_asl_preproc_wf(
]),
(asl_fit_wf, plot_cbf_wf, [
("outputnode.coreg_aslref", "inputnode.aslref"),
# XXX: doesn't exist
("outputnode.anat_to_aslref_xfm", "inputnode.anat_to_aslref_xfm"),
("outputnode.aslref2anat_xfm", "inputnode.aslref2anat_xfm"),
# XXX: Used to use the one from refine_mask
("inputnode.asl_mask", "inputnode.asl_mask"),
]),
Expand Down
71 changes: 46 additions & 25 deletions aslprep/workflows/asl/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,7 @@ def init_cbf_wf(
t1w probability maps
t1w_mask
t1w mask Nifti
anat_to_aslref_xfm
t1w to asl transformation file
aslref_to_anat_xfm
aslref2anat_xfm
asl to t1w transformation file
Outputs
Expand Down Expand Up @@ -228,9 +226,8 @@ def init_cbf_wf(
"asl_mask",
"t1w_tpms",
"t1w_mask",
"anat_to_aslref_xfm",
"aslref_to_anat_xfm",
]
"aslref2anat_xfm",
],
),
name="inputnode",
)
Expand Down Expand Up @@ -268,7 +265,7 @@ def init_cbf_wf(
(inputnode, refine_mask, [
("t1w_mask", "t1w_mask"),
("asl_mask", "asl_mask"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "aslref2anat_xfm"),
]),
])
# fmt:on
Expand All @@ -289,7 +286,11 @@ def _getfiledir(file):
return os.path.dirname(file)

gm_tfm = pe.Node(
ApplyTransforms(interpolation="NearestNeighbor", float=True),
ApplyTransforms(
interpolation="NearestNeighbor",
float=True,
invert_transform_flags=[True],
),
name="gm_tfm",
mem_gb=0.1,
)
Expand All @@ -298,14 +299,18 @@ def _getfiledir(file):
workflow.connect([
(inputnode, gm_tfm, [
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
(("t1w_tpms", _pick_gm), "input_image"),
]),
])
# fmt:on

wm_tfm = pe.Node(
ApplyTransforms(interpolation="NearestNeighbor", float=True),
ApplyTransforms(
interpolation="NearestNeighbor",
float=True,
invert_transform_flags=[True],
),
name="wm_tfm",
mem_gb=0.1,
)
Expand All @@ -314,14 +319,18 @@ def _getfiledir(file):
workflow.connect([
(inputnode, wm_tfm, [
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
(("t1w_tpms", _pick_wm), "input_image"),
]),
])
# fmt:on

csf_tfm = pe.Node(
ApplyTransforms(interpolation="NearestNeighbor", float=True),
ApplyTransforms(
interpolation="NearestNeighbor",
float=True,
invert_transform_flags=[True],
),
name="csf_tfm",
mem_gb=0.1,
)
Expand All @@ -330,7 +339,7 @@ def _getfiledir(file):
workflow.connect([
(inputnode, csf_tfm, [
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
(("t1w_tpms", _pick_csf), "input_image"),
]),
])
Expand Down Expand Up @@ -563,8 +572,7 @@ def init_compute_cbf_ge_wf(
"asl_mask",
"t1w_tpms",
"t1w_mask",
"anat_to_aslref_xfm",
"aslref_to_anat_xfm",
"aslref2anat_xfm",
"m0_file",
"m0tr",
]
Expand Down Expand Up @@ -609,7 +617,11 @@ def _getfiledir(file):
# convert tmps to asl_space
# extract probability maps
csf_tfm = pe.Node(
ApplyTransforms(interpolation="NearestNeighbor", float=True),
ApplyTransforms(
interpolation="NearestNeighbor",
float=True,
invert_transform_flags=[True],
),
name="csf_tfm",
mem_gb=0.1,
)
Expand All @@ -618,14 +630,18 @@ def _getfiledir(file):
workflow.connect([
(inputnode, csf_tfm, [
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
(("t1w_tpms", _pick_csf), "input_image"),
]),
])
# fmt:on

wm_tfm = pe.Node(
ApplyTransforms(interpolation="NearestNeighbor", float=True),
ApplyTransforms(
interpolation="NearestNeighbor",
float=True,
invert_transform_flags=[True],
),
name="wm_tfm",
mem_gb=0.1,
)
Expand All @@ -634,14 +650,18 @@ def _getfiledir(file):
workflow.connect([
(inputnode, wm_tfm, [
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
(("t1w_tpms", _pick_wm), "input_image"),
]),
])
# fmt:on

gm_tfm = pe.Node(
ApplyTransforms(interpolation="NearestNeighbor", float=True),
ApplyTransforms(
interpolation="NearestNeighbor",
float=True,
invert_transform_flags=[True],
),
name="gm_tfm",
mem_gb=0.1,
)
Expand All @@ -650,7 +670,7 @@ def _getfiledir(file):
workflow.connect([
(inputnode, gm_tfm, [
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
(("t1w_tpms", _pick_gm), "input_image"),
]),
])
Expand Down Expand Up @@ -679,7 +699,7 @@ def _getfiledir(file):
(inputnode, refine_mask, [
("t1w_mask", "t1w_mask"),
("asl_mask", "asl_mask"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "aslref2anat_xfm"),
]),
])
# fmt:on
Expand Down Expand Up @@ -918,7 +938,7 @@ def init_parcellate_cbf_wf(
mean_cbf_basil : Undefined or str
mean_cbf_gm_basil : Undefined or str
asl_mask : str
anat_to_aslref_xfm : str
aslref2anat_xfm : str
MNI152NLin2009cAsym_to_anat_xfm : str
The transform from MNI152NLin2009cAsym to the subject's anatomical space.
Expand Down Expand Up @@ -965,7 +985,7 @@ def init_parcellate_cbf_wf(
"mean_cbf_basil",
"mean_cbf_gm_basil",
"asl_mask",
"anat_to_aslref_xfm",
"aslref2anat_xfm",
"MNI152NLin2009cAsym_to_anat_xfm",
],
),
Expand Down Expand Up @@ -1030,7 +1050,7 @@ def init_parcellate_cbf_wf(
workflow.connect([
(inputnode, merge_xforms, [
("MNI152NLin2009cAsym_to_anat_xfm", "in2"),
("anat_to_aslref_xfm", "in3"),
("aslref2anat_xfm", "in3"),
]),
])
# fmt:on
Expand All @@ -1041,6 +1061,7 @@ def init_parcellate_cbf_wf(
interpolation="GenericLabel",
input_image_type=3,
dimension=3,
invert_transform_flags=[False, False, True],
),
name="warp_atlases_to_asl_space",
iterfield=["input_image"],
Expand Down
12 changes: 6 additions & 6 deletions aslprep/workflows/asl/confounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def init_asl_confounds_wf(
Mask of the skull-stripped template image
t1w_tpms
List of tissue probability maps in T1w space
anat_to_aslref_xfm
aslref2anat_xfm
Affine matrix that maps the T1w space into alignment with
the native asl space
Expand Down Expand Up @@ -110,7 +110,7 @@ def init_asl_confounds_wf(
"skip_vols",
"t1w_mask",
"t1w_tpms",
"anat_to_aslref_xfm",
"aslref2anat_xfm",
],
),
name="inputnode",
Expand Down Expand Up @@ -184,7 +184,7 @@ def init_asl_confounds_wf(

# Project T1w mask into BOLD space and merge with BOLD brainmask
t1w_mask_tfm = pe.Node(
ApplyTransforms(interpolation="MultiLabel"),
ApplyTransforms(interpolation="MultiLabel", invert_transform_flags=[True]),
name="t1w_mask_tfm",
)
union_mask = pe.Node(niu.Function(function=_binary_union), name="union_mask")
Expand All @@ -199,7 +199,7 @@ def init_asl_confounds_wf(
(inputnode, t1w_mask_tfm, [
("t1w_mask", "input_image"),
("asl_mask", "reference_image"),
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
]),
(inputnode, union_mask, [("asl_mask", "mask1")]),
(t1w_mask_tfm, union_mask, [("output_image", "mask2")]),
Expand All @@ -223,15 +223,15 @@ def init_asl_confounds_wf(

# Resample probseg maps in BOLD space via T1w-to-BOLD transform
acc_msk_tfm = pe.MapNode(
ApplyTransforms(interpolation="Gaussian"),
ApplyTransforms(interpolation="Gaussian", invert_transform_flags=[True]),
iterfield=["input_image"],
name="acc_msk_tfm",
mem_gb=0.1,
)
# fmt:off
workflow.connect([
(inputnode, acc_msk_tfm, [
("anat_to_aslref_xfm", "transforms"),
("aslref2anat_xfm", "transforms"),
("asl_mask", "reference_image"),
]),
(acc_masks, acc_msk_tfm, [("out_masks", "input_image")]),
Expand Down
6 changes: 3 additions & 3 deletions aslprep/workflows/asl/ge_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def init_asl_reg_ge_wf(
name="inputnode",
)
outputnode = pe.Node(
niu.IdentityInterface(fields=["aslref_to_anat_xfm", "anat_to_aslref_xfm", "fallback"]),
niu.IdentityInterface(fields=["aslref2anat_xfm", "anat2aslref_xfm", "fallback"]),
name="outputnode",
)

Expand All @@ -203,8 +203,8 @@ def init_asl_reg_ge_wf(
("t1w_brain", "inputnode.t1w_brain"),
]),
(bbr_wf, outputnode, [
("outputnode.aslref_to_anat_xfm", "aslref_to_anat_xfm"),
("outputnode.anat_to_aslref_xfm", "anat_to_aslref_xfm"),
("outputnode.aslref2anat_xfm", "aslref2anat_xfm"),
("outputnode.anat2aslref_xfm", "anat2aslref_xfm"),
("outputnode.fallback", "fallback"),
]),
])
Expand Down
Loading

0 comments on commit e5cdab4

Please sign in to comment.