diff --git a/src/niclips/figures/bold.py b/src/niclips/figures/bold.py index c9375a2..b4ac5bf 100644 --- a/src/niclips/figures/bold.py +++ b/src/niclips/figures/bold.py @@ -62,7 +62,6 @@ def cluster_timeseries( def carpet_plot( bold: nib.Nifti1Image, out: StrPath | None = None, - *, label: nib.Nifti1Image | None = None, n_voxels: int = 2000, seed: int = 42, @@ -157,7 +156,6 @@ def carpet_plot( def bold_mean_std( bold: nib.Nifti1Image, out: StrPath | None = None, - *, std_vmax_ratio: float = 0.1, figure: str | None = None, **kwargs, diff --git a/src/niclips/figures/dwi.py b/src/niclips/figures/dwi.py index d21d9d3..f96ce58 100644 --- a/src/niclips/figures/dwi.py +++ b/src/niclips/figures/dwi.py @@ -41,7 +41,6 @@ def _get_bval_indices(bvals: np.ndarray, bval: int) -> np.ndarray: def visualize_qspace( dwi: nib.Nifti1Image, out: StrPath | None = None, - *, thresh: int = 10, figure: str | None = None, ) -> FuncAnimation: @@ -93,7 +92,6 @@ def _rotate(angle: int) -> None: def three_view_per_shell( dwi: nib.Nifti1Image, out: StrPath | None = None, - *, thresh: int = 10, replace_str: str = "bval", ) -> list[nib.Nifti1Image]: @@ -125,7 +123,6 @@ def three_view_per_shell( def signal_per_volume( dwi: nib.Nifti1Image, out: StrPath | None = None, - *, fontsize: int = 14, figure: str | None = None, ) -> None: diff --git a/src/niclips/figures/multi_view.py b/src/niclips/figures/multi_view.py index cb0d7f3..a0224b5 100644 --- a/src/niclips/figures/multi_view.py +++ b/src/niclips/figures/multi_view.py @@ -89,7 +89,6 @@ def multi_view_frame( def three_view_frame( img: NiftiLike, out: StrPath | None = None, - *, coord: tuple[float, float, float] | None = None, idx: int | None = 0, vmin: float | None = None, @@ -138,7 +137,6 @@ def three_view_frame( def three_view_video( img: nib.Nifti1Image, out: StrPath, - *, coord: tuple[float, float, float] | None = None, vmin: float | None = None, vmax: float | None = None, @@ -181,7 +179,6 @@ def three_view_video( def slice_video( img: NiftiLike, out: StrPath, - *, axis: int = 2, idx: int | None = 0, vmin: float | None = None, diff --git a/src/niftyone/figures/generator.py b/src/niftyone/figures/generator.py index ce6ff65..01582e6 100644 --- a/src/niftyone/figures/generator.py +++ b/src/niftyone/figures/generator.py @@ -1,17 +1,16 @@ -"""Generator classes for different views.""" +"""Figure generation class for different views.""" import logging from abc import ABC from functools import reduce from pathlib import Path +from types import MappingProxyType from typing import Any, Callable, Generic, TypeVar import matplotlib.pyplot as plt import nibabel as nib import pandas as pd from bids2table import BIDSEntities, BIDSTable -from matplotlib.figure import Figure -from PIL.Image import Image import niclips.image as noimg @@ -37,45 +36,34 @@ def create_generator( queries: list[str], ) -> "ViewGenerator": """Function to create generator.""" - if view_kwargs is None: - view_kwargs = {} + view_kwargs = view_kwargs or {} try: generator_cls = generator_registry[view] - generator_instance = generator_cls(queries, join_entities, view_kwargs) - return generator_instance + return generator_cls(queries, join_entities, view_kwargs) except KeyError: - msg = f"Generator for '{view}' for not found in registry." - raise KeyError(msg) + raise KeyError(f"Generator for '{view}' for not found in registry.") def create_generators(config: dict[str, Any]) -> list["ViewGenerator"]: """Create selected generators dynamically from config with default settings.""" - generators: list["ViewGenerator"] = [] - - for group in config.get("figures", {}).values(): - queries = group.get("queries", []) - join_entities = group.get("join_entities", ["sub", "ses"]) - views = group.get("views", {}) - - for view, view_kwargs in views.items(): - generators.append( - create_generator( - view=view, - view_kwargs=view_kwargs, - join_entities=join_entities, - queries=queries, - ) - ) - - return generators + return [ + create_generator( + view=view, + view_kwargs=view_kwargs, + join_entities=group.get("join_entities", ["sub", "ses"]), + queries=group.get("queries", []), + ) + for group in config.get("figures", {}).values() + for view, view_kwargs in group.get("views", {}).items() + ] class ViewGenerator(ABC, Generic[T]): """Base view generator class.""" entities: dict[str, Any] | None = None - view_fn: Callable[[nib.Nifti1Image, Path], Image | Figure | None] | None = None + view_fn: Callable | None = None def __init__( self, @@ -84,7 +72,7 @@ def __init__( view_kwargs: dict[str, Any], ) -> None: self.queries = queries - self.view_kwargs = view_kwargs + self.view_kwargs = MappingProxyType(view_kwargs) # Immutable dict self.join_entities = join_entities or [] def __call__( @@ -93,15 +81,13 @@ def __call__( out_dir: Path, overwrite: bool, ) -> None: - # Filters by entity (via string query) + # Filters by entities via string query # First query is for main image, subsequent are for overlays query_dfs = [table.ent.query(q) for q in self.queries] indexed_dfs = [ - ( - df.reset_index(names=f"index_{ii}", drop=False).loc[ - :, [f"index_{ii}"] + self.join_entities - ] - ) + df.reset_index(names=f"index_{ii}", drop=False)[ + [f"index_{ii}"] + self.join_entities + ] for ii, df in enumerate(query_dfs) ] joined = reduce( @@ -110,9 +96,7 @@ def __call__( ), indexed_dfs, ) - indices = [ - joined[col].values for col in joined.columns if col.startswith("index_") - ] + indices = (joined[col].values for col in joined if col.startswith("index_")) for inds in zip(*indices): records = [table.nested.loc[ind] for ind in inds] @@ -123,7 +107,31 @@ def _figure_name(self) -> None: if "figure" in self.view_kwargs: assert self.entities self.entities["extra_entities"]["figure"] = self.view_kwargs["figure"] - del self.view_kwargs["figure"] + # del self.view_kwargs["figure"] + + def _load_image(self, record: pd.Series, log: bool = False) -> nib.Nifti1Image: + """Helper to load image.""" + img_path = Path(record["finfo"]["file_path"]) + if log: + logging.info("Processing %s", img_path) + img = nib.nifti1.load(img_path) + + return noimg.to_iso_ras(img) + + def _load_overlays(self, overlay_records: list[pd.Series]) -> list[nib.Nifti1Image]: + """Helper to load overlays.""" + overlays = [] + for overlay_record in overlay_records: + overlays.append(self._load_image(record=overlay_record)) + return overlays + + def _update_out_path(self, record: pd.Series, out_dir: Path) -> Path: + """Generates the output figure file path.""" + self._figure_name() + existing_entities = BIDSEntities.from_dict(record["ent"]) + out_path = existing_entities.with_update(self.entities).to_path(prefix=out_dir) + out_path.parent.mkdir(exist_ok=True, parents=True) + return out_path def generate( self, @@ -135,30 +143,16 @@ def generate( if not self.view_fn: raise ValueError("View is not provided, unable to create generator.") - # Update figure name if necessary - self._figure_name() - - img_path = Path(records[0]["finfo"]["file_path"]) - logging.info("Processing: %s", img_path) - - img = nib.nifti1.load(img_path) - img = noimg.to_iso_ras(img) - - # Handle overlays - if len(records) > 1: - overlays: list[nib.Nifti1Image] = [] - for overlay_record in records[1:]: - overlay_path = Path(overlay_record["finfo"]["file_path"]) - overlay = nib.nifti1.load(overlay_path) - overlays.append(noimg.to_iso_ras(overlay)) - self.view_kwargs["overlay"] = overlays - - existing_entities = BIDSEntities.from_dict(records[0]["ent"]) - out_path = existing_entities.with_update(self.entities).to_path(prefix=out_dir) - out_path.parent.mkdir(exist_ok=True, parents=True) + img = self._load_image(record=records[0], log=True) + overlays = ( + self._load_overlays(overlay_records=records[1:]) + if len(records) > 1 + else None + ) + out_path = self._update_out_path(records[0], out_dir) if not out_path.exists() or overwrite: logging.info("Generating %s", out_path) - self.view_fn(img, out_path, **self.view_kwargs) + self.view_fn(img, out_path, overlays=overlays, **self.view_kwargs) plt.close("all") diff --git a/tests/unit/niftyone/figures/test_generator.py b/tests/unit/niftyone/figures/test_generator.py index 5ec83a6..0df3e4d 100644 --- a/tests/unit/niftyone/figures/test_generator.py +++ b/tests/unit/niftyone/figures/test_generator.py @@ -1,4 +1,4 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from pathlib import Path from typing import Any from unittest.mock import MagicMock @@ -90,7 +90,7 @@ def test_create_generators_view_kwargs( ) assert isinstance(generator, ViewGenerator) assert generator.queries == ["suffix == 'T1w'"] - assert isinstance(generator.view_kwargs, dict) + assert isinstance(generator.view_kwargs, Mapping) def test_create_generators(self, setup_registry: Generator): config = {