diff --git a/neutorch/cli/train.py b/neutorch/cli/train.py index 092ff3a..a2eb2fc 100644 --- a/neutorch/cli/train.py +++ b/neutorch/cli/train.py @@ -123,18 +123,6 @@ def train(seed: int, training_split_ratio: float, patch_size: tuple, log_tensor(writer, 'train/prediction', predict, iter_idx) log_tensor(writer, 'train/target', target, iter_idx) - # target2d, _ = torch.max(target, dim=2, keepdim=False) - # slices = torch.cat((image[:, :, 32, :, :], predict[:, :, 32, :, :], target2d)) - # image_path = os.path.expanduser('~/Downloads/patches.png') - # print('save a batch of patches to ', image_path) - # torchvision.utils.save_image( - # slices, - # image_path, - # nrow=1, - # normalize=True, - # scale_each=True, - # ) - if iter_idx % validation_interval == 0: fname = os.path.join(output_dir, f'model_{iter_idx}.chkpt') print(f'save model to {fname}') diff --git a/neutorch/dataset/tbar.py b/neutorch/dataset/tbar.py index 29ac8b1..7cd1bad 100644 --- a/neutorch/dataset/tbar.py +++ b/neutorch/dataset/tbar.py @@ -170,10 +170,12 @@ def _prepare_transform(self): AdjustBrightness(), AdjustContrast(), Gamma(), - Noise(), - GaussianBlur2D(), + OneOf([ + Noise(), + GaussianBlur2D(), + ]), BlackBox(), - Perspective(), + Perspective2D(), # RotateScale(probability=1.), DropSection(), Flip(), @@ -200,7 +202,7 @@ def _prepare_transform(self): for n in range(10000): ping = time() patch = dataset.random_training_patch - print(f'generating a patch takes {int(time()-ping)} seconds.') + print(f'generating a patch takes {round(time()-ping, 3)} seconds.') image = patch.image label = patch.label with h5py.File('/tmp/image.h5', 'w') as file: diff --git a/neutorch/dataset/transform.py b/neutorch/dataset/transform.py index 487ba2d..d083171 100644 --- a/neutorch/dataset/transform.py +++ b/neutorch/dataset/transform.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod import random from functools import lru_cache +from copy import deepcopy import numpy as np @@ -124,6 +125,24 @@ def __call__(self, patch: Patch): patch.label = patch.label.copy() +class OneOf(AbstractTransform): + def __init__(self, transforms: list, + probability: float = DEFAULT_PROBABILITY) -> None: + super().__init__(probability=probability) + assert len(transforms) > 1 + self.transforms = transforms + + shrink_size = np.zeros((6,), dtype=np.int64) + for transform in transforms: + if isinstance(transform, SpatialTransform): + shrink_size += np.asarray(transform.shrink_size) + self.shrink_size = tuple(x for x in shrink_size) + + def transform(self, patch: Patch): + # select one of the transforms + transform = random.choice(self.transforms) + transform(patch) + class DropSection(SpatialTransform): def __init__(self, probability: float = DEFAULT_PROBABILITY): @@ -392,7 +411,7 @@ def shrink_size(self): return (self.max_displacement,) * 6 -class Perspective(SpatialTransform): +class Perspective2D(SpatialTransform): def __init__(self, probability: float=DEFAULT_PROBABILITY, corner_ratio: float=0.2): """Warp image using Perspective transform @@ -445,11 +464,22 @@ def _transform2d(self, arr: np.ndarray, interpolation: int): random.randint(sy-round(sy*corner_ratio/2), sy-1), random.randint(sx-round(sx*corner_ratio/2), sx-1) ] - pts1 = np.asarray([ + pts1 = [ upper_left_point, upper_right_point, lower_left_point, - lower_right_point], dtype=np.float32) + lower_right_point + ] + # push the list order to get rotation effect + # for example, push one position will rotate about 90 degrees + # push_index = random.randint(0, 3) + # if push_index > 0: + # tmp = deepcopy(pts1) + # pts1[push_index:] = tmp[:4-push_index] + # # the pushed out elements should be reversed + # pts1[:push_index] = tmp[4-push_index:][::-1] + + pts1 = np.asarray(pts1, dtype=np.float32) pts2 =np.float32([[0, 0], [0, sx], [sy, 0], [sy, sx]]) M = cv2.getPerspectiveTransform(pts1, pts2) @@ -457,33 +487,34 @@ def _transform2d(self, arr: np.ndarray, interpolation: int): return dst -class RotateScale(SpatialTransform): - def __init__(self, probability: float=DEFAULT_PROBABILITY, - max_scaling: float=1.3): - super().__init__(probability=probability) - self.max_scaling = max_scaling - - def transform(self, patch: Patch): - # because we do not know the rotation angle - # we should apply the shrinking first - patch.apply_delayed_shrink_size() - - # if the rotation is close to diagnal, for example 45 degree - # the label could be outside the volume and be black! - # angle = random.choice([0, 90, 180, -90, -180]) + random.randint(-5, 5) - angle = random.randint(0, 180) - scale = random.uniform(1.1, self.max_scaling) - center = patch.center[-2:] - mat = cv2.getRotationMatrix2D( center, angle, scale ) - # breakpoint() - for batch in range(patch.shape[0]): - for channel in range(patch.shape[1]): - for z in range(patch.shape[2]): - patch.image[batch, channel, z, ...] = cv2.warpAffine( - patch.image[batch, channel, z, ...], - mat, patch.shape[-2:], flags=cv2.INTER_LINEAR - ) - patch.label[batch, channel, z, ...] = cv2.warpAffine( - patch.label[batch, channel, z, ...], - mat, patch.shape[-2:], flags=cv2.INTER_NEAREST - ) +# class RotateScale(SpatialTransform): +# def __init__(self, probability: float=DEFAULT_PROBABILITY, +# max_scaling: float=1.3): +# super().__init__(probability=probability) +# raise NotImplementedError('this augmentation is not working correctly yet. The image and label could have patchy effect.We are not sure why.') +# self.max_scaling = max_scaling + +# def transform(self, patch: Patch): +# # because we do not know the rotation angle +# # we should apply the shrinking first +# patch.apply_delayed_shrink_size() + +# # if the rotation is close to diagnal, for example 45 degree +# # the label could be outside the volume and be black! +# # angle = random.choice([0, 90, 180, -90, -180]) + random.randint(-5, 5) +# angle = random.randint(0, 180) +# scale = random.uniform(1.1, self.max_scaling) +# center = patch.center[-2:] +# mat = cv2.getRotationMatrix2D( center, angle, scale ) + +# for batch in range(patch.shape[0]): +# for channel in range(patch.shape[1]): +# for z in range(patch.shape[2]): +# patch.image[batch, channel, z, ...] = cv2.warpAffine( +# patch.image[batch, channel, z, ...], +# mat, patch.shape[-2:], flags=cv2.INTER_LINEAR +# ) +# patch.label[batch, channel, z, ...] = cv2.warpAffine( +# patch.label[batch, channel, z, ...], +# mat, patch.shape[-2:], flags=cv2.INTER_NEAREST +# )