Skip to content

Commit

Permalink
Move sampler logic to samplers module and add unit tests (#420)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored May 7, 2024
1 parent 15874f5 commit b6c5f52
Show file tree
Hide file tree
Showing 13 changed files with 494 additions and 236 deletions.
236 changes: 0 additions & 236 deletions src/eva/vision/data/wsi/patching/samplers.py

This file was deleted.

8 changes: 8 additions & 0 deletions src/eva/vision/data/wsi/patching/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Patch Sampler API."""

from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler
from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler
from eva.vision.data.wsi.patching.samplers.grid import GridSampler
from eva.vision.data.wsi.patching.samplers.random import RandomSampler

__all__ = ["Sampler", "ForegroundSampler", "RandomSampler", "GridSampler", "ForegroundGridSampler"]
50 changes: 50 additions & 0 deletions src/eva/vision/data/wsi/patching/samplers/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import random
from typing import Tuple

import numpy as np


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)


def get_grid_coords_and_indices(
layer_shape: Tuple[int, int],
width: int,
height: int,
overlap: Tuple[int, int],
shuffle: bool = True,
seed: int = 42,
):
"""Get grid coordinates and indices.
Args:
layer_shape: The shape of the layer.
width: The width of the patches.
height: The height of the patches.
overlap: The overlap between patches in the grid.
shuffle: Whether to shuffle the indices.
seed: The random seed.
"""
x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
x_y = [(x, y) for x in x_range for y in y_range]

indices = list(range(len(x_y)))
if shuffle:
set_seed(seed)
np.random.shuffle(indices)
return x_y, indices


def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
"""Checks if the width / height is bigger than the layer shape.
Args:
width: The width of the patches.
height: The height of the patches.
layer_shape: The shape of the layer.
"""
if width > layer_shape[0] or height > layer_shape[1]:
raise ValueError("The width / height cannot be bigger than the layer shape.")
48 changes: 48 additions & 0 deletions src/eva/vision/data/wsi/patching/samplers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Base classes for samplers."""

import abc
from typing import Generator, Tuple

from eva.vision.data.wsi.patching.mask import Mask


class Sampler(abc.ABC):
"""Base class for samplers."""

@abc.abstractmethod
def sample(
self,
width: int,
height: int,
layer_shape: Tuple[int, int],
mask: Mask | None = None,
) -> Generator[Tuple[int, int], None, None]:
"""Sample patche coordinates.
Args:
width: The width of the patches.
height: The height of the patches.
layer_shape: The shape of the layer.
mask: Tuple containing the mask array and the scaling factor with respect to the
provided layer_shape. Optional, only required for samplers with foreground
filtering.
Returns:
A generator producing sampled patch coordinates.
"""


class ForegroundSampler(Sampler):
"""Base class for samplers with foreground filtering capabilities."""

@abc.abstractmethod
def is_foreground(
self,
mask: Mask,
x: int,
y: int,
width: int,
height: int,
min_foreground_ratio: float,
) -> bool:
"""Check if a patch contains sufficient foreground."""
Loading

0 comments on commit b6c5f52

Please sign in to comment.