Skip to content

Commit

Permalink
Add lower bound for wsi resolution level during mask generation (#412)
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig authored May 7, 2024
1 parent ecaa5ee commit 15874f5
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 79 deletions.
9 changes: 5 additions & 4 deletions src/eva/vision/data/wsi/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def level_downsamples(self) -> Sequence[float]:
@property
@abc.abstractmethod
def mpp(self) -> float:
"""Microns per pixel at the highest resolution."""
"""Microns per pixel at the highest resolution (level 0)."""

@abc.abstractmethod
def read_region(
Expand All @@ -47,9 +47,10 @@ def read_region(
"""Reads and returns image data for a specified region and zoom level.
Args:
location: Top-left corner (x, y) to start reading.
size: Region size as (width, height), relative to <location>.
level: Zoom level, with 0 being the highest resolution.
location: Top-left corner (x, y) to start reading at level 0.
level: WSI level to read from.
size: Region size as (width, height) in pixels at the selected read level.
Remember to scale the size correctly.
"""

def get_closest_level(self, target_mpp: float) -> int:
Expand Down
27 changes: 20 additions & 7 deletions src/eva/vision/data/wsi/backends/openslide.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
class WsiOpenslide(base.Wsi):
"""Class for loading data from WSI files using the OpenSlide library."""

_wsi: openslide.OpenSlide | openslide.ImageSlide
_wsi: openslide.OpenSlide

@override
def open_file(self, file_path: str) -> openslide.OpenSlide | openslide.ImageSlide:
return openslide.open_slide(file_path)
def open_file(self, file_path: str) -> openslide.OpenSlide:
return openslide.OpenSlide(file_path)

@property
@override
Expand All @@ -40,8 +40,21 @@ def mpp(self) -> float:
def read_region(
self, location: Tuple[int, int], level: int, size: Tuple[int, int]
) -> np.ndarray:
x_max, y_max = self._wsi.level_dimensions[level]
if location[0] + size[0] > x_max or location[1] + size[1] > y_max:
x_max, y_max = self.level_dimensions[0]

x_scale = x_max / self._wsi.level_dimensions[level][0]
y_scale = y_max / self._wsi.level_dimensions[level][1]

if (
int(location[0] + x_scale * size[0]) > x_max
or int(location[1] + y_scale * size[1]) > y_max
):
raise ValueError(f"Out of bounds region: {location}, {size}, {level}")
data = self._wsi.read_region(location, level, size)
return np.array(data.convert("RGB"))

data = np.array(self._wsi.read_region(location, level, size))

if data.shape[2] == 4:
# Change color to white where the alpha channel is 0
data[data[:, :, 3] == 0] = 255

return data[:, :, :3]
38 changes: 21 additions & 17 deletions src/eva/vision/data/wsi/patching/coordinates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from eva.vision.data.wsi import backends
from eva.vision.data.wsi.patching import samplers
from eva.vision.utils.mask import get_mask
from eva.vision.data.wsi.patching.mask import Mask, get_mask, get_mask_level

LRU_CACHE_SIZE = 32

Expand All @@ -16,16 +16,18 @@ class PatchCoordinates:
"""A class to store coordinates of patches from a whole-slide image.
Args:
x_y: A list of (x, y) coordinates of the patches.
width: The width of the patches, in pixels (refers to x-dim).
height: The height of the patches, in pixels (refers to y-dim).
level_idx: The level index of the patches.
x_y: A list of (x, y) coordinates of the patches (refer to level 0).
width: The width of the patches, in pixels (refers to level_idx).
height: The height of the patches, in pixels (refers to level_idx).
level_idx: The level index at which to extract the patches.
mask: The foreground mask of the wsi.
"""

x_y: List[Tuple[int, int]]
width: int
height: int
level_idx: int
mask: Mask | None = None

@classmethod
def from_file(
Expand All @@ -50,24 +52,26 @@ def from_file(
backend: The backend to use for reading the whole-slide images.
"""
wsi = backends.wsi_backend(backend)(wsi_path)
level_idx = wsi.get_closest_level(target_mpp)
level_mpp = wsi.mpp * wsi.level_downsamples[level_idx]
mpp_ratio = target_mpp / level_mpp
scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)

# Sample patch coordinates at level 0
mpp_ratio_0 = target_mpp / wsi.mpp
sample_args = {
"width": scaled_width,
"height": scaled_height,
"layer_shape": wsi.level_dimensions[level_idx],
"width": int(mpp_ratio_0 * width),
"height": int(mpp_ratio_0 * height),
"layer_shape": wsi.level_dimensions[0],
}
if isinstance(sampler, samplers.ForegroundSampler):
sample_args["mask"] = get_mask(wsi, level_idx)
mask_level_idx = get_mask_level(wsi, width, height, target_mpp)
sample_args["mask"] = get_mask(wsi, mask_level_idx)

x_y = list(sampler.sample(**sample_args))

x_y = []
for x, y in sampler.sample(**sample_args):
x_y.append((x, y))
# Scale dimensions to level that is closest to the target_mpp
level_idx = wsi.get_closest_level(target_mpp)
mpp_ratio = target_mpp / (wsi.mpp * wsi.level_downsamples[level_idx])
scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)

return cls(x_y, scaled_width, scaled_height, level_idx)
return cls(x_y, scaled_width, scaled_height, level_idx, sample_args.get("mask"))


@functools.lru_cache(LRU_CACHE_SIZE)
Expand Down
98 changes: 98 additions & 0 deletions src/eva/vision/data/wsi/patching/mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""Functions for extracting foreground masks."""

import dataclasses
from typing import Tuple

import cv2
import numpy as np

from eva.vision.data.wsi.backends.base import Wsi


@dataclasses.dataclass
class Mask:
"""A class to store the mask of a whole-slide image."""

mask_array: np.ndarray
"""Binary mask array where 1s represent the foreground and 0s represent the background."""

mask_level_idx: int
"""WSI level index at which the mask_array was extracted."""

scale_factors: Tuple[float, float]
"""Factors to scale x/y coordinates from mask_level_idx to level 0."""


def get_mask(
wsi: Wsi,
mask_level_idx: int,
kernel_size: Tuple[int, int] = (7, 7),
gray_threshold: int = 220,
fill_holes: bool = False,
) -> Mask:
"""Extracts a binary mask from an image.
Args:
wsi: The WSI object.
mask_level_idx: The level index of the WSI at which we want to extract the mask.
kernel_size: The size of the kernel for morphological operations.
gray_threshold: The threshold for the gray scale image.
fill_holes: Whether to fill holes in the mask.
"""
image = wsi.read_region((0, 0), mask_level_idx, wsi.level_dimensions[mask_level_idx])

kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, kernel_size)
gray = np.array(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), dtype=np.uint8)
mask_array = np.where(gray < gray_threshold, 1, 0).astype(np.uint8)

if fill_holes:
mask_array = cv2.dilate(mask_array, kernel, iterations=1)
contour, _ = cv2.findContours(mask_array, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_SIMPLE)
for cnt in contour:
cv2.drawContours(mask_array, [cnt], 0, (1,), -1)

scale_factors = (
wsi.level_dimensions[0][0] / wsi.level_dimensions[mask_level_idx][0],
wsi.level_dimensions[0][1] / wsi.level_dimensions[mask_level_idx][1],
)

return Mask(mask_array=mask_array, mask_level_idx=mask_level_idx, scale_factors=scale_factors)


def get_mask_level(
wsi: Wsi,
width: int,
height: int,
target_mpp: float,
min_mask_patch_pixels: int = 3 * 3,
) -> int:
"""For performance reasons, we generate the mask at the lowest resolution level possible.
However, if minimum resolution level has too few pixels, the patches scaled to that level will
be too small or even collapse to a single pixel. This function allows to find the lowest
resolution level that yields mask patches with at least `min_mask_patch_pixels` pixels.
Args:
wsi: The WSI object.
width: The width of the patches to be extracted, in pixels (at target_mpp).
height: The height of the patches to be extracted, in pixels.
target_mpp: The target microns per pixel (mpp) for the patches.
min_mask_patch_pixels: The minimum number of pixels required for the mask patches.
Mask patch refers to width / height at target_mpp scaled down to the WSI level
at which the mask is generated.
"""
level_mpps = wsi.mpp * np.array(wsi.level_downsamples)
mask_level_idx = None

for level_idx, level_mpp in reversed(list(enumerate(level_mpps))):
mpp_ratio = target_mpp / level_mpp
scaled_width, scaled_height = int(mpp_ratio * width), int(mpp_ratio * height)

if scaled_width * scaled_height >= min_mask_patch_pixels:
mask_level_idx = level_idx
break

if mask_level_idx is None:
raise ValueError("No level with the specified minimum number of patch pixels available.")

return mask_level_idx
21 changes: 11 additions & 10 deletions src/eva/vision/data/wsi/patching/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np

from eva.vision.data.wsi.patching.mask import Mask


class Sampler(abc.ABC):
"""Base class for samplers."""
Expand All @@ -16,7 +18,7 @@ def sample(
width: int,
height: int,
layer_shape: Tuple[int, int],
mask: Tuple[np.ndarray, float] | None = None,
mask: Mask | None = None,
) -> Generator[Tuple[int, int], None, None]:
"""Sample patche coordinates.
Expand All @@ -39,7 +41,7 @@ class ForegroundSampler(Sampler):
@abc.abstractmethod
def is_foreground(
self,
mask: Tuple[np.ndarray, float],
mask: Mask,
x: int,
y: int,
width: int,
Expand Down Expand Up @@ -150,7 +152,7 @@ def sample(
width: int,
height: int,
layer_shape: Tuple[int, int],
mask: Tuple[np.ndarray, float],
mask: Mask,
):
"""Sample patches from a grid containing foreground.
Expand All @@ -174,7 +176,7 @@ def sample(

def is_foreground(
self,
mask: Tuple[np.ndarray, float],
mask: Mask,
x: int,
y: int,
width: int,
Expand All @@ -191,14 +193,13 @@ def is_foreground(
height: The height of the patch.
min_foreground_ratio: The minimum amount of foreground in the patch.
"""
mask_array, mask_scale_factor = mask
x_, y_, width_, height_ = self._scale_coords(mask_scale_factor, x, y, width, height)
patch_mask = mask_array[y_ : y_ + height_, x_ : x_ + width_]
# TODO: look into warning "RuntimeWarning: invalid value encountered in divide"
x_, y_ = self._scale_coords(x, y, mask.scale_factors)
width_, height_ = self._scale_coords(width, height, mask.scale_factors)
patch_mask = mask.mask_array[y_ : y_ + height_, x_ : x_ + width_]
return patch_mask.sum() / patch_mask.size > min_foreground_ratio

def _scale_coords(self, scale_factor, *coords):
return tuple(int(coord * scale_factor) for coord in coords)
def _scale_coords(self, x: int, y: int, scale_factors: Tuple[float, float]) -> Tuple[int, int]:
return int(x / scale_factors[0]), int(y / scale_factors[1])


def _get_grid_coords_and_indices(
Expand Down
41 changes: 0 additions & 41 deletions src/eva/vision/utils/mask.py

This file was deleted.

0 comments on commit 15874f5

Please sign in to comment.