Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] start using mypy and updating types #86

Merged
merged 23 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [types-all, pandas-stubs, types-tqdm]
args: [--config-file=pyproject.toml]
- repo: https://github.com/psf/black
rev: 22.12.0
hooks:
Expand Down
4 changes: 2 additions & 2 deletions giga_connectome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
except ImportError:
pass

from .mask import generate_gm_mask_atlas
from .atlas import load_atlas_setting
from .postprocess import run_postprocessing_dataset
from .denoise import get_denoise_strategy
from .mask import generate_gm_mask_atlas
from .postprocess import run_postprocessing_dataset

__all__ = [
"__copyright__",
Expand Down
51 changes: 29 additions & 22 deletions giga_connectome/atlas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import json
from typing import Union, List
from __future__ import annotations

htwangtw marked this conversation as resolved.
Show resolved Hide resolved
import json
import os
from pathlib import Path
from tqdm import tqdm
from typing import Any

import nibabel as nib
from nilearn.image import resample_to_img
from nibabel import Nifti1Image
from nilearn.image import resample_to_img
from pkg_resources import resource_filename
from tqdm import tqdm

from giga_connectome.logger import gc_logger

Expand All @@ -17,7 +19,9 @@
PRESET_ATLAS = ["DiFuMo", "MIST", "Schaefer20187Networks"]


def load_atlas_setting(atlas: Union[str, Path, dict]):
def load_atlas_setting(
atlas: str | Path | dict[str, str | Path | dict[str, str]],
) -> dict[str, Any]:
"""Load atlas details for templateflow api to fetch.
The setting file can be configured for atlases not included in the
templateflow collections, but user has to organise their files to
Expand Down Expand Up @@ -82,10 +86,10 @@ def load_atlas_setting(atlas: Union[str, Path, dict]):

def resample_atlas_collection(
template: str,
atlas_config: dict,
atlas_config: dict[str, Any],
group_mask_dir: Path,
group_mask: Nifti1Image,
) -> List[Path]:
) -> list[Path]:
"""Resample a atlas collection to group grey matter mask.

Parameters
Expand All @@ -105,7 +109,7 @@ def resample_atlas_collection(

Returns
-------
List of pathlib.Path
list of pathlib.Path
Paths to atlases sampled to group level grey matter mask.
"""
gc_log.info("Resample atlas to group grey matter mask.")
Expand All @@ -128,12 +132,12 @@ def resample_atlas_collection(
return resampled_atlases


def _check_altas_config(atlas: Union[str, Path, dict]) -> dict:
def _check_altas_config(atlas: str | Path | dict[str, Any]) -> dict[str, Any]:
"""Load the configuration file.

Parameters
----------
atlas : Union[str, Path, dict]
atlas : str | Path | dict
Atlas name or configuration file path.

Returns
Expand All @@ -149,23 +153,26 @@ def _check_altas_config(atlas: Union[str, Path, dict]) -> dict:
# load the file first if the input is not already a dictionary
if isinstance(atlas, (str, Path)):
if atlas in PRESET_ATLAS:
config_path = resource_filename(
"giga_connectome", f"data/atlas/{atlas}.json"
config_path = Path(
resource_filename(
"giga_connectom)e", f"data/atlas/{atlas}.json"
)
)
elif Path(atlas).exists():
config_path = Path(atlas)

with open(config_path, "r") as file:
atlas = json.load(file)

keys = list(atlas.keys())
minimal_keys = ["name", "parameters", "desc", "templateflow_dir"]
common_keys = set(minimal_keys).intersection(set(keys))
if isinstance(atlas, dict) and common_keys != set(minimal_keys):
raise KeyError(
"Invalid dictionary input. Input should"
" contain minimally the following keys: 'name', "
"'parameters', 'desc', 'templateflow_dir'. Found "
f"{keys}"
)
if isinstance(atlas, dict):
keys = list(atlas.keys())
common_keys = set(minimal_keys).intersection(set(keys))
if common_keys != set(minimal_keys):
raise KeyError(
"Invalid dictionary input. Input should"
" contain minimally the following keys: 'name', "
"'parameters', 'desc', 'templateflow_dir'. Found "
f"{keys}"
)
return atlas
40 changes: 23 additions & 17 deletions giga_connectome/connectome.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from __future__ import annotations

from pathlib import Path
from typing import Union, Tuple
from typing import Any

import numpy as np
from nilearn.maskers import NiftiMasker
from nilearn.image import load_img
from nibabel import Nifti1Image
from nilearn.connectome import ConnectivityMeasure
from nilearn.image import load_img
from nilearn.maskers import NiftiMasker
from numpy._typing import NDArray


def build_size_roi(mask: np.ndarray, labels_roi: np.ndarray) -> np.ndarray:
def build_size_roi(
mask: NDArray[Any], labels_roi: NDArray[Any]
) -> np.ndarray[Any, np.dtype[Any]]:
"""Extract labels and sizes of ROIs given an atlas.
The atlas parcels must be discrete segmentations.

Expand Down Expand Up @@ -41,12 +47,12 @@ def build_size_roi(mask: np.ndarray, labels_roi: np.ndarray) -> np.ndarray:


def calculate_intranetwork_correlation(
correlation_matrix: np.array,
masker_labels: np.array,
time_series_atlas: np.array,
group_mask: Union[str, Path, Nifti1Image],
atlas_image: Union[str, Path, Nifti1Image],
) -> Tuple[np.ndarray, np.ndarray]:
correlation_matrix: NDArray[Any],
masker_labels: NDArray[Any],
time_series_atlas: NDArray[Any],
group_mask: str | Path | Nifti1Image,
atlas_image: str | Path | Nifti1Image,
) -> tuple[NDArray[Any], NDArray[Any]]:
"""Calculate the average functional correlation within each parcel.
Currently we only support discrete segmentations.

Expand All @@ -61,15 +67,15 @@ def calculate_intranetwork_correlation(
time_series_atlas : np.array
Time series extracted from each parcel.

group_mask : Union[str, Path, Nifti1Image]
group_mask : str | Path | Nifti1Image
The group grey matter mask.

atlas_image : Union[str, Path, Nifti1Image]
atlas_image : str | Path | Nifti1Image
3D atlas image.

Returns
-------
Tuple[np.ndarray, np.ndarray]
tuple[np.ndarray, np.ndarray]
A tuple containing the modified Pearson's correlation matrix with
the diagonal replaced by the average correlation within each parcel,
and an array of the computed average intranetwork correlations for
Expand Down Expand Up @@ -106,10 +112,10 @@ def calculate_intranetwork_correlation(
def generate_timeseries_connectomes(
masker: NiftiMasker,
denoised_img: Nifti1Image,
group_mask: Union[str, Path],
group_mask: str | Path,
correlation_measure: ConnectivityMeasure,
calculate_average_correlation: bool,
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[NDArray[Any], NDArray[Any]]:
"""Generate timeseries-based connectomes from functional data.

Parameters
Expand All @@ -120,7 +126,7 @@ def generate_timeseries_connectomes(
denoised_img : Nifti1Image
Denoised functional image.

group_mask : Union[str, Path]
group_mask : str | Path
Path to the group grey matter mask.

correlation_measure : ConnectivityMeasure
Expand All @@ -131,7 +137,7 @@ def generate_timeseries_connectomes(

Returns
-------
Tuple[np.ndarray, np.ndarray]
tuple[np.ndarray, np.ndarray]
A tuple containing the correlation matrix and time series atlas.
"""
time_series_atlas = masker.fit_transform(denoised_img)
Expand Down
28 changes: 13 additions & 15 deletions giga_connectome/denoise.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
from typing import Union, Optional

import json
from pathlib import Path
import pandas as pd
from typing import Optional, Any

import numpy as np
import pandas as pd
from nibabel import Nifti1Image

from nilearn.interfaces import fmriprep
from nilearn.maskers import NiftiMasker

from pkg_resources import resource_filename


PRESET_STRATEGIES = [
"simple",
"simple+gsr",
Expand All @@ -26,7 +23,7 @@

def get_denoise_strategy(
strategy: str,
) -> dict:
) -> dict[str, str | dict[str, str]]:
"""
Select denoise strategies and associated parameters.
The strategy parameters are designed to pass to load_confounds_strategy.
Expand All @@ -47,7 +44,7 @@ def get_denoise_strategy(
Denosing strategy parameter to pass to load_confounds_strategy.
"""
if strategy in PRESET_STRATEGIES:
config_path = resource_filename(
config_path: str | Path = resource_filename(
"giga_connectome", f"data/denoise_strategy/{strategy}.json"
)
elif Path(strategy).exists():
Expand All @@ -63,7 +60,7 @@ def get_denoise_strategy(
return benchmark_strategy


def is_ica_aroma(strategy: str) -> bool:
def is_ica_aroma(strategy: dict[str, dict[str, str]]) -> bool:
"""Check if the current strategy is ICA AROMA.

Parameters
Expand All @@ -89,9 +86,9 @@ def is_ica_aroma(strategy: str) -> bool:


def denoise_nifti_voxel(
strategy: dict,
group_mask: Union[str, Path],
standardize: Union[str, bool],
strategy: dict[str, dict[str, str]],
group_mask: str | Path,
standardize: str | bool,
smoothing_fwhm: float,
img: str,
) -> Nifti1Image:
Expand All @@ -101,9 +98,9 @@ def denoise_nifti_voxel(
----------
strategy : dict
Denoising strategy parameter to pass to load_confounds_strategy.
group_mask : Union[str, Path]
group_mask : str | Path
Path to the group mask.
standardize : Union[str, bool]
standardize : str | bool
Standardize the data. If 'zscore', zscore the data. If 'psc', convert
the data to percent signal change. If False, do not standardize.
smoothing_fwhm : float
Expand Down Expand Up @@ -138,7 +135,8 @@ def denoise_nifti_voxel(


def _check_exclusion(
reduced_confounds: pd.DataFrame, sample_mask: Optional[np.ndarray]
reduced_confounds: pd.DataFrame,
sample_mask: Optional[np.ndarray[Any, Any]],
) -> bool:
"""For scrubbing based strategy, check if regression can be performed."""
if sample_mask is not None:
Expand Down
Loading
Loading