From 66dd7462abd583c188f26c7de7a5f28f4ef9c1d7 Mon Sep 17 00:00:00 2001 From: Sergiy Popovych Date: Fri, 19 Jan 2024 16:41:23 +0000 Subject: [PATCH] feat: unit based misalignments --- .gitignore | 1 + tests/unit/augmentations/test_misalign.py | 83 +++++++++++---- zetta_utils/augmentations/misalign.py | 122 +++++++++++++--------- 3 files changed, 133 insertions(+), 73 deletions(-) diff --git a/.gitignore b/.gitignore index da35471d9..e9e3a7d67 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ .vscode +tmp/**/* # Traing data logging **/tmp.json diff --git a/tests/unit/augmentations/test_misalign.py b/tests/unit/augmentations/test_misalign.py index d1b0c87c7..434ca259b 100644 --- a/tests/unit/augmentations/test_misalign.py +++ b/tests/unit/augmentations/test_misalign.py @@ -12,8 +12,10 @@ def test_write_exc(mocker): idx = mocker.MagicMock() data = mocker.MagicMock() - proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1) - proc.prepared_disp = mocker.MagicMock() + proc = MisalignProcessor( + prob=1.0, disp_min_in_unit=1, disp_max_in_unit=1, disp_in_unit_must_be_divisible_by=1 + ) + proc.prepared_disp_fraction = mocker.MagicMock() with pytest.raises(RuntimeError): proc.process_index(idx, mode="write") @@ -28,8 +30,14 @@ def test_tensor_process_data_slip_pos(mocker): for z in range(5): data_padded[0, x, y, z] = 100 * z + 10 * y + x - proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1, mode="slip") - proc.prepared_disp = (1, 2) + proc = MisalignProcessor( + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, + mode="slip", + ) + proc.prepared_disp_fraction = (1 / 5, 2 / 5) chosen_z = 3 mocker.patch("random.randint", return_value=chosen_z) result = proc.process_data(data_padded.clone(), mode="read") @@ -49,8 +57,14 @@ def test_tensor_process_data_slip_neg(mocker): for z in range(5): data_padded[0, x, y, z] = 100 * z + 10 * y + x - proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1, mode="slip") - proc.prepared_disp = (-1, -2) + proc = MisalignProcessor( + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, + mode="slip", + ) + proc.prepared_disp_fraction = (-1 / 5, -2 / 5) chosen_z = 3 mocker.patch("random.randint", return_value=chosen_z) result = proc.process_data(data_padded.clone(), mode="read") @@ -70,8 +84,14 @@ def test_tensor_process_data_step_pos(mocker): for z in range(5): data_padded[0, x, y, z] = 100 * z + 10 * y + x - proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1, mode="step") - proc.prepared_disp = (1, 2) + proc = MisalignProcessor( + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, + mode="step", + ) + proc.prepared_disp_fraction = (1 / 5, 2 / 5) chosen_z = 3 mocker.patch("random.randint", return_value=chosen_z) result = proc.process_data(data_padded.clone(), mode="read") @@ -98,9 +118,14 @@ def test_dict_process_data_slip_pos(mocker): } keys_to_apply = ["key1", "key2"] proc = MisalignProcessor[dict[str, torch.Tensor]]( - prob=1.0, disp_min=1, disp_max=1, mode="slip", keys_to_apply=keys_to_apply + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, + mode="slip", + keys_to_apply=keys_to_apply, ) - proc.prepared_disp = (1, 2) + proc.prepared_disp_fraction = (1 / 5, 2 / 5) chosen_z = 3 mocker.patch("random.randint", return_value=chosen_z) @@ -122,30 +147,43 @@ def test_dict_process_data_slip_pos(mocker): def test_dict_process_no_keys_exc(): data = {"key": torch.ones((1, 5, 5, 5))} proc = MisalignProcessor[dict[str, torch.Tensor]]( - prob=1.0, disp_min=1, disp_max=1, mode="slip" + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, ) - proc.prepared_disp = (1, 1) + proc.prepared_disp_fraction = (1 / 5, 1 / 5) with pytest.raises(ValueError): proc.process_data(data, mode="read") -def test_dict_process_diff_size_exc(): - data = {"key0": torch.ones((1, 5, 5, 5)), "key1": torch.ones((1, 4, 4, 4))} +def test_dict_process_diff_size(): + data = {"key0": torch.ones((1, 5, 5, 5)), "key1": torch.ones((1, 10, 10, 5))} proc = MisalignProcessor[dict[str, torch.Tensor]]( - prob=1.0, disp_min=1, disp_max=1, mode="slip", keys_to_apply=["key0", "key1"] + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, + mode="slip", + keys_to_apply=["key0", "key1"], ) - proc.prepared_disp = (1, 1) - with pytest.raises(ValueError): - proc.process_data(data, mode="read") + proc.prepared_disp_fraction = (1 / 5, 1 / 5) + result = proc.process_data(data, mode="read") + assert result["key0"].shape == (1, 4, 4, 5) + assert result["key1"].shape == (1, 8, 8, 5) def test_process_index_pos(mocker): idx_in = VolumetricIndex( resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((1, 2), (10, 20), (100, 200))) ) - proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1) + proc = MisalignProcessor( + prob=1.0, + disp_min_in_unit=1, + disp_max_in_unit=1, + disp_in_unit_must_be_divisible_by=1, + ) mocker.patch("random.choice", return_value=1) - mocker.patch("random.randint", return_value=1) idx_out = proc.process_index(idx_in, mode="read") assert idx_out == VolumetricIndex( resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((0, 2), (9, 20), (100, 200))) @@ -156,9 +194,10 @@ def test_process_index_neg(mocker): idx_in = VolumetricIndex( resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((1, 2), (10, 20), (100, 200))) ) - proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1) + proc = MisalignProcessor( + prob=1.0, disp_min_in_unit=1, disp_max_in_unit=1, disp_in_unit_must_be_divisible_by=1 + ) mocker.patch("random.choice", return_value=-1) - mocker.patch("random.randint", return_value=1) idx_out = proc.process_index(idx_in, mode="read") assert idx_out == VolumetricIndex( resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((1, 3), (10, 21), (100, 200))) diff --git a/zetta_utils/augmentations/misalign.py b/zetta_utils/augmentations/misalign.py index 6e5c25146..3e9abfa51 100644 --- a/zetta_utils/augmentations/misalign.py +++ b/zetta_utils/augmentations/misalign.py @@ -1,5 +1,6 @@ from __future__ import annotations +import math import random from typing import Any, Literal, TypeVar @@ -19,91 +20,111 @@ @typechecked @attrs.mutable class MisalignProcessor(JointIndexDataProcessor[T, VolumetricIndex]): + """ + Minimum and maximum displacement is specified in unit. + The selected displacement will be rounded to always be divisible + by `disp_in_unit_must_be_divisible_by`, which should correspond to + the lowest resolution of the layers being cropped. + + Data handling is implemented through recording the fraction + of the misalignment relative to the total bbox size. + """ + prob: float - disp_min: int - disp_max: int + disp_min_in_unit: float + disp_max_in_unit: float + disp_in_unit_must_be_divisible_by: float mode: Literal["slip", "step"] = "slip" keys_to_apply: list[str] | None = None - prepared_disp: tuple[int, int] | None = attrs.field(init=False, default=None) + prepared_disp_fraction: tuple[float, float] | None = attrs.field(init=False, default=None) def process_index( self, idx: VolumetricIndex, mode: Literal["read", "write"] ) -> VolumetricIndex: if mode != "read": raise NotImplementedError() - disp_x = random.randint(self.disp_min, self.disp_max) * random.choice([1, -1]) - disp_y = random.randint(self.disp_min, self.disp_max) * random.choice([1, -1]) - self.prepared_disp = (disp_x, disp_y) + disp_x_in_unit_magn = random.uniform(self.disp_min_in_unit, self.disp_max_in_unit) + disp_y_in_unit_magn = random.uniform(self.disp_min_in_unit, self.disp_max_in_unit) + disp_x_in_unit_magn_rounded = math.floor( + disp_x_in_unit_magn / self.disp_in_unit_must_be_divisible_by + ) + disp_y_in_unit_magn_rounded = math.floor( + disp_y_in_unit_magn / self.disp_in_unit_must_be_divisible_by + ) + disp_x_in_unit = disp_x_in_unit_magn_rounded * random.choice([1, -1]) + disp_y_in_unit = disp_y_in_unit_magn_rounded * random.choice([1, -1]) + + disp_x_in_idx_res = disp_x_in_unit / idx.resolution[0] + disp_y_in_idx_res = disp_y_in_unit / idx.resolution[1] + + # self.prepared_disp = (disp_x, disp_y) start_offset = [0, 0, 0] end_offset = [0, 0, 0] - if disp_x > 0: - start_offset[0] = -disp_x + if disp_x_in_idx_res > 0: + start_offset[0] = -disp_x_in_idx_res else: - end_offset[0] = -disp_x - if disp_y > 0: - start_offset[1] = -disp_y + end_offset[0] = -disp_x_in_idx_res + if disp_y_in_idx_res > 0: + start_offset[1] = -disp_y_in_idx_res else: - end_offset[1] = -disp_y + end_offset[1] = -disp_y_in_idx_res idx = idx.translated_start(Vec3D[int](*start_offset)).translated_end( Vec3D[int](*end_offset) ) + self.prepared_disp_fraction = ( + disp_x_in_idx_res / idx.shape[0], + disp_y_in_idx_res / idx.shape[0], + ) return idx - def _get_tensor_shape(self, data: T) -> torch.Size: + def process_data(self, data: T, mode: Literal["read", "write"]) -> T: + if mode != "read": + raise NotImplementedError() + if isinstance(data, torch.Tensor): - result = data.shape + z_size = data.shape[-1] else: - assert isinstance(data, dict) if not self.keys_to_apply: raise ValueError( "`keys_to_apply` must be a non-empty list of springs when " "applying to data of type `dict`" ) - tensor_shapes = [data[k].shape for k in self.keys_to_apply] - if not all(e == tensor_shapes[0] for e in tensor_shapes): - raise ValueError( - "Tensor shapes to be processed with misalignment augmentation " - f"must all have the same shape. Got keys: {self.keys_to_apply} " - f"shapes: {tensor_shapes}" - ) - result = tensor_shapes[0] - return result - - def process_data(self, data: T, mode: Literal["read", "write"]) -> T: - assert self.prepared_disp is not None - if mode != "read": - raise NotImplementedError() + z_sizes = [data[k].shape[-1] for k in self.keys_to_apply] + assert all(e == z_sizes[0] for e in z_sizes) + z_size = z_sizes[0] - tensor_shape = self._get_tensor_shape(data) + z_chosen = random.randint(0, z_size - 1) - z_chosen = random.randint(0, tensor_shape[-1] - 1) if self.mode == "slip": z_misal_slice = slice(z_chosen, z_chosen + 1) else: - z_misal_slice = slice(z_chosen, tensor_shape[-1]) - - x_size = tensor_shape[1] - abs(self.prepared_disp[0]) - y_size = tensor_shape[2] - abs(self.prepared_disp[1]) - - x_start = 0 - y_start = 0 - x_start_misal = 0 - y_start_misal = 0 - - if self.prepared_disp[0] > 0: - x_start += self.prepared_disp[0] - else: - x_start_misal += abs(self.prepared_disp[0]) - - if self.prepared_disp[1] > 0: - y_start += self.prepared_disp[1] - else: - y_start_misal += abs(self.prepared_disp[1]) + z_misal_slice = slice(z_chosen, z_size) def _process_tensor(tensor: torch.Tensor) -> torch.Tensor: + assert self.prepared_disp_fraction is not None + x_offset = int(tensor.shape[-3] * self.prepared_disp_fraction[0]) + y_offset = int(tensor.shape[-2] * self.prepared_disp_fraction[1]) + + x_size = tensor.shape[-3] - abs(x_offset) + y_size = tensor.shape[-2] - abs(y_offset) + + x_start = 0 + y_start = 0 + x_start_misal = 0 + y_start_misal = 0 + + if x_offset > 0: + x_start += x_offset + else: + x_start_misal += abs(x_offset) + + if y_offset > 0: + y_start += y_offset + else: + y_start_misal += abs(y_offset) # Shift the data of the misaligned portion tensor[ :, x_start : x_start + x_size, y_start : y_start + y_size, z_misal_slice @@ -113,7 +134,6 @@ def _process_tensor(tensor: torch.Tensor) -> torch.Tensor: y_start_misal : y_start_misal + y_size, z_misal_slice, ].clone() # clone necessary bc inplace operation - # Crop the data from the pad result = tensor[:, x_start : x_start + x_size, y_start : y_start + y_size, :] return result