Skip to content

Commit

Permalink
Pass overlays without mutating view_kwargs
Browse files Browse the repository at this point in the history
Also restruture the ViewGenerator class to call various internal
functions (for long term maintenance).
  • Loading branch information
kaitj committed Sep 18, 2024
1 parent cc58dbb commit e16c33d
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 71 deletions.
2 changes: 0 additions & 2 deletions src/niclips/figures/bold.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def cluster_timeseries(
def carpet_plot(
bold: nib.Nifti1Image,
out: StrPath | None = None,
*,
label: nib.Nifti1Image | None = None,
n_voxels: int = 2000,
seed: int = 42,
Expand Down Expand Up @@ -157,7 +156,6 @@ def carpet_plot(
def bold_mean_std(
bold: nib.Nifti1Image,
out: StrPath | None = None,
*,
std_vmax_ratio: float = 0.1,
figure: str | None = None,
**kwargs,
Expand Down
3 changes: 0 additions & 3 deletions src/niclips/figures/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _get_bval_indices(bvals: np.ndarray, bval: int) -> np.ndarray:
def visualize_qspace(
dwi: nib.Nifti1Image,
out: StrPath | None = None,
*,
thresh: int = 10,
figure: str | None = None,
) -> FuncAnimation:
Expand Down Expand Up @@ -93,7 +92,6 @@ def _rotate(angle: int) -> None:
def three_view_per_shell(
dwi: nib.Nifti1Image,
out: StrPath | None = None,
*,
thresh: int = 10,
replace_str: str = "bval",
) -> list[nib.Nifti1Image]:
Expand Down Expand Up @@ -125,7 +123,6 @@ def three_view_per_shell(
def signal_per_volume(
dwi: nib.Nifti1Image,
out: StrPath | None = None,
*,
fontsize: int = 14,
figure: str | None = None,
) -> None:
Expand Down
3 changes: 0 additions & 3 deletions src/niclips/figures/multi_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def multi_view_frame(
def three_view_frame(
img: NiftiLike,
out: StrPath | None = None,
*,
coord: tuple[float, float, float] | None = None,
idx: int | None = 0,
vmin: float | None = None,
Expand Down Expand Up @@ -138,7 +137,6 @@ def three_view_frame(
def three_view_video(
img: nib.Nifti1Image,
out: StrPath,
*,
coord: tuple[float, float, float] | None = None,
vmin: float | None = None,
vmax: float | None = None,
Expand Down Expand Up @@ -181,7 +179,6 @@ def three_view_video(
def slice_video(
img: NiftiLike,
out: StrPath,
*,
axis: int = 2,
idx: int | None = 0,
vmin: float | None = None,
Expand Down
116 changes: 55 additions & 61 deletions src/niftyone/figures/generator.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Generator classes for different views."""
"""Figure generation class for different views."""

import logging
from abc import ABC
from functools import reduce
from pathlib import Path
from types import MappingProxyType
from typing import Any, Callable, Generic, TypeVar

import matplotlib.pyplot as plt
import nibabel as nib
import pandas as pd
from bids2table import BIDSEntities, BIDSTable
from matplotlib.figure import Figure
from PIL.Image import Image

import niclips.image as noimg

Expand All @@ -37,45 +36,34 @@ def create_generator(
queries: list[str],
) -> "ViewGenerator":
"""Function to create generator."""
if view_kwargs is None:
view_kwargs = {}
view_kwargs = view_kwargs or {}

try:
generator_cls = generator_registry[view]
generator_instance = generator_cls(queries, join_entities, view_kwargs)
return generator_instance
return generator_cls(queries, join_entities, view_kwargs)
except KeyError:
msg = f"Generator for '{view}' for not found in registry."
raise KeyError(msg)
raise KeyError(f"Generator for '{view}' for not found in registry.")


def create_generators(config: dict[str, Any]) -> list["ViewGenerator"]:
"""Create selected generators dynamically from config with default settings."""
generators: list["ViewGenerator"] = []

for group in config.get("figures", {}).values():
queries = group.get("queries", [])
join_entities = group.get("join_entities", ["sub", "ses"])
views = group.get("views", {})

for view, view_kwargs in views.items():
generators.append(
create_generator(
view=view,
view_kwargs=view_kwargs,
join_entities=join_entities,
queries=queries,
)
)

return generators
return [
create_generator(
view=view,
view_kwargs=view_kwargs,
join_entities=group.get("join_entities", ["sub", "ses"]),
queries=group.get("queries", []),
)
for group in config.get("figures", {}).values()
for view, view_kwargs in group.get("views", {}).items()
]


class ViewGenerator(ABC, Generic[T]):
"""Base view generator class."""

entities: dict[str, Any] | None = None
view_fn: Callable[[nib.Nifti1Image, Path], Image | Figure | None] | None = None
view_fn: Callable | None = None

def __init__(
self,
Expand All @@ -84,7 +72,7 @@ def __init__(
view_kwargs: dict[str, Any],
) -> None:
self.queries = queries
self.view_kwargs = view_kwargs
self.view_kwargs = MappingProxyType(view_kwargs) # Immutable dict
self.join_entities = join_entities or []

def __call__(
Expand All @@ -93,15 +81,13 @@ def __call__(
out_dir: Path,
overwrite: bool,
) -> None:
# Filters by entity (via string query)
# Filters by entities via string query
# First query is for main image, subsequent are for overlays
query_dfs = [table.ent.query(q) for q in self.queries]
indexed_dfs = [
(
df.reset_index(names=f"index_{ii}", drop=False).loc[
:, [f"index_{ii}"] + self.join_entities
]
)
df.reset_index(names=f"index_{ii}", drop=False)[
[f"index_{ii}"] + self.join_entities
]
for ii, df in enumerate(query_dfs)
]
joined = reduce(
Expand All @@ -110,9 +96,7 @@ def __call__(
),
indexed_dfs,
)
indices = [
joined[col].values for col in joined.columns if col.startswith("index_")
]
indices = (joined[col].values for col in joined if col.startswith("index_"))

for inds in zip(*indices):
records = [table.nested.loc[ind] for ind in inds]
Expand All @@ -123,7 +107,31 @@ def _figure_name(self) -> None:
if "figure" in self.view_kwargs:
assert self.entities
self.entities["extra_entities"]["figure"] = self.view_kwargs["figure"]
del self.view_kwargs["figure"]
# del self.view_kwargs["figure"]

def _load_image(self, record: pd.Series, log: bool = False) -> nib.Nifti1Image:
"""Helper to load image."""
img_path = Path(record["finfo"]["file_path"])
if log:
logging.info("Processing %s", img_path)
img = nib.nifti1.load(img_path)

return noimg.to_iso_ras(img)

def _load_overlays(self, overlay_records: list[pd.Series]) -> list[nib.Nifti1Image]:
"""Helper to load overlays."""
overlays = []
for overlay_record in overlay_records:
overlays.append(self._load_image(record=overlay_record))
return overlays

def _update_out_path(self, record: pd.Series, out_dir: Path) -> Path:
"""Generates the output figure file path."""
self._figure_name()
existing_entities = BIDSEntities.from_dict(record["ent"])
out_path = existing_entities.with_update(self.entities).to_path(prefix=out_dir)
out_path.parent.mkdir(exist_ok=True, parents=True)
return out_path

def generate(
self,
Expand All @@ -135,30 +143,16 @@ def generate(
if not self.view_fn:
raise ValueError("View is not provided, unable to create generator.")

# Update figure name if necessary
self._figure_name()

img_path = Path(records[0]["finfo"]["file_path"])
logging.info("Processing: %s", img_path)

img = nib.nifti1.load(img_path)
img = noimg.to_iso_ras(img)

# Handle overlays
if len(records) > 1:
overlays: list[nib.Nifti1Image] = []
for overlay_record in records[1:]:
overlay_path = Path(overlay_record["finfo"]["file_path"])
overlay = nib.nifti1.load(overlay_path)
overlays.append(noimg.to_iso_ras(overlay))
self.view_kwargs["overlay"] = overlays

existing_entities = BIDSEntities.from_dict(records[0]["ent"])
out_path = existing_entities.with_update(self.entities).to_path(prefix=out_dir)
out_path.parent.mkdir(exist_ok=True, parents=True)
img = self._load_image(record=records[0], log=True)
overlays = (
self._load_overlays(overlay_records=records[1:])
if len(records) > 1
else None
)
out_path = self._update_out_path(records[0], out_dir)

if not out_path.exists() or overwrite:
logging.info("Generating %s", out_path)
self.view_fn(img, out_path, **self.view_kwargs)
self.view_fn(img, out_path, overlays=overlays, **self.view_kwargs)

plt.close("all")
4 changes: 2 additions & 2 deletions tests/unit/niftyone/figures/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Mapping
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_create_generators_view_kwargs(
)
assert isinstance(generator, ViewGenerator)
assert generator.queries == ["suffix == 'T1w'"]
assert isinstance(generator.view_kwargs, dict)
assert isinstance(generator.view_kwargs, Mapping)

def test_create_generators(self, setup_registry: Generator):
config = {
Expand Down

0 comments on commit e16c33d

Please sign in to comment.