Skip to content

Commit

Permalink
feat: unit based misalignments
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 1, 2024
1 parent 29d83ae commit 66dd746
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 73 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.vscode
tmp/**/*

# Traing data logging
**/tmp.json
Expand Down
83 changes: 61 additions & 22 deletions tests/unit/augmentations/test_misalign.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
def test_write_exc(mocker):
idx = mocker.MagicMock()
data = mocker.MagicMock()
proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1)
proc.prepared_disp = mocker.MagicMock()
proc = MisalignProcessor(
prob=1.0, disp_min_in_unit=1, disp_max_in_unit=1, disp_in_unit_must_be_divisible_by=1
)
proc.prepared_disp_fraction = mocker.MagicMock()
with pytest.raises(RuntimeError):
proc.process_index(idx, mode="write")

Expand All @@ -28,8 +30,14 @@ def test_tensor_process_data_slip_pos(mocker):
for z in range(5):
data_padded[0, x, y, z] = 100 * z + 10 * y + x

proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1, mode="slip")
proc.prepared_disp = (1, 2)
proc = MisalignProcessor(
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
mode="slip",
)
proc.prepared_disp_fraction = (1 / 5, 2 / 5)
chosen_z = 3
mocker.patch("random.randint", return_value=chosen_z)
result = proc.process_data(data_padded.clone(), mode="read")
Expand All @@ -49,8 +57,14 @@ def test_tensor_process_data_slip_neg(mocker):
for z in range(5):
data_padded[0, x, y, z] = 100 * z + 10 * y + x

proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1, mode="slip")
proc.prepared_disp = (-1, -2)
proc = MisalignProcessor(
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
mode="slip",
)
proc.prepared_disp_fraction = (-1 / 5, -2 / 5)
chosen_z = 3
mocker.patch("random.randint", return_value=chosen_z)
result = proc.process_data(data_padded.clone(), mode="read")
Expand All @@ -70,8 +84,14 @@ def test_tensor_process_data_step_pos(mocker):
for z in range(5):
data_padded[0, x, y, z] = 100 * z + 10 * y + x

proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1, mode="step")
proc.prepared_disp = (1, 2)
proc = MisalignProcessor(
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
mode="step",
)
proc.prepared_disp_fraction = (1 / 5, 2 / 5)
chosen_z = 3
mocker.patch("random.randint", return_value=chosen_z)
result = proc.process_data(data_padded.clone(), mode="read")
Expand All @@ -98,9 +118,14 @@ def test_dict_process_data_slip_pos(mocker):
}
keys_to_apply = ["key1", "key2"]
proc = MisalignProcessor[dict[str, torch.Tensor]](
prob=1.0, disp_min=1, disp_max=1, mode="slip", keys_to_apply=keys_to_apply
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
mode="slip",
keys_to_apply=keys_to_apply,
)
proc.prepared_disp = (1, 2)
proc.prepared_disp_fraction = (1 / 5, 2 / 5)
chosen_z = 3
mocker.patch("random.randint", return_value=chosen_z)

Expand All @@ -122,30 +147,43 @@ def test_dict_process_data_slip_pos(mocker):
def test_dict_process_no_keys_exc():
data = {"key": torch.ones((1, 5, 5, 5))}
proc = MisalignProcessor[dict[str, torch.Tensor]](
prob=1.0, disp_min=1, disp_max=1, mode="slip"
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
)
proc.prepared_disp = (1, 1)
proc.prepared_disp_fraction = (1 / 5, 1 / 5)
with pytest.raises(ValueError):
proc.process_data(data, mode="read")


def test_dict_process_diff_size_exc():
data = {"key0": torch.ones((1, 5, 5, 5)), "key1": torch.ones((1, 4, 4, 4))}
def test_dict_process_diff_size():
data = {"key0": torch.ones((1, 5, 5, 5)), "key1": torch.ones((1, 10, 10, 5))}
proc = MisalignProcessor[dict[str, torch.Tensor]](
prob=1.0, disp_min=1, disp_max=1, mode="slip", keys_to_apply=["key0", "key1"]
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
mode="slip",
keys_to_apply=["key0", "key1"],
)
proc.prepared_disp = (1, 1)
with pytest.raises(ValueError):
proc.process_data(data, mode="read")
proc.prepared_disp_fraction = (1 / 5, 1 / 5)
result = proc.process_data(data, mode="read")
assert result["key0"].shape == (1, 4, 4, 5)
assert result["key1"].shape == (1, 8, 8, 5)


def test_process_index_pos(mocker):
idx_in = VolumetricIndex(
resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((1, 2), (10, 20), (100, 200)))
)
proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1)
proc = MisalignProcessor(
prob=1.0,
disp_min_in_unit=1,
disp_max_in_unit=1,
disp_in_unit_must_be_divisible_by=1,
)
mocker.patch("random.choice", return_value=1)
mocker.patch("random.randint", return_value=1)
idx_out = proc.process_index(idx_in, mode="read")
assert idx_out == VolumetricIndex(
resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((0, 2), (9, 20), (100, 200)))
Expand All @@ -156,9 +194,10 @@ def test_process_index_neg(mocker):
idx_in = VolumetricIndex(
resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((1, 2), (10, 20), (100, 200)))
)
proc = MisalignProcessor(prob=1.0, disp_min=1, disp_max=1)
proc = MisalignProcessor(
prob=1.0, disp_min_in_unit=1, disp_max_in_unit=1, disp_in_unit_must_be_divisible_by=1
)
mocker.patch("random.choice", return_value=-1)
mocker.patch("random.randint", return_value=1)
idx_out = proc.process_index(idx_in, mode="read")
assert idx_out == VolumetricIndex(
resolution=Vec3D(1, 1, 1), bbox=BBox3D(bounds=((1, 3), (10, 21), (100, 200)))
Expand Down
122 changes: 71 additions & 51 deletions zetta_utils/augmentations/misalign.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import math
import random
from typing import Any, Literal, TypeVar

Expand All @@ -19,91 +20,111 @@
@typechecked
@attrs.mutable
class MisalignProcessor(JointIndexDataProcessor[T, VolumetricIndex]):
"""
Minimum and maximum displacement is specified in unit.
The selected displacement will be rounded to always be divisible
by `disp_in_unit_must_be_divisible_by`, which should correspond to
the lowest resolution of the layers being cropped.
Data handling is implemented through recording the fraction
of the misalignment relative to the total bbox size.
"""

prob: float
disp_min: int
disp_max: int
disp_min_in_unit: float
disp_max_in_unit: float
disp_in_unit_must_be_divisible_by: float
mode: Literal["slip", "step"] = "slip"
keys_to_apply: list[str] | None = None

prepared_disp: tuple[int, int] | None = attrs.field(init=False, default=None)
prepared_disp_fraction: tuple[float, float] | None = attrs.field(init=False, default=None)

def process_index(
self, idx: VolumetricIndex, mode: Literal["read", "write"]
) -> VolumetricIndex:
if mode != "read":
raise NotImplementedError()
disp_x = random.randint(self.disp_min, self.disp_max) * random.choice([1, -1])
disp_y = random.randint(self.disp_min, self.disp_max) * random.choice([1, -1])
self.prepared_disp = (disp_x, disp_y)
disp_x_in_unit_magn = random.uniform(self.disp_min_in_unit, self.disp_max_in_unit)
disp_y_in_unit_magn = random.uniform(self.disp_min_in_unit, self.disp_max_in_unit)
disp_x_in_unit_magn_rounded = math.floor(
disp_x_in_unit_magn / self.disp_in_unit_must_be_divisible_by
)
disp_y_in_unit_magn_rounded = math.floor(
disp_y_in_unit_magn / self.disp_in_unit_must_be_divisible_by
)
disp_x_in_unit = disp_x_in_unit_magn_rounded * random.choice([1, -1])
disp_y_in_unit = disp_y_in_unit_magn_rounded * random.choice([1, -1])

disp_x_in_idx_res = disp_x_in_unit / idx.resolution[0]
disp_y_in_idx_res = disp_y_in_unit / idx.resolution[1]

# self.prepared_disp = (disp_x, disp_y)

start_offset = [0, 0, 0]
end_offset = [0, 0, 0]
if disp_x > 0:
start_offset[0] = -disp_x
if disp_x_in_idx_res > 0:
start_offset[0] = -disp_x_in_idx_res
else:
end_offset[0] = -disp_x
if disp_y > 0:
start_offset[1] = -disp_y
end_offset[0] = -disp_x_in_idx_res
if disp_y_in_idx_res > 0:
start_offset[1] = -disp_y_in_idx_res
else:
end_offset[1] = -disp_y
end_offset[1] = -disp_y_in_idx_res

idx = idx.translated_start(Vec3D[int](*start_offset)).translated_end(
Vec3D[int](*end_offset)
)
self.prepared_disp_fraction = (
disp_x_in_idx_res / idx.shape[0],
disp_y_in_idx_res / idx.shape[0],
)
return idx

def _get_tensor_shape(self, data: T) -> torch.Size:
def process_data(self, data: T, mode: Literal["read", "write"]) -> T:
if mode != "read":
raise NotImplementedError()

if isinstance(data, torch.Tensor):
result = data.shape
z_size = data.shape[-1]
else:
assert isinstance(data, dict)
if not self.keys_to_apply:
raise ValueError(
"`keys_to_apply` must be a non-empty list of springs when "
"applying to data of type `dict`"
)
tensor_shapes = [data[k].shape for k in self.keys_to_apply]
if not all(e == tensor_shapes[0] for e in tensor_shapes):
raise ValueError(
"Tensor shapes to be processed with misalignment augmentation "
f"must all have the same shape. Got keys: {self.keys_to_apply} "
f"shapes: {tensor_shapes}"
)
result = tensor_shapes[0]
return result

def process_data(self, data: T, mode: Literal["read", "write"]) -> T:
assert self.prepared_disp is not None
if mode != "read":
raise NotImplementedError()
z_sizes = [data[k].shape[-1] for k in self.keys_to_apply]
assert all(e == z_sizes[0] for e in z_sizes)
z_size = z_sizes[0]

tensor_shape = self._get_tensor_shape(data)
z_chosen = random.randint(0, z_size - 1)

z_chosen = random.randint(0, tensor_shape[-1] - 1)
if self.mode == "slip":
z_misal_slice = slice(z_chosen, z_chosen + 1)
else:
z_misal_slice = slice(z_chosen, tensor_shape[-1])

x_size = tensor_shape[1] - abs(self.prepared_disp[0])
y_size = tensor_shape[2] - abs(self.prepared_disp[1])

x_start = 0
y_start = 0
x_start_misal = 0
y_start_misal = 0

if self.prepared_disp[0] > 0:
x_start += self.prepared_disp[0]
else:
x_start_misal += abs(self.prepared_disp[0])

if self.prepared_disp[1] > 0:
y_start += self.prepared_disp[1]
else:
y_start_misal += abs(self.prepared_disp[1])
z_misal_slice = slice(z_chosen, z_size)

def _process_tensor(tensor: torch.Tensor) -> torch.Tensor:
assert self.prepared_disp_fraction is not None
x_offset = int(tensor.shape[-3] * self.prepared_disp_fraction[0])
y_offset = int(tensor.shape[-2] * self.prepared_disp_fraction[1])

x_size = tensor.shape[-3] - abs(x_offset)
y_size = tensor.shape[-2] - abs(y_offset)

x_start = 0
y_start = 0
x_start_misal = 0
y_start_misal = 0

if x_offset > 0:
x_start += x_offset
else:
x_start_misal += abs(x_offset)

if y_offset > 0:
y_start += y_offset
else:
y_start_misal += abs(y_offset)
# Shift the data of the misaligned portion
tensor[
:, x_start : x_start + x_size, y_start : y_start + y_size, z_misal_slice
Expand All @@ -113,7 +134,6 @@ def _process_tensor(tensor: torch.Tensor) -> torch.Tensor:
y_start_misal : y_start_misal + y_size,
z_misal_slice,
].clone() # clone necessary bc inplace operation

# Crop the data from the pad
result = tensor[:, x_start : x_start + x_size, y_start : y_start + y_size, :]
return result
Expand Down

0 comments on commit 66dd746

Please sign in to comment.