Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move sampler logic to samplers module and add unit tests #420

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 0 additions & 236 deletions src/eva/vision/data/wsi/patching/samplers.py

This file was deleted.

8 changes: 8 additions & 0 deletions src/eva/vision/data/wsi/patching/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Patch Sampler API."""

from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler
from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler
from eva.vision.data.wsi.patching.samplers.grid import GridSampler
from eva.vision.data.wsi.patching.samplers.random import RandomSampler

__all__ = ["Sampler", "ForegroundSampler", "RandomSampler", "GridSampler", "ForegroundGridSampler"]
50 changes: 50 additions & 0 deletions src/eva/vision/data/wsi/patching/samplers/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import random
from typing import Tuple

import numpy as np


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)


def get_grid_coords_and_indices(
layer_shape: Tuple[int, int],
width: int,
height: int,
overlap: Tuple[int, int],
shuffle: bool = True,
seed: int = 42,
):
"""Get grid coordinates and indices.

Args:
layer_shape: The shape of the layer.
width: The width of the patches.
height: The height of the patches.
overlap: The overlap between patches in the grid.
shuffle: Whether to shuffle the indices.
seed: The random seed.
"""
x_range = range(0, layer_shape[0] - width + 1, width - overlap[0])
y_range = range(0, layer_shape[1] - height + 1, height - overlap[1])
x_y = [(x, y) for x in x_range for y in y_range]

indices = list(range(len(x_y)))
if shuffle:
set_seed(seed)
np.random.shuffle(indices)
return x_y, indices


def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None:
"""Checks if the width / height is bigger than the layer shape.

Args:
width: The width of the patches.
height: The height of the patches.
layer_shape: The shape of the layer.
"""
if width > layer_shape[0] or height > layer_shape[1]:
raise ValueError("The width / height cannot be bigger than the layer shape.")
48 changes: 48 additions & 0 deletions src/eva/vision/data/wsi/patching/samplers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Base classes for samplers."""

import abc
from typing import Generator, Tuple

from eva.vision.data.wsi.patching.mask import Mask


class Sampler(abc.ABC):
"""Base class for samplers."""

@abc.abstractmethod
def sample(
self,
width: int,
height: int,
layer_shape: Tuple[int, int],
mask: Mask | None = None,
) -> Generator[Tuple[int, int], None, None]:
"""Sample patche coordinates.

Args:
width: The width of the patches.
height: The height of the patches.
layer_shape: The shape of the layer.
mask: Tuple containing the mask array and the scaling factor with respect to the
provided layer_shape. Optional, only required for samplers with foreground
filtering.

Returns:
A generator producing sampled patch coordinates.
"""


class ForegroundSampler(Sampler):
"""Base class for samplers with foreground filtering capabilities."""

@abc.abstractmethod
def is_foreground(
self,
mask: Mask,
x: int,
y: int,
width: int,
height: int,
min_foreground_ratio: float,
) -> bool:
"""Check if a patch contains sufficient foreground."""
Loading