Skip to content

Commit

Permalink
Test time augmentations (#91)
Browse files Browse the repository at this point in the history
* first commit test time agumentations

* the ttas probably belong here.

* adding the rotations

* fixing bug that compose() doesn't catch

* fixes to deal with non-square inputs. the rotation functions were breaking

* adding cropping of output to original

* adding another TTA option for using the product of the stack.

* revert the changes to hcs since we dont need them and removing uncessary methods in engine.py

* formatting

* ruff formatting

* fixing docstring, abtracting tta method, removing inference to cpu.

* fix missing variable in prediction using ttas

* adding the rotations

* ruff
  • Loading branch information
edyoshikun committed Jul 19, 2024
1 parent 059ca38 commit 955da74
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 5 deletions.
1 change: 0 additions & 1 deletion viscy/data/ctmc_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@


class CTMCv1ValidationDataset(SlidingWindowDataset):

def __len__(self, subsample_rate: int = 30) -> int:
# sample every 30th frame in the videos
return super().__len__() // self.subsample_rate
Expand Down
5 changes: 4 additions & 1 deletion viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,10 @@ def _setup_test(self, dataset_settings: dict):
**dataset_settings,
)

def _setup_predict(self, dataset_settings: dict):
def _setup_predict(
self,
dataset_settings: dict,
):
"""Set up the predict stage."""
# track metadata for inverting transform
set_track_meta(True)
Expand Down
84 changes: 81 additions & 3 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from lightning.pytorch import LightningModule
from matplotlib.pyplot import get_cmap
from monai.optimizers import WarmupCosineSchedule
from monai.transforms import DivisiblePad
from monai.transforms import DivisiblePad, Rotate90
from skimage.exposure import rescale_intensity
from torch import Tensor, nn
from torch.nn import functional as F
Expand Down Expand Up @@ -114,6 +114,10 @@ class VSUNet(LightningModule):
:param bool test_evaluate_cellpose:
evaluate the performance of the CellPose model instead of the trained model
in test stage, defaults to False
:param bool test_time_augmentations:
apply test time augmentations in test stage, defaults to False
:param Literal['mean', 'median', 'product'] tta_type:
type of test time augmentations aggregation, defaults to "mean"
"""

def __init__(
Expand All @@ -131,6 +135,8 @@ def __init__(
test_cellpose_model_path: str = None,
test_cellpose_diameter: float = None,
test_evaluate_cellpose: bool = False,
test_time_augmentations: bool = False,
tta_type: Literal["mean", "median", "product"] = "mean",
) -> None:
super().__init__()
net_class = _UNET_ARCHITECTURE.get(architecture)
Expand Down Expand Up @@ -163,7 +169,10 @@ def __init__(
self.test_cellpose_model_path = test_cellpose_model_path
self.test_cellpose_diameter = test_cellpose_diameter
self.test_evaluate_cellpose = test_evaluate_cellpose
self.test_time_augmentations = test_time_augmentations
self.tta_type = tta_type
self.freeze_encoder = freeze_encoder
self._original_shape_yx = None
if ckpt_path is not None:
self.load_state_dict(
torch.load(ckpt_path)["state_dict"]
Expand Down Expand Up @@ -316,8 +325,50 @@ def _log_segmentation_metrics(
)

def predict_step(self, batch: Sample, batch_idx: int, dataloader_idx: int = 0):
source = self._predict_pad(batch["source"])
return self._predict_pad.inverse(self.forward(source))
source = batch["source"]
if self.test_time_augmentations:
prediction = self.perform_test_time_augmentations(source)
else:
source = self._predict_pad(source)
prediction = self.forward(source)
prediction = self._predict_pad.inverse(prediction)

return prediction

def perform_test_time_augmentations(self, source: Tensor) -> Tensor:
"""Perform test time augmentations on the input source
and aggregate the predictions using the specified method.
:param source: input tensor
:return: aggregated prediction
"""

# Save the yx coords to crop post rotations
self._original_shape_yx = source.shape[-2:]
predictions = []
for i in range(4):
augmented = self._rotate_volume(source, k=i, spatial_axes=(1, 2))
augmented = self._predict_pad(augmented)
augmented_prediction = self.forward(augmented)
de_augmented_prediction = self._predict_pad.inverse(augmented_prediction)
de_augmented_prediction = self._rotate_volume(
de_augmented_prediction, k=4 - i, spatial_axes=(1, 2)
)
de_augmented_prediction = self._crop_to_original(de_augmented_prediction)

# Undo rotation and padding
predictions.append(de_augmented_prediction)

if self.tta_type == "mean":
prediction = torch.stack(predictions).mean(dim=0)
elif self.tta_type == "median":
prediction = torch.stack(predictions).median(dim=0).values
elif self.tta_type == "product":
# Perform multiplication of predictions in logarithmic space for numerical stability adding epsion to avoid log(0) case
log_predictions = torch.stack([torch.log(p + 1e-9) for p in predictions])
log_prediction_sum = log_predictions.sum(dim=0)
prediction = torch.exp(log_prediction_sum)
return prediction

def on_train_epoch_end(self):
self._log_samples("train_samples", self.training_step_outputs)
Expand Down Expand Up @@ -404,6 +455,33 @@ def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]):
key, grid, self.current_epoch, dataformats="HWC"
)

def _rotate_volume(self, tensor: Tensor, k: int, spatial_axes: tuple) -> Tensor:
# Padding to ensure square shape
max_dim = max(tensor.shape[-2], tensor.shape[-1])
pad_transform = DivisiblePad((0, 0, max_dim, max_dim))
padded_tensor = pad_transform(tensor)

# Rotation
rotated_tensor = []
rotate = Rotate90(k=k, spatial_axes=spatial_axes)
for b in range(padded_tensor.shape[0]): # iterate over batch
rotated_tensor.append(rotate(padded_tensor[b]))

# Stack the list of tensors back into a single tensor
rotated_tensor = torch.stack(rotated_tensor)
del padded_tensor
# # Cropping to original shape
return rotated_tensor

def _crop_to_original(self, tensor: Tensor) -> Tensor:
original_y, original_x = self._original_shape_yx
pad_y = (tensor.shape[-2] - original_y) // 2
pad_x = (tensor.shape[-1] - original_x) // 2
cropped_tensor = tensor[
..., pad_y : pad_y + original_y, pad_x : pad_x + original_x
]
return cropped_tensor


class FcmaeUNet(VSUNet):
def __init__(self, fit_mask_ratio: float = 0.0, **kwargs):
Expand Down
24 changes: 24 additions & 0 deletions viscy/light/predict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,30 @@ def _resize_image(image: ImageArray, t_index: int, z_slice: slice) -> None:


def _blend_in(old_stack: NDArray, new_stack: NDArray, z_slice: slice) -> NDArray:
"""
Blend a new stack of images into an old stack over a specified range of slices.
This function blends the `new_stack` of images into the `old_stack` over the range
specified by `z_slice`. The blending is done using a weighted average where the
weights are determined by the position within the range of slices. If the start
of `z_slice` is 0, the function returns the `new_stack` unchanged.
Parameters:
----------
old_stack : NDArray
The original stack of images to be blended.
new_stack : NDArray
The new stack of images to blend into the original stack.
z_slice : slice
A slice object indicating the range of slices over which to perform the blending.
The start and stop attributes of the slice determine the range.
Returns:
-------
NDArray
The blended stack of images. If `z_slice.start` is 0, returns `new_stack` unchanged.
"""

if z_slice.start == 0:
return new_stack
depth = z_slice.stop - z_slice.start
Expand Down

0 comments on commit 955da74

Please sign in to comment.