diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 3403384..fc7ab58 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -125,6 +125,7 @@ jobs: runs-on: ubuntu-latest needs: [build, download-test-data, check_skip_flags] strategy: + fail-fast: false matrix: python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 020351b..d42e759 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,12 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.7.1 + hooks: + - id: mypy + additional_dependencies: [types-all, pandas-stubs, types-tqdm] + args: [--config-file=pyproject.toml] - repo: https://github.com/psf/black-pre-commit-mirror rev: 24.3.0 hooks: diff --git a/docs/source/conf.py b/docs/source/conf.py index 4d28c2d..32f11ac 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -5,7 +5,7 @@ import os import sys -from giga_connectome import __version__, __packagename__, __copyright__ +from giga_connectome import __copyright__, __packagename__, __version__ sys.path.insert(0, os.path.abspath("..")) # -- Project information ----------------------------------------------------- diff --git a/giga_connectome/__init__.py b/giga_connectome/__init__.py index fd61600..549aee5 100644 --- a/giga_connectome/__init__.py +++ b/giga_connectome/__init__.py @@ -6,10 +6,10 @@ except ImportError: pass -from .mask import generate_gm_mask_atlas from .atlas import load_atlas_setting -from .postprocess import run_postprocessing_dataset from .denoise import get_denoise_strategy +from .mask import generate_gm_mask_atlas +from .postprocess import run_postprocessing_dataset __all__ = [ "__copyright__", diff --git a/giga_connectome/atlas.py b/giga_connectome/atlas.py index 063e356..2d4fe33 100644 --- a/giga_connectome/atlas.py +++ b/giga_connectome/atlas.py @@ -1,11 +1,13 @@ -import os -import json -from typing import Union, List +from __future__ import annotations +import json +import os from pathlib import Path +from typing import Any, Dict, List, TypedDict + import nibabel as nib -from nilearn.image import resample_to_img from nibabel import Nifti1Image +from nilearn.image import resample_to_img from pkg_resources import resource_filename from giga_connectome.logger import gc_logger @@ -16,8 +18,25 @@ PRESET_ATLAS = ["DiFuMo", "MIST", "Schaefer20187Networks"] - -def load_atlas_setting(atlas: Union[str, Path, dict]): +ATLAS_CONFIG_TYPE = TypedDict( + "ATLAS_CONFIG_TYPE", + { + "name": str, + "parameters": Dict[str, str], + "desc": List[str], + "templateflow_dir": Any, + }, +) + +ATLAS_SETTING_TYPE = TypedDict( + "ATLAS_SETTING_TYPE", + {"name": str, "file_paths": Dict[str, List[Path]], "type": str}, +) + + +def load_atlas_setting( + atlas: str | Path | dict[str, Any], +) -> ATLAS_SETTING_TYPE: """Load atlas details for templateflow api to fetch. The setting file can be configured for atlases not included in the templateflow collections, but user has to organise their files to @@ -59,19 +78,16 @@ def load_atlas_setting(atlas: Union[str, Path, dict]): import templateflow - if isinstance(atlas_config["desc"], str): - desc = [atlas_config["desc"]] - else: - desc = atlas_config["desc"] - parcellation = {} - for d in desc: + for d in atlas_config["desc"]: p = templateflow.api.get( **atlas_config["parameters"], raise_empty=True, desc=d, extension="nii.gz", ) + if isinstance(p, Path): + p = [p] parcellation[d] = p return { "name": atlas_config["name"], @@ -82,10 +98,10 @@ def load_atlas_setting(atlas: Union[str, Path, dict]): def resample_atlas_collection( template: str, - atlas_config: dict, + atlas_config: ATLAS_SETTING_TYPE, group_mask_dir: Path, group_mask: Nifti1Image, -) -> List[Path]: +) -> list[Path]: """Resample a atlas collection to group grey matter mask. Parameters @@ -105,7 +121,7 @@ def resample_atlas_collection( Returns ------- - List of pathlib.Path + list of pathlib.Path Paths to atlases sampled to group level grey matter mask. """ gc_log.info("Resample atlas to group grey matter mask.") @@ -137,12 +153,14 @@ def resample_atlas_collection( return resampled_atlases -def _check_altas_config(atlas: Union[str, Path, dict]) -> dict: +def _check_altas_config( + atlas: str | Path | dict[str, Any] +) -> ATLAS_CONFIG_TYPE: """Load the configuration file. Parameters ---------- - atlas : Union[str, Path, dict] + atlas : str | Path | dict Atlas name or configuration file path. Returns @@ -158,23 +176,35 @@ def _check_altas_config(atlas: Union[str, Path, dict]) -> dict: # load the file first if the input is not already a dictionary if isinstance(atlas, (str, Path)): if atlas in PRESET_ATLAS: - config_path = resource_filename( - "giga_connectome", f"data/atlas/{atlas}.json" + config_path = Path( + resource_filename( + "giga_connectome", f"data/atlas/{atlas}.json" + ) ) elif Path(atlas).exists(): config_path = Path(atlas) with open(config_path, "r") as file: - atlas = json.load(file) + atlas_config = json.load(file) + else: + atlas_config = atlas - keys = list(atlas.keys()) minimal_keys = ["name", "parameters", "desc", "templateflow_dir"] + keys = list(atlas_config.keys()) common_keys = set(minimal_keys).intersection(set(keys)) - if isinstance(atlas, dict) and common_keys != set(minimal_keys): + if common_keys != set(minimal_keys): raise KeyError( "Invalid dictionary input. Input should" " contain minimally the following keys: 'name', " "'parameters', 'desc', 'templateflow_dir'. Found " f"{keys}" ) - return atlas + + # cast to list of string + if isinstance(atlas_config["desc"], (str, int)): + desc = [atlas_config["desc"]] + else: + desc = atlas_config["desc"] + atlas_config["desc"] = [str(x) for x in desc] + + return atlas_config diff --git a/giga_connectome/connectome.py b/giga_connectome/connectome.py index 5d9628f..ee5d9a1 100644 --- a/giga_connectome/connectome.py +++ b/giga_connectome/connectome.py @@ -1,13 +1,18 @@ +from __future__ import annotations + from pathlib import Path -from typing import Union, Tuple +from typing import Any + import numpy as np -from nilearn.maskers import NiftiMasker -from nilearn.image import load_img from nibabel import Nifti1Image from nilearn.connectome import ConnectivityMeasure +from nilearn.image import load_img +from nilearn.maskers import NiftiMasker -def build_size_roi(mask: np.ndarray, labels_roi: np.ndarray) -> np.ndarray: +def build_size_roi( + mask: np.ndarray[Any, Any], labels_roi: np.ndarray[Any, Any] +) -> np.ndarray[Any, np.dtype[Any]]: """Extract labels and sizes of ROIs given an atlas. The atlas parcels must be discrete segmentations. @@ -41,12 +46,12 @@ def build_size_roi(mask: np.ndarray, labels_roi: np.ndarray) -> np.ndarray: def calculate_intranetwork_correlation( - correlation_matrix: np.array, - masker_labels: np.array, - time_series_atlas: np.array, - group_mask: Union[str, Path, Nifti1Image], - atlas_image: Union[str, Path, Nifti1Image], -) -> Tuple[np.ndarray, np.ndarray]: + correlation_matrix: np.ndarray[Any, Any], + masker_labels: np.ndarray[Any, Any], + time_series_atlas: np.ndarray[Any, Any], + group_mask: str | Path | Nifti1Image, + atlas_image: str | Path | Nifti1Image, +) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: """Calculate the average functional correlation within each parcel. Currently we only support discrete segmentations. @@ -61,15 +66,15 @@ def calculate_intranetwork_correlation( time_series_atlas : np.array Time series extracted from each parcel. - group_mask : Union[str, Path, Nifti1Image] + group_mask : str | Path | Nifti1Image The group grey matter mask. - atlas_image : Union[str, Path, Nifti1Image] + atlas_image : str | Path | Nifti1Image 3D atlas image. Returns ------- - Tuple[np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray] A tuple containing the modified Pearson's correlation matrix with the diagonal replaced by the average correlation within each parcel, and an array of the computed average intranetwork correlations for @@ -106,10 +111,10 @@ def calculate_intranetwork_correlation( def generate_timeseries_connectomes( masker: NiftiMasker, denoised_img: Nifti1Image, - group_mask: Union[str, Path], + group_mask: str | Path, correlation_measure: ConnectivityMeasure, calculate_average_correlation: bool, -) -> Tuple[np.ndarray, np.ndarray]: +) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]: """Generate timeseries-based connectomes from functional data. Parameters @@ -120,7 +125,7 @@ def generate_timeseries_connectomes( denoised_img : Nifti1Image Denoised functional image. - group_mask : Union[str, Path] + group_mask : str | Path Path to the group grey matter mask. correlation_measure : ConnectivityMeasure @@ -131,7 +136,7 @@ def generate_timeseries_connectomes( Returns ------- - Tuple[np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray] A tuple containing the correlation matrix and time series atlas. """ time_series_atlas = masker.fit_transform(denoised_img) diff --git a/giga_connectome/denoise.py b/giga_connectome/denoise.py index 85e7080..e2acb53 100644 --- a/giga_connectome/denoise.py +++ b/giga_connectome/denoise.py @@ -1,17 +1,16 @@ -from typing import Union, Optional +from __future__ import annotations import json from pathlib import Path -import pandas as pd +from typing import Any, Callable, Dict, List, TypedDict, Union + import numpy as np +import pandas as pd from nibabel import Nifti1Image - from nilearn.interfaces import fmriprep from nilearn.maskers import NiftiMasker - from pkg_resources import resource_filename - PRESET_STRATEGIES = [ "simple", "simple+gsr", @@ -23,10 +22,30 @@ "icaaroma", ] +# More refined type not possible with python <= 3.9? +# STRATEGY_TYPE = TypedDict( +# "STRATEGY_TYPE", +# { +# "name": str, +# "function": Callable[ +# ..., tuple[pd.DataFrame, Union[np.ndarray[Any, Any], None]] +# ], +# "parameters": dict[str, str | list[str]], +# }, +# ) +STRATEGY_TYPE = TypedDict( + "STRATEGY_TYPE", + { + "name": str, + "function": Callable[..., Any], + "parameters": Dict[str, Union[str, List[str]]], + }, +) + def get_denoise_strategy( strategy: str, -) -> dict: +) -> STRATEGY_TYPE: """ Select denoise strategies and associated parameters. The strategy parameters are designed to pass to load_confounds_strategy. @@ -47,7 +66,7 @@ def get_denoise_strategy( Denosing strategy parameter to pass to load_confounds_strategy. """ if strategy in PRESET_STRATEGIES: - config_path = resource_filename( + config_path: str | Path = resource_filename( "giga_connectome", f"data/denoise_strategy/{strategy}.json" ) elif Path(strategy).exists(): @@ -63,7 +82,7 @@ def get_denoise_strategy( return benchmark_strategy -def is_ica_aroma(strategy: str) -> bool: +def is_ica_aroma(strategy: STRATEGY_TYPE) -> bool: """Check if the current strategy is ICA AROMA. Parameters @@ -79,19 +98,17 @@ def is_ica_aroma(strategy: str) -> bool: strategy_preset = strategy["parameters"].get("denoise_strategy", False) strategy_user_define = strategy["parameters"].get("strategy", False) if strategy_preset or strategy_user_define: - return ( - strategy_preset == "ica_aroma" - if strategy_preset - else "ica_aroma" in strategy_user_define - ) + return strategy_preset == "ica_aroma" + elif isinstance(strategy_user_define, list): + return "ica_aroma" in strategy_user_define else: raise ValueError(f"Invalid input dictionary. {strategy['parameters']}") def denoise_nifti_voxel( - strategy: dict, - group_mask: Union[str, Path], - standardize: Union[str, bool], + strategy: STRATEGY_TYPE, + group_mask: str | Path, + standardize: str | bool, smoothing_fwhm: float, img: str, ) -> Nifti1Image: @@ -101,9 +118,9 @@ def denoise_nifti_voxel( ---------- strategy : dict Denoising strategy parameter to pass to load_confounds_strategy. - group_mask : Union[str, Path] + group_mask : str | Path Path to the group mask. - standardize : Union[str, bool] + standardize : str | bool Standardize the data. If 'zscore', zscore the data. If 'psc', convert the data to percent signal change. If False, do not standardize. smoothing_fwhm : float @@ -138,7 +155,8 @@ def denoise_nifti_voxel( def _check_exclusion( - reduced_confounds: pd.DataFrame, sample_mask: Optional[np.ndarray] + reduced_confounds: pd.DataFrame, + sample_mask: np.ndarray[Any, Any] | None, ) -> bool: """For scrubbing based strategy, check if regression can be performed.""" if sample_mask is not None: diff --git a/giga_connectome/mask.py b/giga_connectome/mask.py index 6a12398..6678975 100644 --- a/giga_connectome/mask.py +++ b/giga_connectome/mask.py @@ -1,35 +1,38 @@ +from __future__ import annotations + import os import re -from typing import Optional, Union, List, Tuple -from bids.layout import BIDSImageFile - from pathlib import Path +from typing import Any, Sequence + import nibabel as nib -from nilearn.masking import compute_multi_epi_mask +import numpy as np +from bids.layout import BIDSImageFile +from nibabel import Nifti1Image from nilearn.image import ( - resample_to_img, - new_img_like, get_data, - math_img, load_img, + math_img, + new_img_like, + resample_to_img, ) -from nibabel import Nifti1Image -import numpy as np +from nilearn.masking import compute_multi_epi_mask from scipy.ndimage import binary_closing from giga_connectome.atlas import resample_atlas_collection - from giga_connectome.logger import gc_logger +from giga_connectome.atlas import ATLAS_SETTING_TYPE + gc_log = gc_logger() def generate_gm_mask_atlas( working_dir: Path, - atlas: dict, + atlas: ATLAS_SETTING_TYPE, template: str, - masks: List[BIDSImageFile], -) -> Tuple[Path, List[Path]]: + masks: list[BIDSImageFile], +) -> tuple[Path, list[Path]]: """ """ # check masks; isolate this part and make sure to make it a validate # templateflow template with a config file @@ -63,9 +66,9 @@ def generate_gm_mask_atlas( def generate_group_mask( - imgs: list, + imgs: Sequence[Path | str | Nifti1Image], template: str = "MNI152NLin2009cAsym", - templateflow_dir: Optional[Path] = None, + templateflow_dir: Path | None = None, n_iter: int = 2, ) -> Nifti1Image: """ @@ -76,8 +79,8 @@ def generate_group_mask( Parameters ---------- - imgs : list of string - List of EPI masks or preprocessed BOLD data. + imgs : list of Path or str or Nifti1Image + list of EPI masks or preprocessed BOLD data. template : str, Default = MNI152NLin2009cAsym Template name from TemplateFlow to retrieve the grey matter template. @@ -169,8 +172,8 @@ def generate_group_mask( def _get_consistent_masks( - mask_imgs: List[Union[Path, str, Nifti1Image]], exclude: List[int] -) -> Tuple[List[int], List[str]]: + mask_imgs: Sequence[Path | str | Nifti1Image], exclude: list[int] +) -> tuple[list[Path | str | Any], list[str]]: """Create a list of masks that has the same affine. Parameters @@ -180,14 +183,14 @@ def _get_consistent_masks( The original list of functional masks exclude : - List of index to exclude. + list of index to exclude. Returns ------- - List of str + list of str Functional masks with the same affine. - List of str + list of str Identifiers of scans with a different affine. """ weird_mask_identifiers = [] @@ -196,14 +199,13 @@ def _get_consistent_masks( for odd_file in odd_masks: identifier = Path(odd_file).name.split("_space")[0] weird_mask_identifiers.append(identifier) - cleaned_func_masks = set(mask_imgs) - set(odd_masks) - cleaned_func_masks = list(cleaned_func_masks) + cleaned_func_masks = list(set(mask_imgs) - set(odd_masks)) return cleaned_func_masks, weird_mask_identifiers def _check_mask_affine( - mask_imgs: List[Union[Path, str, Nifti1Image]] -) -> Union[list, None]: + mask_imgs: Sequence[Path | str | Nifti1Image], +) -> list[int] | None: """Given a list of input mask images, show the most common affine matrix and subjects with different values. @@ -216,12 +218,12 @@ def _check_mask_affine( Returns ------- - List or None + list or None Index of masks with odd affine matrix. Return None when all masks have the same affine matrix. """ # save all header and affine info in hashable type... - header_info = {"affine": []} + header_info: dict[str, list[str]] = {"affine": []} key_to_header = {} for this_mask in mask_imgs: img = load_img(this_mask) @@ -232,9 +234,9 @@ def _check_mask_affine( key_to_header[affine_hashable] = affine if isinstance(mask_imgs[0], Nifti1Image): - mask_imgs = np.arange(len(mask_imgs)) + mask_arrays = np.arange(len(mask_imgs)) else: - mask_imgs = np.array(mask_imgs) + mask_arrays = np.array(mask_imgs) # get most common values common_affine = max( set(header_info["affine"]), key=header_info["affine"].count @@ -256,24 +258,26 @@ def _check_mask_affine( gc_log.debug( "The following subjects has a different affine matrix " f"({key_to_header[ob]}) comparing to the most common value: " - f"{mask_imgs[ob_index]}." + f"{mask_arrays[ob_index]}." ) exclude += ob_index gc_log.info( - f"{len(exclude)} out of {len(mask_imgs)} has " + f"{len(exclude)} out of {len(mask_arrays)} has " "different affine matrix. Ignore when creating group mask." ) return sorted(exclude) -def _check_pregenerated_masks(template, working_dir, atlas): +def _check_pregenerated_masks( + template: str, working_dir: Path, atlas: ATLAS_SETTING_TYPE +) -> tuple[Path | None, list[Path] | None]: """Check if the working directory is populated with needed files.""" output_dir = working_dir / "groupmasks" / f"tpl-{template}" - group_mask = ( + group_mask: Path | None = ( output_dir / f"tpl-{template}_res-dataset_label-GM_desc-group_mask.nii.gz" ) - if not group_mask.exists(): + if group_mask and not group_mask.exists(): group_mask = None else: gc_log.info( @@ -281,7 +285,7 @@ def _check_pregenerated_masks(template, working_dir, atlas): ) # atlas - resampled_atlases = [] + resampled_atlases: list[Path] = [] for desc in atlas["file_paths"]: filename = ( f"tpl-{template}_" @@ -293,7 +297,7 @@ def _check_pregenerated_masks(template, working_dir, atlas): resampled_atlases.append(output_dir / filename) all_exist = [file_path.exists() for file_path in resampled_atlases] if not all(all_exist): - resampled_atlases = None + return group_mask, None else: gc_log.info( f"Found resampled atlases:\n{[str(x) for x in resampled_atlases]}." diff --git a/giga_connectome/postprocess.py b/giga_connectome/postprocess.py index 488b3f7..4019839 100644 --- a/giga_connectome/postprocess.py +++ b/giga_connectome/postprocess.py @@ -1,15 +1,17 @@ -from typing import Union, List +from __future__ import annotations + from pathlib import Path +from typing import Any, Sequence import h5py import numpy as np +from bids.layout import BIDSImageFile from nilearn.connectome import ConnectivityMeasure from nilearn.maskers import NiftiLabelsMasker, NiftiMapsMasker -from bids.layout import BIDSImageFile from giga_connectome import utils from giga_connectome.connectome import generate_timeseries_connectomes -from giga_connectome.denoise import denoise_nifti_voxel +from giga_connectome.denoise import STRATEGY_TYPE, denoise_nifti_voxel from giga_connectome.logger import gc_logger from giga_connectome.utils import progress_bar @@ -17,11 +19,11 @@ def run_postprocessing_dataset( - strategy: dict, - resampled_atlases: List[Union[str, Path]], - images: List[BIDSImageFile], - group_mask: Union[str, Path], - standardize: Union[str, bool], + strategy: STRATEGY_TYPE, + resampled_atlases: Sequence[str | Path], + images: Sequence[BIDSImageFile], + group_mask: str | Path, + standardize: str | bool, smoothing_fwhm: float, output_path: Path, analysis_level: str, @@ -88,7 +90,8 @@ def run_postprocessing_dataset( Whether to calculate average correlation within each parcel. """ atlas = output_path.name.split("atlas-")[-1].split("_")[0] - atlas_maskers, connectomes = {}, {} + atlas_maskers: dict[str, (NiftiLabelsMasker | NiftiMapsMasker)] = {} + connectomes: dict[str, list[np.ndarray[Any, Any]]] = {} for atlas_path in resampled_atlases: if isinstance(atlas_path, str): atlas_path = Path(atlas_path) @@ -175,33 +178,31 @@ def run_postprocessing_dataset( def _set_file_flag(output_path: Path) -> str: """Find out if new file needs to be created.""" - flag = "w" - if output_path.exists(): - flag = "a" + flag = "a" if output_path.exists() else "w" return flag def _fetch_h5_group( - f: h5py.File, subject: str, session: str -) -> Union[h5py.File, h5py.Group]: + file: h5py.File, subject: str, session: str | None +) -> h5py.File | h5py.Group: """Determine the level of grouping based on BIDS standard.""" - if subject not in f: + if subject not in file: return ( - f.create_group(f"{subject}/{session}") + file.create_group(f"{subject}/{session}") if session - else f.create_group(f"{subject}") + else file.create_group(subject) ) elif session: return ( - f[f"{subject}"].create_group(f"{session}") - if session not in f[f"{subject}"] - else f[f"{subject}/{session}"] + file[subject].create_group(session) + if session not in file[subject] + else file[f"{subject}/{session}"] ) else: - return f[f"{subject}"] + return file[subject] -def _get_masker(atlas_path: Path) -> Union[NiftiLabelsMasker, NiftiMapsMasker]: +def _get_masker(atlas_path: Path) -> NiftiLabelsMasker | NiftiMapsMasker: """Get the masker object based on the templateflow file name suffix.""" atlas_type = atlas_path.name.split("_")[-1].split(".nii")[0] if atlas_type == "dseg": diff --git a/giga_connectome/run.py b/giga_connectome/run.py index 2b1617d..901e37c 100644 --- a/giga_connectome/run.py +++ b/giga_connectome/run.py @@ -1,7 +1,11 @@ +from __future__ import annotations + import argparse from pathlib import Path -from giga_connectome.workflow import workflow +from typing import Sequence + from giga_connectome import __version__ +from giga_connectome.workflow import workflow def global_parser() -> argparse.ArgumentParser: @@ -118,7 +122,7 @@ def global_parser() -> argparse.ArgumentParser: return parser -def main(argv=None): +def main(argv: None | Sequence[str] = None) -> None: """Entry point.""" parser = global_parser() diff --git a/giga_connectome/tests/test_cli.py b/giga_connectome/tests/test_cli.py index fead8b5..0864454 100644 --- a/giga_connectome/tests/test_cli.py +++ b/giga_connectome/tests/test_cli.py @@ -3,13 +3,13 @@ """ from pathlib import Path -from pkg_resources import resource_filename -from giga_connectome.run import main -from giga_connectome import __version__ +import h5py import pytest +from pkg_resources import resource_filename -import h5py +from giga_connectome import __version__ +from giga_connectome.run import main def test_version(capsys): diff --git a/giga_connectome/tests/test_connectome.py b/giga_connectome/tests/test_connectome.py index ac6f7a4..2336caf 100644 --- a/giga_connectome/tests/test_connectome.py +++ b/giga_connectome/tests/test_connectome.py @@ -1,8 +1,9 @@ import numpy as np -from giga_connectome.connectome import generate_timeseries_connectomes from nibabel import Nifti1Image -from nilearn.maskers import NiftiMasker, NiftiLabelsMasker from nilearn.connectome import ConnectivityMeasure +from nilearn.maskers import NiftiLabelsMasker, NiftiMasker + +from giga_connectome.connectome import generate_timeseries_connectomes def _extract_time_series_voxel(img, mask, confounds=None, smoothing_fwhm=None): diff --git a/giga_connectome/tests/test_mask.py b/giga_connectome/tests/test_mask.py index 9afac35..9bd11f6 100644 --- a/giga_connectome/tests/test_mask.py +++ b/giga_connectome/tests/test_mask.py @@ -1,8 +1,9 @@ -import pytest import numpy as np -from giga_connectome import mask -from nilearn import datasets +import pytest from nibabel import Nifti1Image +from nilearn import datasets + +from giga_connectome import mask def test_generate_group_mask(): diff --git a/giga_connectome/tests/test_utils.py b/giga_connectome/tests/test_utils.py index a475f49..ca28a8c 100644 --- a/giga_connectome/tests/test_utils.py +++ b/giga_connectome/tests/test_utils.py @@ -1,9 +1,10 @@ from pathlib import Path + +import pytest from bids.tests import get_test_data_path -from giga_connectome import utils from pkg_resources import resource_filename -import pytest +from giga_connectome import utils def test_get_bids_images(): diff --git a/giga_connectome/utils.py b/giga_connectome/utils.py index 02ea57a..e4b9596 100644 --- a/giga_connectome/utils.py +++ b/giga_connectome/utils.py @@ -1,8 +1,10 @@ -from typing import List, Tuple, Union +from __future__ import annotations + from pathlib import Path -from nilearn.interfaces.bids import parse_bids_filename -from bids.layout import Query + from bids import BIDSLayout +from bids.layout import BIDSFile, Query +from nilearn.interfaces.bids import parse_bids_filename from rich.progress import ( BarColumn, @@ -21,12 +23,12 @@ def get_bids_images( - subjects: List[str], + subjects: list[str], template: str, bids_dir: Path, reindex_bids: bool, - bids_filters: dict, -) -> Tuple[dict, BIDSLayout]: + bids_filters: None | dict[str, dict[str, str]], +) -> tuple[dict[str, list[BIDSFile]], BIDSLayout]: """ Apply BIDS filter to the base filter we are using. Modified from fmripprep @@ -81,7 +83,9 @@ def get_bids_images( return subj_data, layout -def check_filter(bids_filters: dict) -> dict: +def check_filter( + bids_filters: None | dict[str, dict[str, str]] +) -> dict[str, dict[str, str]]: """Should only have bold and mask.""" if not bids_filters: return {} @@ -97,63 +101,61 @@ def check_filter(bids_filters: dict) -> dict: return bids_filters -def _filter_pybids_none_any(dct: dict) -> dict: - import bids - +def _filter_pybids_none_any( + dct: dict[str, None | str] +) -> dict[str, Query.NONE | Query.ANY]: return { - k: ( - bids.layout.Query.NONE - if v is None - else (bids.layout.Query.ANY if v == "*" else v) - ) + k: Query.NONE if v is None else (Query.ANY if v == "*" else v) for k, v in dct.items() } -def parse_bids_filter(value: Path) -> dict: +def parse_bids_filter(value: Path) -> None | dict[str, dict[str, str]]: from json import JSONDecodeError, loads - if value: - if value.exists(): - try: - return loads( - value.read_text(), - object_hook=_filter_pybids_none_any, - ) - except JSONDecodeError: - raise JSONDecodeError(f"JSON syntax error in: <{value}>.") - else: - raise FileNotFoundError(f"Path does not exist: <{value}>.") + if not value: + return None + + if not value.exists(): + raise FileNotFoundError(f"Path does not exist: <{value}>.") + try: + tmp = loads( + value.read_text(), + object_hook=_filter_pybids_none_any, + ) + except JSONDecodeError as e: + raise ValueError(f"JSON syntax error in: <{value}>.") from e + return tmp -def parse_standardize_options(standardize: str) -> Union[str, bool]: +def parse_standardize_options(standardize: str) -> str | bool: if standardize not in ["zscore", "psc"]: raise ValueError(f"{standardize} is not a valid standardize strategy.") - if standardize == "psc": - return standardize - else: - return True + return standardize if standardize == "psc" else True -def parse_bids_name(img: str) -> List[str]: +def parse_bids_name(img: str) -> tuple[str, str | None, str]: """Get subject, session, and specifier for a fMRIPrep output.""" reference = parse_bids_filename(img) + subject = f"sub-{reference['sub']}" - session = reference.get("ses", None) - run = reference.get("run", None) + specifier = f"task-{reference['task']}" + run = reference.get("run", None) + if isinstance(run, str): + specifier = f"{specifier}_run-{run}" + + session = reference.get("ses", None) if isinstance(session, str): session = f"ses-{session}" specifier = f"{session}_{specifier}" - if isinstance(run, str): - specifier = f"{specifier}_run-{run}" return subject, session, specifier def get_subject_lists( - participant_label: List[str] = None, bids_dir: Path = None -) -> List[str]: + participant_label: None | list[str] = None, bids_dir: None | Path = None +) -> list[str]: """ Parse subject list from user options. @@ -172,7 +174,7 @@ def get_subject_lists( Return ------ - List + list BIDS subject identifier without `sub-` prefix. """ if participant_label: @@ -184,15 +186,17 @@ def get_subject_lists( checked_labels.append(sub_id) return checked_labels # get all subjects, this is quicker than bids... - subject_dirs = bids_dir.glob("sub-*/") - return [ - subject_dir.name.split("-")[-1] - for subject_dir in subject_dirs - if subject_dir.is_dir() - ] + if bids_dir: + subject_dirs = bids_dir.glob("sub-*/") + return [ + subject_dir.name.split("-")[-1] + for subject_dir in subject_dirs + if subject_dir.is_dir() + ] + return [] -def check_path(path: Path): +def check_path(path: Path) -> None: """Check if given path (file or dir) already exists. If so, a warning is logged. diff --git a/giga_connectome/workflow.py b/giga_connectome/workflow.py index 034e910..d70824c 100644 --- a/giga_connectome/workflow.py +++ b/giga_connectome/workflow.py @@ -4,15 +4,17 @@ from __future__ import annotations +import argparse + from giga_connectome import ( generate_gm_mask_atlas, + get_denoise_strategy, load_atlas_setting, + methods, run_postprocessing_dataset, - get_denoise_strategy, + utils, ) - from giga_connectome.denoise import is_ica_aroma -from giga_connectome import utils, methods from giga_connectome.logger import gc_logger @@ -32,7 +34,7 @@ def set_verbosity(verbosity: int | list[int]) -> None: gc_log.setLevel("DEBUG") -def workflow(args): +def workflow(args: argparse.Namespace) -> None: gc_log.info(vars(args)) # set file paths diff --git a/pyproject.toml b/pyproject.toml index df8608f..6793384 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,9 +36,14 @@ giga_connectome = "giga_connectome.run:main" [project.optional-dependencies] dev = [ "black", + "flaek8", "pre-commit", "giga_connectome[test]", 'tox', + 'mypy', + 'types-all', + 'pandas-stubs', + 'types-tqdm' ] test = [ "pytest", @@ -78,6 +83,46 @@ target-version = ['py38'] exclude = "giga_connectome/_version.py" line_length = 79 + +[tool.mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +enable_error_code = ["ignore-without-code", "redundant-expr"] # "truthy-bool" +no_implicit_optional = true +show_error_codes = true +# strict = true +warn_redundant_casts = true +warn_unreachable = true +warn_unused_ignores = true + +[[tool.mypy.overrides]] +ignore_missing_imports = true +module = [ + "bids.*", + "giga_connectome._version", + "h5py.*", + "nibabel.*", + "nilearn.*", + "nilearn.connectome.*", + "nilearn.image.*", + "nilearn.interfaces.*", + "nilearn.maskers.*", + "nilearn.masking.*", + "rich.*", + "scipy.ndimage.*", + "templateflow.*", +] + +[[tool.mypy.overrides]] +ignore_errors = true +module = [ + 'giga_connectome.tests.*', + 'download_templates', + 'conf', +] + [tool.pytest.ini_options] minversion = "7" log_cli_level = "INFO" diff --git a/tools/download_templates.py b/tools/download_templates.py index 20ee80a..3faa937 100644 --- a/tools/download_templates.py +++ b/tools/download_templates.py @@ -3,40 +3,38 @@ Download atlases that are relevant. """ +import importlib.util +import shutil +import sys + from pathlib import Path +import templateflow as tf + from giga_connectome.logger import gc_logger gc_log = gc_logger() -def fetch_tpl_atlas(): +def fetch_tpl_atlas() -> None: """Download datasets from templateflow.""" - import templateflow.api as tf - atlases = ["Schaefer2018", "DiFuMo"] for atlas in atlases: - tf_path = tf.get("MNI152NLin2009cAsym", atlas=atlas) + tf_path = tf.api.get("MNI152NLin2009cAsym", atlas=atlas) if isinstance(tf_path, list) and len(tf_path) > 0: gc_log.info(f"{atlas} exists.") # download MNI grey matter template - tf.get("MNI152NLin2009cAsym", label="GM") + tf.api.get("MNI152NLin2009cAsym", label="GM") -def download_mist(): +def download_mist() -> None: """Download mist atlas and convert to templateflow format.""" - import templateflow - - tf_path = templateflow.api.get("MNI152NLin2009bAsym", atlas="BASC") + tf_path = tf.api.get("MNI152NLin2009bAsym", atlas="BASC") if isinstance(tf_path, list) and len(tf_path) > 0: gc_log.info("BASC / MIST atlas exists.") return # download and convert - import importlib.util - import sys - import shutil - spec = importlib.util.spec_from_file_location( "mist2templateflow", Path(__file__).parent / "mist2templateflow/mist2templateflow.py", @@ -45,12 +43,12 @@ def download_mist(): sys.modules["module.name"] = mist2templateflow spec.loader.exec_module(mist2templateflow) mist2templateflow.convert_basc( - templateflow.conf.TF_HOME, Path(__file__).parent / "tmp" + tf.conf.TF_HOME, Path(__file__).parent / "tmp" ) shutil.rmtree(Path(__file__).parent / "tmp") -def main(): +def main() -> None: fetch_tpl_atlas() download_mist()