Skip to content

Commit

Permalink
[FIX] Update regexp dwi connectome utility function (#1070)
Browse files Browse the repository at this point in the history
* update regexp for dwi preproc reading in DWIConnectome utility functions

* add unit tests

* refactor a bit

* add unit tests

* add some more tests
  • Loading branch information
NicolasGensollen authored Feb 6, 2024
1 parent a62db61 commit bf13349
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 77 deletions.
140 changes: 63 additions & 77 deletions clinica/pipelines/dwi_connectome/dwi_connectome_utils.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,101 @@
def get_luts():
import os
from pathlib import Path

from clinica.utils.exceptions import ClinicaException

try:
# For aparc+aseg.mgz file:
default = os.path.join(os.environ["FREESURFER_HOME"], "FreeSurferColorLUT.txt")
# For aparc.a2009s+aseg.mgz file:
a2009s = os.path.join(os.environ["FREESURFER_HOME"], "FreeSurferColorLUT.txt")
def get_luts() -> list:
from pathlib import Path

# TODO: Add custom Lausanne2008 LUTs here.
except KeyError:
raise ClinicaException("Could not find FREESURFER_HOME environment variable.")
return [default, a2009s]
from clinica.utils.check_dependency import check_environment_variable

freesurfer_home = Path(check_environment_variable("FREESURFER_HOME", "Freesurfer"))

def get_conversion_luts_offline():
# TODO: use this function if no internet connect found in client (need to upload files to clinica repository)
return
return [
str(freesurfer_home / "FreeSurferColorLUT.txt"),
str(freesurfer_home / "FreeSurferColorLUT.txt"),
]


def get_conversion_luts():
from os import pardir
from os.path import abspath, dirname, join
def get_conversion_luts() -> list:
from pathlib import Path

from clinica.utils.inputs import RemoteFileStructure, fetch_file
from clinica.utils.stream import cprint

root = dirname(abspath(join(abspath(__file__), pardir, pardir)))

path_to_mappings = Path(root) / "resources" / "mappings"

url_mrtrix = "https://raw.githubusercontent.com/MRtrix3/mrtrix3/master/share/mrtrix3/labelconvert/"

fs_default = RemoteFileStructure(
filename="fs_default.txt",
url=url_mrtrix,
checksum="6ee07088915fdbcf52b05147ddae86e5fcaf3efc63db5b0ba8f361637dfa11ef",
path_to_mappings = (
Path(__file__).resolve().parent.parent.parent / "resources" / "mappings"
)
resulting_paths = []
for filename in ("fs_default.txt", "fs_a2009s.txt"):
file_path = path_to_mappings / filename
if not file_path.is_file():
file_path = _download_mrtrix3_file(filename, path_to_mappings)
resulting_paths.append(str(file_path))
return resulting_paths

fs_a2009s = RemoteFileStructure(
filename="fs_a2009s.txt",
url=url_mrtrix,
checksum="b472f09cfe92ac0b6694fb6b00a87baf15dd269566e4a92b8a151ff1080bf170",
)

ref_fs_default = path_to_mappings / Path(fs_default.filename)
ref_fs_a2009 = path_to_mappings / Path(fs_a2009s.filename)
def _download_mrtrix3_file(filename: str, path_to_mappings: Path) -> str:
from clinica.utils.inputs import RemoteFileStructure, fetch_file
from clinica.utils.stream import cprint

if not (ref_fs_default.is_file()):
try:
ref_fs_default = fetch_file(fs_default, path_to_mappings)
except IOError as err:
cprint(
msg=f"Unable to download required MRTRIX mapping (fs_default.txt) for processing: {err}",
lvl="error",
)
if not (ref_fs_a2009.is_file()):
try:
ref_fs_a2009 = fetch_file(fs_a2009s, path_to_mappings)
except IOError as err:
cprint(
msg=f"Unable to download required MRTRIX mapping (fs_a2009s.txt) for processing: {err}",
lvl="error",
)
try:
return fetch_file(
RemoteFileStructure(
filename=filename,
url="https://raw.githubusercontent.com/MRtrix3/mrtrix3/master/share/mrtrix3/labelconvert/",
checksum=_get_checksum_for_filename(filename),
),
str(path_to_mappings),
)
except IOError as err:
error_msg = f"Unable to download required MRTRIX mapping ({filename}) for processing: {err}"
cprint(msg=error_msg, lvl="error")
raise IOError(error_msg)

return [ref_fs_default, ref_fs_a2009]

def _get_checksum_for_filename(filename: str) -> str:
if filename == "fs_default.txt":
return "a8d561694887a1ca8d9df223aa5ef861b6c79d43ce9ed93835b9ce8aadc331b1"
if filename == "fs_a2009s.txt":
return "40b0d4d77bde7e1d265439347af5b30cc973748c1a88d203d7044cb35b3863e1"
raise ValueError(f"File name {filename} is not supported.")

def get_containers(subjects, sessions):
import os

def get_containers(subjects: list, sessions: list) -> list:
from pathlib import Path

return [
os.path.join("subjects", subjects[i], sessions[i], "dwi")
for i in range(len(subjects))
str(Path("subjects") / subject / session / "dwi")
for subject, session in zip(subjects, sessions)
]


def get_caps_filenames(dwi_file: str):
def get_caps_filenames(dwi_file: str) -> tuple:
import re

m = re.search(r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*)_preproc", dwi_file)
if not m:
raise ValueError(
f"Input filename {dwi_file} is not in a CAPS compliant format."
error_msg = f"Input filename {dwi_file} is not in a CAPS compliant format."
if (
m := re.search(
r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*_desc-preproc*)_dwi", dwi_file
)
) is None:
raise ValueError(error_msg)
source_file_caps = m.group(1)

m = re.search(
r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*)_space-[a-zA-Z0-9]+_preproc", dwi_file
)
if not m:
raise ValueError(
f"Input filename {dwi_file} is not in a CAPS compliant format."
if (
m := re.search(
r"/(sub-[a-zA-Z0-9]+_ses-[a-zA-Z0-9]+.*)_space-[a-zA-Z0-9]+_desc-preproc_dwi",
dwi_file,
)
) is None:
raise ValueError(error_msg)
source_file_bids = m.group(1)

response = f"{source_file_caps}_model-CSD_responseFunction.txt"
fod = f"{source_file_caps}_model-CSD_diffmodel.nii.gz"
tracts = f"{source_file_caps}_model-CSD_tractography.tck"
nodes = [
f"{source_file_caps}_atlas-desikan_parcellation.nii.gz",
f"{source_file_caps}_atlas-destrieux_parcellation.nii.gz",
f"{source_file_caps}_atlas-{atlas}_parcellation.nii.gz"
for atlas in ("desikan", "destrieux")
]
# TODO: Add custom Lausanne2008 node files here.
connectomes = [
f"{source_file_bids}_model-CSD_atlas-desikan_connectivity.tsv",
f"{source_file_bids}_model-CSD_atlas-destrieux_connectivity.tsv",
f"{source_file_bids}_model-CSD_atlas-{atlas}_connectivity.tsv"
for atlas in ("desikan", "destrieux")
]
# TODO: Add custom Lausanne2008 connectome files here.

return response, fod, tracts, nodes, connectomes

Expand Down
154 changes: 154 additions & 0 deletions test/unittests/pipelines/dwi_connectome/test_dwi_connectome_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import pytest


def test_get_luts(mocker):
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_luts

mocked_freesurfer_home = "/Applications/freesurfer/7.2.0"
mocker.patch(
"clinica.utils.check_dependency.check_environment_variable",
return_value=mocked_freesurfer_home,
)
assert get_luts() == [f"{mocked_freesurfer_home}/FreeSurferColorLUT.txt"] * 2


@pytest.mark.parametrize(
"filename,expected_checksum",
[
(
"fs_default.txt",
"a8d561694887a1ca8d9df223aa5ef861b6c79d43ce9ed93835b9ce8aadc331b1",
),
(
"fs_a2009s.txt",
"40b0d4d77bde7e1d265439347af5b30cc973748c1a88d203d7044cb35b3863e1",
),
],
)
def test_get_checksum_for_filename(filename, expected_checksum):
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import (
_get_checksum_for_filename,
)

assert _get_checksum_for_filename(filename) == expected_checksum


def test_get_checksum_for_filename_error():
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import (
_get_checksum_for_filename,
)

with pytest.raises(ValueError, match="File name foo.txt is not supported."):
_get_checksum_for_filename("foo.txt")


@pytest.mark.parametrize(
"filename,expected_length", [("fs_default.txt", 112), ("fs_a2009s.txt", 192)]
)
def test_download_mrtrix3_file(tmp_path, filename, expected_length):
"""Atm this test needs an internet connection to download the files.
TODO: Use mocking in the fetch_file function to remove this necessity.
"""
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import (
_download_mrtrix3_file,
)

_download_mrtrix3_file(filename, tmp_path)

assert [f.name for f in tmp_path.iterdir()] == [filename]
assert len((tmp_path / filename).read_text().split("\n")) == expected_length


def test_download_mrtrix3_file_error(tmp_path, mocker):
import re

from clinica.pipelines.dwi_connectome.dwi_connectome_utils import (
_download_mrtrix3_file,
)

mocker.patch(
"clinica.pipelines.dwi_connectome.dwi_connectome_utils._get_checksum_for_filename",
return_value="foo",
)
mocker.patch("clinica.utils.inputs.fetch_file", side_effect=IOError)

with pytest.raises(
IOError,
match=re.escape(
"Unable to download required MRTRIX mapping (foo.txt) for processing"
),
):
_download_mrtrix3_file("foo.txt", tmp_path)


def test_get_conversion_luts():
from pathlib import Path

from clinica.pipelines.dwi_connectome.dwi_connectome_utils import (
get_conversion_luts,
)

luts = [Path(_) for _ in get_conversion_luts()]

assert [p.name for p in luts] == ["fs_default.txt", "fs_a2009s.txt"]
assert all([p.is_file() for p in luts])


@pytest.mark.parametrize(
"filename",
[
"foo.txt",
"dwi.nii.gz",
"sub-01_ses-M000_dwi.nii.gz",
"sub-01_ses-M000_preproc.nii.gz",
"sub-01_ses-M000_space-T1w_preproc.nii.gz",
"sub-01_ses-M000_space-b0_preproc.nii.gz",
],
)
def test_get_caps_filenames_error(tmp_path, filename):
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_caps_filenames

with pytest.raises(ValueError, match="is not in a CAPS compliant format."):
get_caps_filenames(str(tmp_path / filename))


def test_get_caps_filenames(tmp_path):
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_caps_filenames

dwi_caps = tmp_path / "dwi" / "preprocessing"
dwi_caps.mkdir(parents=True)

assert get_caps_filenames(
str(dwi_caps / "sub-01_ses-M000_space-b0_desc-preproc_dwi.nii.gz")
) == (
"sub-01_ses-M000_space-b0_desc-preproc_model-CSD_responseFunction.txt",
"sub-01_ses-M000_space-b0_desc-preproc_model-CSD_diffmodel.nii.gz",
"sub-01_ses-M000_space-b0_desc-preproc_model-CSD_tractography.tck",
[
"sub-01_ses-M000_space-b0_desc-preproc_atlas-desikan_parcellation.nii.gz",
"sub-01_ses-M000_space-b0_desc-preproc_atlas-destrieux_parcellation.nii.gz",
],
[
"sub-01_ses-M000_model-CSD_atlas-desikan_connectivity.tsv",
"sub-01_ses-M000_model-CSD_atlas-destrieux_connectivity.tsv",
],
)


@pytest.mark.parametrize(
"subjects,sessions,expected",
[
([], [], []),
(["foo"], ["bar"], ["subjects/foo/bar/dwi"]),
(
["sub-01", "sub-02"],
["ses-M000", "ses-M006"],
["subjects/sub-01/ses-M000/dwi", "subjects/sub-02/ses-M006/dwi"],
),
],
)
def test_get_containers(subjects, sessions, expected):
from clinica.pipelines.dwi_connectome.dwi_connectome_utils import get_containers

assert get_containers(subjects, sessions) == expected

0 comments on commit bf13349

Please sign in to comment.