Skip to content

Commit

Permalink
use type dict for atlas
Browse files Browse the repository at this point in the history
  • Loading branch information
Remi-Gau committed Dec 15, 2023
1 parent 8342490 commit 179ae8f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 17 deletions.
52 changes: 37 additions & 15 deletions giga_connectome/atlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
from pathlib import Path
from typing import Any
from typing import Any, Dict, List, TypedDict

import nibabel as nib
from nibabel import Nifti1Image
Expand All @@ -18,10 +18,25 @@

PRESET_ATLAS = ["DiFuMo", "MIST", "Schaefer20187Networks"]

ATLAS_CONFIG_TYPE = TypedDict(
"ATLAS_CONFIG_TYPE",
{
"name": str,
"parameters": Dict[str, str],
"desc": List[str],
"templateflow_dir": Any,
},
)

ATLAS_SETTING_TYPE = TypedDict(
"ATLAS_SETTING_TYPE",
{"name": str, "file_paths": Dict[str, List[Path]], "type": str},
)


def load_atlas_setting(
atlas: str | Path | dict[str, str | Path | dict[str, str]],
) -> dict[str, Any]:
atlas: str | Path | dict[str, Any],
) -> ATLAS_SETTING_TYPE:
"""Load atlas details for templateflow api to fetch.
The setting file can be configured for atlases not included in the
templateflow collections, but user has to organise their files to
Expand Down Expand Up @@ -63,19 +78,16 @@ def load_atlas_setting(

import templateflow

if isinstance(atlas_config["desc"], str):
desc = [atlas_config["desc"]]
else:
desc = atlas_config["desc"]

parcellation = {}
for d in desc:
for d in atlas_config["desc"]:
p = templateflow.api.get(
**atlas_config["parameters"],
raise_empty=True,
desc=d,
extension="nii.gz",
)
if isinstance(p, Path):
p = [p]
parcellation[d] = p
return {
"name": atlas_config["name"],
Expand All @@ -86,7 +98,7 @@ def load_atlas_setting(

def resample_atlas_collection(
template: str,
atlas_config: dict[str, Any],
atlas_config: ATLAS_SETTING_TYPE,
group_mask_dir: Path,
group_mask: Nifti1Image,
) -> list[Path]:
Expand Down Expand Up @@ -132,7 +144,9 @@ def resample_atlas_collection(
return resampled_atlases


def _check_altas_config(atlas: str | Path | dict[str, Any]) -> dict[str, Any]:
def _check_altas_config(
atlas: str | Path | dict[str, Any]
) -> ATLAS_CONFIG_TYPE:
"""Load the configuration file.
Parameters
Expand Down Expand Up @@ -162,12 +176,12 @@ def _check_altas_config(atlas: str | Path | dict[str, Any]) -> dict[str, Any]:
config_path = Path(atlas)

with open(config_path, "r") as file:
atlas_dct = json.load(file)
atlas_config = json.load(file)
else:
atlas_dct = atlas
atlas_config = atlas

minimal_keys = ["name", "parameters", "desc", "templateflow_dir"]
keys = list(atlas_dct.keys())
keys = list(atlas_config.keys())
common_keys = set(minimal_keys).intersection(set(keys))
if common_keys != set(minimal_keys):
raise KeyError(
Expand All @@ -176,4 +190,12 @@ def _check_altas_config(atlas: str | Path | dict[str, Any]) -> dict[str, Any]:
"'parameters', 'desc', 'templateflow_dir'. Found "
f"{keys}"
)
return atlas_dct

# cast to list of string
if isinstance(atlas_config["desc"], (str, int)):
desc = [atlas_config["desc"]]
else:
desc = atlas_config["desc"]
atlas_config["desc"] = [str(x) for x in desc]

return atlas_config
6 changes: 4 additions & 2 deletions giga_connectome/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from giga_connectome.atlas import resample_atlas_collection
from giga_connectome.logger import gc_logger

from giga_connectome.atlas import ATLAS_SETTING_TYPE

gc_log = gc_logger()


def generate_gm_mask_atlas(
working_dir: Path,
atlas: dict[str, Any],
atlas: ATLAS_SETTING_TYPE,
template: str,
masks: list[BIDSImageFile],
) -> tuple[Path, list[Path]]:
Expand Down Expand Up @@ -267,7 +269,7 @@ def _check_mask_affine(


def _check_pregenerated_masks(
template: str, working_dir: Path, atlas: dict[str, Any]
template: str, working_dir: Path, atlas: ATLAS_SETTING_TYPE
) -> tuple[Path | None, list[Path] | None]:
"""Check if the working directory is populated with needed files."""
output_dir = working_dir / "groupmasks" / f"tpl-{template}"
Expand Down
1 change: 1 addition & 0 deletions giga_connectome/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from giga_connectome.denoise import is_ica_aroma
from giga_connectome.logger import gc_logger


gc_log = gc_logger()


Expand Down

0 comments on commit 179ae8f

Please sign in to comment.