Skip to content

Commit

Permalink
Modify carpet plot.
Browse files Browse the repository at this point in the history
  • Loading branch information
tsalo committed Nov 14, 2023
1 parent 267bc22 commit ada684f
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 235 deletions.
213 changes: 49 additions & 164 deletions aslprep/interfaces/plotting.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Plotting interfaces."""
import nibabel as nb
import numpy as np
import pandas as pd
from nipype.interfaces.base import (
BaseInterfaceInputSpec,
Expand All @@ -9,14 +11,14 @@
traits,
)
from nipype.utils.filemanip import fname_presuffix

from aslprep.utils.plotting import ASLPlot, CBFPlot, CBFtsPlot
from niworkflows.utils.timeseries import _cifti_timeseries, _nifti_timeseries
from niworkflows.viz.plots import fMRIPlot


class _ASLSummaryInputSpec(BaseInterfaceInputSpec):
in_func = File(exists=True, mandatory=True, desc="input ASL time-series (4D file)")
in_mask = File(exists=True, desc="3D brain mask")
in_segm = File(exists=True, desc="resampled segmentation")
in_nifti = File(exists=True, mandatory=True, desc="input BOLD (4D NIfTI file)")
in_cifti = File(exists=True, desc="input BOLD (CIFTI dense timeseries)")
in_segm = File(exists=True, desc="volumetric segmentation corresponding to in_nifti")
confounds_file = File(exists=True, desc="BIDS' _confounds.tsv file")

str_or_tuple = traits.Either(
Expand All @@ -25,34 +27,56 @@ class _ASLSummaryInputSpec(BaseInterfaceInputSpec):
traits.Tuple(traits.Str, traits.Either(None, traits.Str), traits.Either(None, traits.Str)),
)
confounds_list = traits.List(
str_or_tuple,
minlen=1,
desc="list of headers to extract from the confounds_file",
str_or_tuple, minlen=1, desc="list of headers to extract from the confounds_file"
)
tr = traits.Either(None, traits.Float, usedefault=True, desc="the repetition time")
drop_trs = traits.Int(0, usedefault=True, desc="dummy scans")


class _ASLSummaryOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="written file path")


class ASLSummary(SimpleInterface):
"""Copy the x-form matrices from `hdr_file` to `out_file`.
Clearly that's wrong.
"""
"""Copy the x-form matrices from `hdr_file` to `out_file`."""

input_spec = _ASLSummaryInputSpec
output_spec = _ASLSummaryOutputSpec

def _run_interface(self, runtime):
self._results["out_file"] = fname_presuffix(
self.inputs.in_func,
suffix="_aslplot.svg",
use_ext=False,
newpath=runtime.cwd,
self.inputs.in_nifti, suffix="_fmriplot.svg", use_ext=False, newpath=runtime.cwd
)

has_cifti = isdefined(self.inputs.in_cifti)

# Read input object and create timeseries + segments object
seg_file = self.inputs.in_segm if isdefined(self.inputs.in_segm) else None
dataset, segments = _nifti_timeseries(
nb.load(self.inputs.in_nifti),
nb.load(seg_file),
remap_rois=False,
labels=(
("WM+CSF", "Edge")
if has_cifti
else ("Ctx GM", "dGM", "sWM+sCSF", "dWM+dCSF", "Cb", "Edge")
),
)

# Process CIFTI
if has_cifti:
cifti_data, cifti_segments = _cifti_timeseries(nb.load(self.inputs.in_cifti))

if seg_file is not None:
# Append WM+CSF and Edge masks
cifti_length = cifti_data.shape[0]
dataset = np.vstack((cifti_data, dataset))
segments = {k: np.array(v) + cifti_length for k, v in segments.items()}
cifti_segments.update(segments)
segments = cifti_segments
else:
dataset, segments = cifti_data, cifti_segments

dataframe = pd.read_csv(
self.inputs.confounds_file,
sep="\t",
Expand Down Expand Up @@ -83,157 +107,18 @@ def _run_interface(self, runtime):
else:
data = dataframe[headers]

colnames = data.columns.ravel().tolist()

for name, newname in list(names.items()):
colnames[colnames.index(name)] = newname
data = data.rename(columns=names)

data.columns = colnames

fig = ASLPlot(
self.inputs.in_func,
mask_file=self.inputs.in_mask if isdefined(self.inputs.in_mask) else None,
seg_file=(self.inputs.in_segm if isdefined(self.inputs.in_segm) else None),
fig = fMRIPlot(
dataset,
segments=segments,
tr=self.inputs.tr,
data=data,
confounds=data,
units=units,
nskip=self.inputs.drop_trs,
paired_carpet=has_cifti,
# The main change from fMRIPrep's usage is that detrend is False for ASL.
detrend=False,
).plot()
fig.savefig(self._results["out_file"], bbox_inches="tight")
return runtime


class _CBFSummaryInputSpec(BaseInterfaceInputSpec):
cbf = File(exists=True, mandatory=True, desc="")
label = traits.Str(exists=True, mandatory=True, desc="label")
vmax = traits.Int(exists=True, default_value=90, mandatory=True, desc="max value of asl")
ref_vol = File(exists=True, mandatory=True, desc="")


class _CBFSummaryOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="written file path")


class CBFSummary(SimpleInterface):
"""Prepare an CBF summary plot for the report.
This plot restricts CBF values to -20 (if there are negative values) or 0 (if not) to 100.
"""

input_spec = _CBFSummaryInputSpec
output_spec = _CBFSummaryOutputSpec

def _run_interface(self, runtime):
self._results["out_file"] = fname_presuffix(
self.inputs.cbf,
suffix="_cbfplot.svg",
use_ext=False,
newpath=runtime.cwd,
)
CBFPlot(
cbf=self.inputs.cbf,
label=self.inputs.label,
ref_vol=self.inputs.ref_vol,
vmax=self.inputs.vmax,
outfile=self._results["out_file"],
).plot()
return runtime


class _CBFtsSummaryInputSpec(BaseInterfaceInputSpec):
cbf_ts = File(exists=True, mandatory=True, desc=" cbf time series")
confounds_file = File(exists=True, mandatory=False, desc="confound file ")
score_outlier_index = File(exists=True, mandatory=False, desc="scorexindex file ")
seg_file = File(exists=True, mandatory=True, desc="seg_file")
tr = traits.Float(desc="TR", mandatory=True)


class _CBFtsSummaryOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="written file path")


class CBFtsSummary(SimpleInterface):
"""Prepare an CBF summary plot for the report."""

input_spec = _CBFtsSummaryInputSpec
output_spec = _CBFtsSummaryOutputSpec

def _run_interface(self, runtime):
self._results["out_file"] = fname_presuffix(
self.inputs.cbf_ts,
suffix="_cbfcarpetplot.svg",
use_ext=False,
newpath=runtime.cwd,
)
fig = CBFtsPlot(
cbf_file=self.inputs.cbf_ts,
seg_file=self.inputs.seg_file,
score_outlier_index=self.inputs.score_outlier_index,
tr=self.inputs.tr,
).plot()
fig.savefig(self._results["out_file"], bbox_inches="tight")
return runtime


class _CBFByTissueTypePlotInputSpec(BaseInterfaceInputSpec):
cbf = File(exists=True, mandatory=True, desc="")
seg_file = File(exists=True, mandatory=True, desc="Segmentation file")


class _CBFByTissueTypePlotOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="written file path")


class CBFByTissueTypePlot(SimpleInterface):
"""Prepare an CBF summary plot for the report."""

input_spec = _CBFByTissueTypePlotInputSpec
output_spec = _CBFByTissueTypePlotOutputSpec

def _run_interface(self, runtime):
import matplotlib.pyplot as plt
import seaborn as sns
from nilearn import image, masking

self._results["out_file"] = fname_presuffix(
self.inputs.cbf,
suffix="_cbfplot.svg",
use_ext=False,
newpath=runtime.cwd,
)

dfs = []
for i_tissue_type, tissue_type in enumerate(["GM", "WM", "CSF"]):
tissue_type_val = i_tissue_type + 1
mask_img = image.math_img(
f"(img == {tissue_type_val}).astype(int)",
img=self.inputs.seg_file,
)
tissue_type_vals = masking.apply_mask(self.inputs.cbf, mask_img)
df = pd.DataFrame(
columns=["CBF\n(mL/100 g/min)", "Tissue Type"],
data=list(
map(list, zip(*[tissue_type_vals, [tissue_type] * tissue_type_vals.size]))
),
)
dfs.append(df)

df = pd.concat(dfs, axis=0)

# Create the plot
with sns.axes_style("whitegrid"), sns.plotting_context(font_scale=3):
fig, ax = plt.subplots(figsize=(16, 8))
sns.despine(ax=ax, bottom=True, left=True)
sns.boxenplot(
x="Tissue Type",
y="CBF\n(mL/100 g/min)",
data=df,
width=0.6,
showfliers=True,
palette={"GM": "#1b60a5", "WM": "#2da467", "CSF": "#9d8f25"},
ax=ax,
)
fig.tight_layout()
fig.savefig(self._results["out_file"])
plt.close()

return runtime
70 changes: 7 additions & 63 deletions aslprep/workflows/asl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""Preprocessing workflows for ASL data."""
import nibabel as nb
import numpy as np
from fmriprep.workflows.bold import confounds
from fmriprep.workflows.bold.base import get_estimator
from fmriprep.workflows.bold.registration import init_bold_reg_wf, init_bold_t1_trans_wf
from fmriprep.workflows.bold.resampling import (
Expand All @@ -26,7 +25,7 @@
from aslprep.utils.bids import collect_run_data
from aslprep.utils.misc import _create_mem_gb, _get_wf_name
from aslprep.workflows.asl.cbf import init_compute_cbf_wf, init_parcellate_cbf_wf
from aslprep.workflows.asl.confounds import init_asl_confounds_wf
from aslprep.workflows.asl.confounds import init_asl_confounds_wf, init_carpetplot_wf
from aslprep.workflows.asl.hmc import init_asl_hmc_wf
from aslprep.workflows.asl.outputs import init_asl_derivatives_wf
from aslprep.workflows.asl.plotting import init_plot_cbf_wf
Expand All @@ -35,60 +34,6 @@
from aslprep.workflows.asl.util import init_asl_reference_wf, init_validate_asl_wf


class OverrideConfoundsDerivativesDataSink:
"""A context manager for temporarily overriding the definition of SomeClass.
Parameters
----------
None
Attributes
----------
original_class (type): The original class that is replaced during the override.
Methods
-------
__enter__()
Enters the context manager and performs the class override.
__exit__(exc_type, exc_value, traceback)
Exits the context manager and restores the original class definition.
"""

def __enter__(self):
"""Enter the context manager and perform the class override.
Returns
-------
OverrideConfoundsDerivativesDataSink
The instance of the context manager.
"""
# Save the original class
self.original_class = confounds.DerivativesDataSink
# Replace SomeClass with YourOwnClass
confounds.DerivativesDataSink = DerivativesDataSink
return self

def __exit__(self, exc_type, exc_value, traceback): # noqa: U100
"""Exit the context manager and restore the original class definition.
Parameters
----------
exc_type : type
The type of the exception (if an exception occurred).
exc_value : Exception
The exception instance (if an exception occurred).
traceback : traceback
The traceback information (if an exception occurred).
Returns
-------
None
"""
# Restore the original class
confounds.DerivativesDataSink = self.original_class


def init_asl_preproc_wf(asl_file, has_fieldmap=False):
"""Perform the functional preprocessing stages of ASLPrep.
Expand Down Expand Up @@ -1065,13 +1010,12 @@ def init_asl_preproc_wf(asl_file, has_fieldmap=False):
# Standard-space outputs requested.
# Since ASLPrep automatically includes MNI152NLin2009cAsym, this should always be reached.
if spaces.get_spaces(nonstandard=False, dim=(3,)):
with OverrideConfoundsDerivativesDataSink():
carpetplot_wf = confounds.init_carpetplot_wf(
mem_gb=mem_gb["resampled"],
metadata=metadata,
cifti_output=config.workflow.cifti_output,
name="carpetplot_wf",
)
carpetplot_wf = init_carpetplot_wf(
mem_gb=mem_gb["resampled"],
metadata=metadata,
cifti_output=config.workflow.cifti_output,
name="carpetplot_wf",
)

# Xform to "MNI152NLin2009cAsym" is always computed.
carpetplot_select_std = pe.Node(
Expand Down
11 changes: 3 additions & 8 deletions aslprep/workflows/asl/confounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,8 @@ def init_carpetplot_wf(
cifti_output: bool,
name: str = "bold_carpet_wf",
):
"""Build a workflow to generate *carpet* plots.
XXX: Copied from fMRIPrep. Needs to be replaced with some version that works for ASLPrep.
TODO: Find a solution that directly uses fMRIPrep's.
"""
Build a workflow to generate *carpet* plots.
Resamples the MNI parcellation (ad-hoc parcellation derived from the
Harvard-Oxford template and others).
Expand Down Expand Up @@ -389,10 +387,7 @@ def init_carpetplot_wf(
)
ds_report_bold_conf = pe.Node(
DerivativesDataSink(
desc="carpetplot",
datatype="figures",
extension="svg",
dismiss_entities=("echo",),
desc="carpetplot", datatype="figures", extension="svg", dismiss_entities=("echo",)
),
name="ds_report_bold_conf",
run_without_submitting=True,
Expand Down

0 comments on commit ada684f

Please sign in to comment.