-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move sampler logic to
samplers
module and add unit tests (#420)
- Loading branch information
Showing
13 changed files
with
494 additions
and
236 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
"""Patch Sampler API.""" | ||
|
||
from eva.vision.data.wsi.patching.samplers.base import ForegroundSampler, Sampler | ||
from eva.vision.data.wsi.patching.samplers.foreground_grid import ForegroundGridSampler | ||
from eva.vision.data.wsi.patching.samplers.grid import GridSampler | ||
from eva.vision.data.wsi.patching.samplers.random import RandomSampler | ||
|
||
__all__ = ["Sampler", "ForegroundSampler", "RandomSampler", "GridSampler", "ForegroundGridSampler"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import random | ||
from typing import Tuple | ||
|
||
import numpy as np | ||
|
||
|
||
def set_seed(seed: int) -> None: | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
|
||
|
||
def get_grid_coords_and_indices( | ||
layer_shape: Tuple[int, int], | ||
width: int, | ||
height: int, | ||
overlap: Tuple[int, int], | ||
shuffle: bool = True, | ||
seed: int = 42, | ||
): | ||
"""Get grid coordinates and indices. | ||
Args: | ||
layer_shape: The shape of the layer. | ||
width: The width of the patches. | ||
height: The height of the patches. | ||
overlap: The overlap between patches in the grid. | ||
shuffle: Whether to shuffle the indices. | ||
seed: The random seed. | ||
""" | ||
x_range = range(0, layer_shape[0] - width + 1, width - overlap[0]) | ||
y_range = range(0, layer_shape[1] - height + 1, height - overlap[1]) | ||
x_y = [(x, y) for x in x_range for y in y_range] | ||
|
||
indices = list(range(len(x_y))) | ||
if shuffle: | ||
set_seed(seed) | ||
np.random.shuffle(indices) | ||
return x_y, indices | ||
|
||
|
||
def validate_dimensions(width: int, height: int, layer_shape: Tuple[int, int]) -> None: | ||
"""Checks if the width / height is bigger than the layer shape. | ||
Args: | ||
width: The width of the patches. | ||
height: The height of the patches. | ||
layer_shape: The shape of the layer. | ||
""" | ||
if width > layer_shape[0] or height > layer_shape[1]: | ||
raise ValueError("The width / height cannot be bigger than the layer shape.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
"""Base classes for samplers.""" | ||
|
||
import abc | ||
from typing import Generator, Tuple | ||
|
||
from eva.vision.data.wsi.patching.mask import Mask | ||
|
||
|
||
class Sampler(abc.ABC): | ||
"""Base class for samplers.""" | ||
|
||
@abc.abstractmethod | ||
def sample( | ||
self, | ||
width: int, | ||
height: int, | ||
layer_shape: Tuple[int, int], | ||
mask: Mask | None = None, | ||
) -> Generator[Tuple[int, int], None, None]: | ||
"""Sample patche coordinates. | ||
Args: | ||
width: The width of the patches. | ||
height: The height of the patches. | ||
layer_shape: The shape of the layer. | ||
mask: Tuple containing the mask array and the scaling factor with respect to the | ||
provided layer_shape. Optional, only required for samplers with foreground | ||
filtering. | ||
Returns: | ||
A generator producing sampled patch coordinates. | ||
""" | ||
|
||
|
||
class ForegroundSampler(Sampler): | ||
"""Base class for samplers with foreground filtering capabilities.""" | ||
|
||
@abc.abstractmethod | ||
def is_foreground( | ||
self, | ||
mask: Mask, | ||
x: int, | ||
y: int, | ||
width: int, | ||
height: int, | ||
min_foreground_ratio: float, | ||
) -> bool: | ||
"""Check if a patch contains sufficient foreground.""" |
Oops, something went wrong.