Skip to content

Commit

Permalink
[ENH] Simplify imports for DWI pipelines (#1072)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasGensollen authored Feb 7, 2024
1 parent 1ad79ac commit f8c0aba
Show file tree
Hide file tree
Showing 24 changed files with 101 additions and 149 deletions.
2 changes: 1 addition & 1 deletion clinica/pipelines/dwi_connectome/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import dwi_connectome_cli
from . import cli
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def cli(

from clinica.utils.ux import print_end_pipeline

from .dwi_connectome_pipeline import DwiConnectome
from .pipeline import DwiConnectome

parameters = {"n_tracks": n_tracks}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _build_output_node(self):
import nipype.interfaces.utility as nutil
import nipype.pipeline.engine as npe

import clinica.pipelines.dwi_connectome.dwi_connectome_utils as utils
from .utils import get_containers

join_node = npe.JoinNode(
name="JoinOutputs",
Expand All @@ -212,7 +212,7 @@ def _build_output_node(self):
),
)
write_node.inputs.base_directory = str(self.caps_directory)
write_node.inputs.container = utils.get_containers(self.subjects, self.sessions)
write_node.inputs.container = get_containers(self.subjects, self.sessions)
write_node.inputs.substitutions = [("trait_added", "")]
write_node.inputs.parameterization = False

Expand Down Expand Up @@ -271,12 +271,19 @@ def _build_core_nodes(self):
Tractography,
)

import clinica.pipelines.dwi_connectome.dwi_connectome_utils as utils
from clinica.utils.exceptions import ClinicaCAPSError
from clinica.utils.mri_registration import (
convert_flirt_transformation_to_mrtrix_transformation,
)

from .utils import (
get_caps_filenames,
get_conversion_luts,
get_luts,
print_begin_pipeline,
print_end_pipeline,
)

# Nodes
# =====
# B0 Extraction (only if space=b0)
Expand Down Expand Up @@ -328,8 +335,8 @@ def _build_core_nodes(self):
iterfield=["in_file", "in_config", "in_lut", "out_file"],
interface=mrtrix3.LabelConvert(),
)
label_convert_node.inputs.in_config = utils.get_conversion_luts()
label_convert_node.inputs.in_lut = utils.get_luts()
label_convert_node.inputs.in_config = get_conversion_luts()
label_convert_node.inputs.in_lut = get_luts()

# FSL flirt matrix to MRtrix matrix Conversion (only if space=b0)
# --------------------------------------------
Expand Down Expand Up @@ -391,7 +398,7 @@ def _build_core_nodes(self):
print_begin_message = npe.MapNode(
interface=niu.Function(
input_names=["in_bids_or_caps_file"],
function=utils.print_begin_pipeline,
function=print_begin_pipeline,
),
iterfield="in_bids_or_caps_file",
name="WriteBeginMessage",
Expand All @@ -402,7 +409,7 @@ def _build_core_nodes(self):
print_end_message = npe.MapNode(
interface=niu.Function(
input_names=["in_bids_or_caps_file", "final_file"],
function=utils.print_end_pipeline,
function=print_end_pipeline,
),
iterfield=["in_bids_or_caps_file"],
name="WriteEndMessage",
Expand All @@ -415,7 +422,7 @@ def _build_core_nodes(self):
interface=niu.Function(
input_names="dwi_file",
output_names=self.get_output_fields(),
function=utils.get_caps_filenames,
function=get_caps_filenames,
),
)
self.connect(
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion clinica/pipelines/dwi_dti/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import dwi_dti_cli
from . import cli
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def cli(

from clinica.utils.ux import print_end_pipeline

from .dwi_dti_pipeline import DwiDti
from .pipeline import DwiDti

pipeline = DwiDti(
caps_directory=caps_directory,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def _build_output_node(self):

from clinica.utils.nipype import container_from_filename, fix_join

from .dwi_dti_utils import rename_into_caps
from .utils import rename_into_caps

# Find container path from filename
container_path = npe.Node(
Expand Down Expand Up @@ -241,7 +241,6 @@ def _build_output_node(self):

def _build_core_nodes(self):
"""Build and connect the core nodes of the pipeline."""
import os
from pathlib import Path

import nipype.interfaces.fsl as fsl
Expand All @@ -255,7 +254,7 @@ def _build_core_nodes(self):
from clinica.utils.check_dependency import check_environment_variable
from clinica.utils.dwi import extract_bids_identifier_from_filename

from .dwi_dti_utils import (
from .utils import (
get_ants_transforms,
get_caps_filenames,
print_begin_pipeline,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def statistics_on_atlases(in_registered_map, name_map, prefix_file=None):
Returns:
List of paths leading to the statistics TSV files.
"""
import os
from pathlib import Path

from nipype.utils.filemanip import split_filename

Expand Down Expand Up @@ -41,8 +41,7 @@ def statistics_on_atlases(in_registered_map, name_map, prefix_file=None):
f"_res-{atlas.get_spatial_resolution()}_map-{name_map}_statistics.tsv"
)

out_atlas_statistics = os.path.abspath(os.path.join(os.getcwd(), filename))

out_atlas_statistics = str((Path.cwd() / filename).resolve())
statistics_on_atlas(in_registered_map, atlas, out_atlas_statistics)
atlas_statistics_list.append(out_atlas_statistics)

Expand Down
2 changes: 1 addition & 1 deletion clinica/pipelines/dwi_preprocessing_using_fmap/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import dwi_preprocessing_using_phasediff_fmap_cli
from . import cli
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def cli(

from clinica.utils.ux import print_end_pipeline

from .dwi_preprocessing_using_phasediff_fmap_pipeline import (
DwiPreprocessingUsingPhaseDiffFMap,
)
from .pipeline import DwiPreprocessingUsingPhaseDiffFMap

parameters = {
"low_bval": low_bval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def _build_output_node(self):

from clinica.utils.nipype import container_from_filename, fix_join

from .dwi_preprocessing_using_phasediff_fmap_utils import rename_into_caps
from .utils import rename_into_caps

container_path = npe.Node(
nutil.Function(
Expand Down Expand Up @@ -261,19 +261,13 @@ def _build_core_nodes(self):
import nipype.interfaces.utility as nutil
import nipype.pipeline.engine as npe

from clinica.pipelines.dwi_preprocessing_using_t1.dwi_preprocessing_using_t1_workflows import (
from clinica.pipelines.dwi_preprocessing_using_t1.workflows import (
eddy_fsl_pipeline,
)
from clinica.utils.dwi import compute_average_b0

from .dwi_preprocessing_using_phasediff_fmap_utils import (
init_input_node,
print_end_pipeline,
)
from .dwi_preprocessing_using_phasediff_fmap_workflows import (
calibrate_and_register_fmap,
compute_reference_b0,
)
from .utils import init_input_node, print_end_pipeline
from .workflows import calibrate_and_register_fmap, compute_reference_b0

init_node = npe.Node(
interface=nutil.Function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,41 +170,39 @@ def rads2hz(in_file: str, delta_te: float, out_file: str = None) -> str:
Path to output file.
"""
import math
import os

import nibabel as nb
import numpy as np

if out_file is None:
fname, fext = os.path.splitext(os.path.basename(in_file))
if fext == ".gz":
fname, _ = os.path.splitext(fname)
out_file = os.path.abspath(f"./{fname}_radsec.nii.gz")

im = nb.load(in_file)
out_file = out_file or _get_output_file(in_file, "radsec")
data = im.get_fdata().astype(np.float32) * (1.0 / (float(delta_te) * 2 * math.pi))
nb.Nifti1Image(data, im.affine, im.header).to_filename(out_file)

return out_file


def _get_output_file(input_file: str, suffix: str) -> str:
from pathlib import Path

input_file = Path(input_file)
filename = input_file.name
if input_file.suffix == ".gz":
filename = Path(filename).name

return str(Path(f"./{filename}_{suffix}.nii.gz").resolve())


def demean_image(in_file: str, in_mask: str = None, out_file: str = None) -> str:
"""Demean image data inside mask.
This function was taken from: https://github.com/niflows/nipype1-workflows/
"""
import os.path as op

import nibabel as nb
import numpy as np

if out_file is None:
fname, fext = op.splitext(op.basename(in_file))
if fext == ".gz":
fname, _ = op.splitext(fname)
out_file = op.abspath("./%s_demean.nii.gz" % fname)

im = nb.load(in_file)
out_file = out_file or _get_output_file(in_file, "demean")
data = im.get_fdata().astype(np.float32)
mask = np.ones_like(data)

Expand All @@ -216,6 +214,7 @@ def demean_image(in_file: str, in_mask: str = None, out_file: str = None) -> str
mean = np.median(data[mask == 1].reshape(-1))
data[mask == 1] = data[mask == 1] - mean
nb.Nifti1Image(data, im.affine, im.header).to_filename(out_file)

return out_file


Expand All @@ -225,17 +224,11 @@ def siemens2rads(in_file: str, out_file: str = None):
This function was taken from: https://github.com/niflows/nipype1-workflows/
"""
import math
import os.path as op

import nibabel as nb
import numpy as np

if out_file is None:
fname, fext = op.splitext(op.basename(in_file))
if fext == ".gz":
fname, _ = op.splitext(fname)
out_file = op.abspath("./%s_rads.nii.gz" % fname)

out_file = out_file or _get_output_file(in_file, "rads")
in_file = np.atleast_1d(in_file).tolist()
im = nb.load(in_file[0])
data = im.get_fdata().astype(np.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,11 @@ def prepare_phasediff_fmap(
import nipype.interfaces.utility as nutil
import nipype.pipeline.engine as npe

from clinica.pipelines.dwi_preprocessing_using_fmap.dwi_preprocessing_using_phasediff_fmap_workflows import (
cleanup_edge_pipeline,
from clinica.pipelines.dwi_preprocessing_using_fmap.workflows import (
cleanup_edge_pipeline, # noqa
)

from .dwi_preprocessing_using_phasediff_fmap_utils import (
demean_image,
rads2hz,
siemens2rads,
)
from .utils import demean_image, rads2hz, siemens2rads

input_node = npe.Node(
nutil.IdentityInterface(
Expand Down Expand Up @@ -271,12 +267,10 @@ def compute_reference_b0(
import nipype.interfaces.utility as niu
import nipype.pipeline.engine as npe

from clinica.pipelines.dwi_preprocessing_using_t1.dwi_preprocessing_using_t1_workflows import (
eddy_fsl_pipeline,
)
from clinica.pipelines.dwi_preprocessing_using_t1.workflows import eddy_fsl_pipeline
from clinica.utils.dwi import compute_average_b0

from .dwi_preprocessing_using_phasediff_fmap_utils import get_grad_fsl
from .utils import get_grad_fsl

inputnode = npe.Node(
niu.IdentityInterface(
Expand Down
2 changes: 1 addition & 1 deletion clinica/pipelines/dwi_preprocessing_using_t1/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import dwi_preprocessing_using_t1_cli
from . import cli
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def cli(

from clinica.utils.ux import print_end_pipeline

from .dwi_preprocessing_using_t1_pipeline import DwiPreprocessingUsingT1
from .pipeline import DwiPreprocessingUsingT1

parameters = {
"low_bval": low_bval,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _build_output_node(self):

from clinica.utils.nipype import container_from_filename, fix_join

from .dwi_preprocessing_using_t1_utils import rename_into_caps
from .utils import rename_into_caps

container_path = npe.Node(
nutil.Function(
Expand Down Expand Up @@ -210,15 +210,12 @@ def _build_core_nodes(self):

from clinica.utils.dwi import compute_average_b0_task

from .dwi_preprocessing_using_t1_utils import (
from .utils import (
init_input_node,
prepare_reference_b0_task,
print_end_pipeline,
)
from .dwi_preprocessing_using_t1_workflows import (
eddy_fsl_pipeline,
epi_pipeline,
)
from .workflows import eddy_fsl_pipeline, epi_pipeline

# Nodes creation
# ==============
Expand Down
Loading

0 comments on commit f8c0aba

Please sign in to comment.