Skip to content

Commit

Permalink
add oneof transform to compose transforms more flexiblly
Browse files Browse the repository at this point in the history
  • Loading branch information
xiuliren committed May 7, 2021
1 parent 9743934 commit cc2679b
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 49 deletions.
12 changes: 0 additions & 12 deletions neutorch/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,18 +123,6 @@ def train(seed: int, training_split_ratio: float, patch_size: tuple,
log_tensor(writer, 'train/prediction', predict, iter_idx)
log_tensor(writer, 'train/target', target, iter_idx)

# target2d, _ = torch.max(target, dim=2, keepdim=False)
# slices = torch.cat((image[:, :, 32, :, :], predict[:, :, 32, :, :], target2d))
# image_path = os.path.expanduser('~/Downloads/patches.png')
# print('save a batch of patches to ', image_path)
# torchvision.utils.save_image(
# slices,
# image_path,
# nrow=1,
# normalize=True,
# scale_each=True,
# )

if iter_idx % validation_interval == 0:
fname = os.path.join(output_dir, f'model_{iter_idx}.chkpt')
print(f'save model to {fname}')
Expand Down
10 changes: 6 additions & 4 deletions neutorch/dataset/tbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,10 +170,12 @@ def _prepare_transform(self):
AdjustBrightness(),
AdjustContrast(),
Gamma(),
Noise(),
GaussianBlur2D(),
OneOf([
Noise(),
GaussianBlur2D(),
]),
BlackBox(),
Perspective(),
Perspective2D(),
# RotateScale(probability=1.),
DropSection(),
Flip(),
Expand All @@ -200,7 +202,7 @@ def _prepare_transform(self):
for n in range(10000):
ping = time()
patch = dataset.random_training_patch
print(f'generating a patch takes {int(time()-ping)} seconds.')
print(f'generating a patch takes {round(time()-ping, 3)} seconds.')
image = patch.image
label = patch.label
with h5py.File('/tmp/image.h5', 'w') as file:
Expand Down
97 changes: 64 additions & 33 deletions neutorch/dataset/transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod
import random
from functools import lru_cache
from copy import deepcopy

import numpy as np

Expand Down Expand Up @@ -124,6 +125,24 @@ def __call__(self, patch: Patch):
patch.label = patch.label.copy()


class OneOf(AbstractTransform):
def __init__(self, transforms: list,
probability: float = DEFAULT_PROBABILITY) -> None:
super().__init__(probability=probability)
assert len(transforms) > 1
self.transforms = transforms

shrink_size = np.zeros((6,), dtype=np.int64)
for transform in transforms:
if isinstance(transform, SpatialTransform):
shrink_size += np.asarray(transform.shrink_size)
self.shrink_size = tuple(x for x in shrink_size)

def transform(self, patch: Patch):
# select one of the transforms
transform = random.choice(self.transforms)
transform(patch)


class DropSection(SpatialTransform):
def __init__(self, probability: float = DEFAULT_PROBABILITY):
Expand Down Expand Up @@ -392,7 +411,7 @@ def shrink_size(self):
return (self.max_displacement,) * 6


class Perspective(SpatialTransform):
class Perspective2D(SpatialTransform):
def __init__(self, probability: float=DEFAULT_PROBABILITY,
corner_ratio: float=0.2):
"""Warp image using Perspective transform
Expand Down Expand Up @@ -445,45 +464,57 @@ def _transform2d(self, arr: np.ndarray, interpolation: int):
random.randint(sy-round(sy*corner_ratio/2), sy-1),
random.randint(sx-round(sx*corner_ratio/2), sx-1)
]
pts1 = np.asarray([
pts1 = [
upper_left_point,
upper_right_point,
lower_left_point,
lower_right_point], dtype=np.float32)
lower_right_point
]
# push the list order to get rotation effect
# for example, push one position will rotate about 90 degrees
# push_index = random.randint(0, 3)
# if push_index > 0:
# tmp = deepcopy(pts1)
# pts1[push_index:] = tmp[:4-push_index]
# # the pushed out elements should be reversed
# pts1[:push_index] = tmp[4-push_index:][::-1]

pts1 = np.asarray(pts1, dtype=np.float32)

pts2 =np.float32([[0, 0], [0, sx], [sy, 0], [sy, sx]])
M = cv2.getPerspectiveTransform(pts1, pts2)
dst = cv2.warpPerspective(arr, M, (sy, sx), flags=interpolation)
return dst


class RotateScale(SpatialTransform):
def __init__(self, probability: float=DEFAULT_PROBABILITY,
max_scaling: float=1.3):
super().__init__(probability=probability)
self.max_scaling = max_scaling

def transform(self, patch: Patch):
# because we do not know the rotation angle
# we should apply the shrinking first
patch.apply_delayed_shrink_size()

# if the rotation is close to diagnal, for example 45 degree
# the label could be outside the volume and be black!
# angle = random.choice([0, 90, 180, -90, -180]) + random.randint(-5, 5)
angle = random.randint(0, 180)
scale = random.uniform(1.1, self.max_scaling)
center = patch.center[-2:]
mat = cv2.getRotationMatrix2D( center, angle, scale )
# breakpoint()
for batch in range(patch.shape[0]):
for channel in range(patch.shape[1]):
for z in range(patch.shape[2]):
patch.image[batch, channel, z, ...] = cv2.warpAffine(
patch.image[batch, channel, z, ...],
mat, patch.shape[-2:], flags=cv2.INTER_LINEAR
)
patch.label[batch, channel, z, ...] = cv2.warpAffine(
patch.label[batch, channel, z, ...],
mat, patch.shape[-2:], flags=cv2.INTER_NEAREST
)
# class RotateScale(SpatialTransform):
# def __init__(self, probability: float=DEFAULT_PROBABILITY,
# max_scaling: float=1.3):
# super().__init__(probability=probability)
# raise NotImplementedError('this augmentation is not working correctly yet. The image and label could have patchy effect.We are not sure why.')
# self.max_scaling = max_scaling

# def transform(self, patch: Patch):
# # because we do not know the rotation angle
# # we should apply the shrinking first
# patch.apply_delayed_shrink_size()

# # if the rotation is close to diagnal, for example 45 degree
# # the label could be outside the volume and be black!
# # angle = random.choice([0, 90, 180, -90, -180]) + random.randint(-5, 5)
# angle = random.randint(0, 180)
# scale = random.uniform(1.1, self.max_scaling)
# center = patch.center[-2:]
# mat = cv2.getRotationMatrix2D( center, angle, scale )

# for batch in range(patch.shape[0]):
# for channel in range(patch.shape[1]):
# for z in range(patch.shape[2]):
# patch.image[batch, channel, z, ...] = cv2.warpAffine(
# patch.image[batch, channel, z, ...],
# mat, patch.shape[-2:], flags=cv2.INTER_LINEAR
# )
# patch.label[batch, channel, z, ...] = cv2.warpAffine(
# patch.label[batch, channel, z, ...],
# mat, patch.shape[-2:], flags=cv2.INTER_NEAREST
# )

0 comments on commit cc2679b

Please sign in to comment.