Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] start using mypy and updating types #86

Merged
merged 23 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ jobs:
runs-on: ubuntu-latest
needs: [build, download-test-data, check_skip_flags]
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
steps:
Expand Down
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ repos:
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies: [types-all, pandas-stubs, types-tqdm]
args: [--config-file=pyproject.toml]
- repo: https://github.com/psf/black
rev: 23.12.1
hooks:
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import sys

from giga_connectome import __version__, __packagename__, __copyright__
from giga_connectome import __copyright__, __packagename__, __version__

sys.path.insert(0, os.path.abspath(".."))
# -- Project information -----------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions giga_connectome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
except ImportError:
pass

from .mask import generate_gm_mask_atlas
from .atlas import load_atlas_setting
from .postprocess import run_postprocessing_dataset
from .denoise import get_denoise_strategy
from .mask import generate_gm_mask_atlas
from .postprocess import run_postprocessing_dataset

__all__ = [
"__copyright__",
Expand Down
78 changes: 54 additions & 24 deletions giga_connectome/atlas.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import os
import json
from typing import Union, List
from __future__ import annotations

htwangtw marked this conversation as resolved.
Show resolved Hide resolved
import json
import os
from pathlib import Path
from tqdm import tqdm
from typing import Any, Dict, List, TypedDict

import nibabel as nib
from nilearn.image import resample_to_img
from nibabel import Nifti1Image
from nilearn.image import resample_to_img
from pkg_resources import resource_filename
from tqdm import tqdm

from giga_connectome.logger import gc_logger

Expand All @@ -16,8 +18,25 @@

PRESET_ATLAS = ["DiFuMo", "MIST", "Schaefer20187Networks"]


def load_atlas_setting(atlas: Union[str, Path, dict]):
ATLAS_CONFIG_TYPE = TypedDict(
"ATLAS_CONFIG_TYPE",
{
"name": str,
"parameters": Dict[str, str],
"desc": List[str],
"templateflow_dir": Any,
},
)

ATLAS_SETTING_TYPE = TypedDict(
"ATLAS_SETTING_TYPE",
{"name": str, "file_paths": Dict[str, List[Path]], "type": str},
)


def load_atlas_setting(
atlas: str | Path | dict[str, Any],
) -> ATLAS_SETTING_TYPE:
"""Load atlas details for templateflow api to fetch.
The setting file can be configured for atlases not included in the
templateflow collections, but user has to organise their files to
Expand Down Expand Up @@ -59,19 +78,16 @@ def load_atlas_setting(atlas: Union[str, Path, dict]):

import templateflow

if isinstance(atlas_config["desc"], str):
desc = [atlas_config["desc"]]
else:
desc = atlas_config["desc"]

parcellation = {}
for d in desc:
for d in atlas_config["desc"]:
p = templateflow.api.get(
**atlas_config["parameters"],
raise_empty=True,
desc=d,
extension="nii.gz",
)
if isinstance(p, Path):
p = [p]
parcellation[d] = p
return {
"name": atlas_config["name"],
Expand All @@ -82,10 +98,10 @@ def load_atlas_setting(atlas: Union[str, Path, dict]):

def resample_atlas_collection(
template: str,
atlas_config: dict,
atlas_config: ATLAS_SETTING_TYPE,
group_mask_dir: Path,
group_mask: Nifti1Image,
) -> List[Path]:
) -> list[Path]:
"""Resample a atlas collection to group grey matter mask.

Parameters
Expand All @@ -105,7 +121,7 @@ def resample_atlas_collection(

Returns
-------
List of pathlib.Path
list of pathlib.Path
Paths to atlases sampled to group level grey matter mask.
"""
gc_log.info("Resample atlas to group grey matter mask.")
Expand All @@ -128,12 +144,14 @@ def resample_atlas_collection(
return resampled_atlases


def _check_altas_config(atlas: Union[str, Path, dict]) -> dict:
def _check_altas_config(
atlas: str | Path | dict[str, Any]
) -> ATLAS_CONFIG_TYPE:
"""Load the configuration file.

Parameters
----------
atlas : Union[str, Path, dict]
atlas : str | Path | dict
Atlas name or configuration file path.

Returns
Expand All @@ -149,23 +167,35 @@ def _check_altas_config(atlas: Union[str, Path, dict]) -> dict:
# load the file first if the input is not already a dictionary
if isinstance(atlas, (str, Path)):
if atlas in PRESET_ATLAS:
config_path = resource_filename(
"giga_connectome", f"data/atlas/{atlas}.json"
config_path = Path(
resource_filename(
"giga_connectome", f"data/atlas/{atlas}.json"
)
)
elif Path(atlas).exists():
config_path = Path(atlas)

with open(config_path, "r") as file:
atlas = json.load(file)
atlas_config = json.load(file)
else:
atlas_config = atlas

keys = list(atlas.keys())
minimal_keys = ["name", "parameters", "desc", "templateflow_dir"]
keys = list(atlas_config.keys())
common_keys = set(minimal_keys).intersection(set(keys))
if isinstance(atlas, dict) and common_keys != set(minimal_keys):
if common_keys != set(minimal_keys):
raise KeyError(
"Invalid dictionary input. Input should"
" contain minimally the following keys: 'name', "
"'parameters', 'desc', 'templateflow_dir'. Found "
f"{keys}"
)
return atlas

# cast to list of string
if isinstance(atlas_config["desc"], (str, int)):
desc = [atlas_config["desc"]]
else:
desc = atlas_config["desc"]
atlas_config["desc"] = [str(x) for x in desc]

return atlas_config
39 changes: 22 additions & 17 deletions giga_connectome/connectome.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

from pathlib import Path
from typing import Union, Tuple
from typing import Any

import numpy as np
from nilearn.maskers import NiftiMasker
from nilearn.image import load_img
from nibabel import Nifti1Image
from nilearn.connectome import ConnectivityMeasure
from nilearn.image import load_img
from nilearn.maskers import NiftiMasker


def build_size_roi(mask: np.ndarray, labels_roi: np.ndarray) -> np.ndarray:
def build_size_roi(
mask: np.ndarray[Any, Any], labels_roi: np.ndarray[Any, Any]
) -> np.ndarray[Any, np.dtype[Any]]:
"""Extract labels and sizes of ROIs given an atlas.
The atlas parcels must be discrete segmentations.

Expand Down Expand Up @@ -41,12 +46,12 @@ def build_size_roi(mask: np.ndarray, labels_roi: np.ndarray) -> np.ndarray:


def calculate_intranetwork_correlation(
correlation_matrix: np.array,
masker_labels: np.array,
time_series_atlas: np.array,
group_mask: Union[str, Path, Nifti1Image],
atlas_image: Union[str, Path, Nifti1Image],
) -> Tuple[np.ndarray, np.ndarray]:
correlation_matrix: np.ndarray[Any, Any],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this how to declare the dimension of the numpy array? That's cool

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to investigate better how to use types with numpy, this is a bit of a brute force approach.

masker_labels: np.ndarray[Any, Any],
time_series_atlas: np.ndarray[Any, Any],
group_mask: str | Path | Nifti1Image,
atlas_image: str | Path | Nifti1Image,
) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]:
"""Calculate the average functional correlation within each parcel.
Currently we only support discrete segmentations.

Expand All @@ -61,15 +66,15 @@ def calculate_intranetwork_correlation(
time_series_atlas : np.array
Time series extracted from each parcel.

group_mask : Union[str, Path, Nifti1Image]
group_mask : str | Path | Nifti1Image
The group grey matter mask.

atlas_image : Union[str, Path, Nifti1Image]
atlas_image : str | Path | Nifti1Image
3D atlas image.

Returns
-------
Tuple[np.ndarray, np.ndarray]
tuple[np.ndarray, np.ndarray]
A tuple containing the modified Pearson's correlation matrix with
the diagonal replaced by the average correlation within each parcel,
and an array of the computed average intranetwork correlations for
Expand Down Expand Up @@ -106,10 +111,10 @@ def calculate_intranetwork_correlation(
def generate_timeseries_connectomes(
masker: NiftiMasker,
denoised_img: Nifti1Image,
group_mask: Union[str, Path],
group_mask: str | Path,
correlation_measure: ConnectivityMeasure,
calculate_average_correlation: bool,
) -> Tuple[np.ndarray, np.ndarray]:
) -> tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]]:
"""Generate timeseries-based connectomes from functional data.

Parameters
Expand All @@ -120,7 +125,7 @@ def generate_timeseries_connectomes(
denoised_img : Nifti1Image
Denoised functional image.

group_mask : Union[str, Path]
group_mask : str | Path
Path to the group grey matter mask.

correlation_measure : ConnectivityMeasure
Expand All @@ -131,7 +136,7 @@ def generate_timeseries_connectomes(

Returns
-------
Tuple[np.ndarray, np.ndarray]
tuple[np.ndarray, np.ndarray]
A tuple containing the correlation matrix and time series atlas.
"""
time_series_atlas = masker.fit_transform(denoised_img)
Expand Down
Loading
Loading