Skip to content

Commit

Permalink
negative patch oversampling 추가 (issue #30)
Browse files Browse the repository at this point in the history
  • Loading branch information
4pygmalion committed Jul 19, 2024
1 parent fdfda5b commit d7575e4
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 23 deletions.
102 changes: 102 additions & 0 deletions cosas/datasets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import random
from typing import Tuple, List


import tqdm
import torch
import numpy as np
import albumentations as A
Expand Down Expand Up @@ -138,8 +141,107 @@ def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
return image, mask


class PreAugDataset(Dataset):
def __init__(
self, images, masks, transform: A.Compose | None = None, device="cuda"
):
self.images = images
self.masks = masks
self.transform = transform
self.device = device
self._pre_augmentation()

def _sample_random_size(self, image_shape):
w, h, _ = image_shape

new_w = int(random.uniform(w / 2, w))
new_h = int(random.uniform(h / 2, h))
while new_w >= w or new_h >= h or new_h <= 0 or new_w <= 0:
new_w = int(random.uniform(w / 2, w))
new_h = int(random.uniform(h / 2, h))

return new_w, new_h

def _patch_to_original(self, patch, original_shape):

new_image = np.zeros(original_shape, dtype=np.uint8)
h, w, _ = new_image.shape
ch, cw, _ = patch.shape

# 랜덤 위치 선정
x = np.random.randint(0, w - cw)
y = np.random.randint(0, h - ch)

# 배경에 패치 붙이기
new_image[y : y + ch, x : x + cw, :] = patch

return new_image

def _pre_transform(self, image, mask):
# replace
negative_image = image.copy()
negative_image[np.where(mask == 1)] = 255

new_size = self._sample_random_size(image.shape)
transform = A.Compose([A.RandomCrop(*new_size), A.SafeRotate()])
aug = transform(image=negative_image, mask=mask)
cropped_image = aug["image"]
while len(np.unique(cropped_image)) == 1:
aug = transform(image=negative_image, mask=mask)
cropped_image = aug["image"]
patched_image = self._patch_to_original(cropped_image, image.shape)

return patched_image, np.zeros_like(mask, dtype=np.uint8)

def _pre_augmentation(self, multiple: int = 2):

aug_images = list()
aug_masks = list()
for iter in range(multiple):
for image, mask in tqdm.tqdm(
zip(self.images, self.masks), desc="Pre-augmentation"
):
if len(np.unique(mask)) == 1 and np.unique(mask) == np.array([1]):
continue

patched_image, patched_mask = self._pre_transform(image, mask)
if patched_image.sum() == 0:
continue

aug_images.append(patched_image)
aug_masks.append(patched_mask)

self.images = self.images + aug_images
self.masks = self.masks + aug_masks

return

def __len__(self):
return len(self.images)

def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
image = self.images[idx]
mask = self.masks[idx]

# self.transform: np.ndarray -> Dict[str, Tensor]
if self.transform:
augmented = self.transform(image=image, mask=mask)
image, mask = augmented["image"], augmented["mask"]

# stride가 negative인 경우 처리
if isinstance(mask, torch.Tensor):
mask = torch.from_numpy(mask.numpy().copy())
return image, mask

else:
image = torch.from_numpy(image.copy())
mask = torch.from_numpy(mask.copy())
return image, mask


DATASET_REGISTRY = {
"patch": Patchdataset,
"whole": WholeSizeDataset,
"image_mask": ImageMaskDataset,
"pre_aug": PreAugDataset,
}
2 changes: 1 addition & 1 deletion cosas/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_config() -> argparse.ArgumentParser:
parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
parser.add_argument("--device", type=str, default="cuda", help="Device to use")
parser.add_argument(
"--dataset", type=str, choices=["patch", "image_mask", "wholesize"]
"--dataset", type=str, choices=["patch", "image_mask", "wholesize", "pre_aug"]
)
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
parser.add_argument(
Expand Down
129 changes: 108 additions & 21 deletions debug.ipynb

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion experiments/train_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,11 @@
train_dataset = dataset(
train_images, train_masks, train_transform, device=args.device
)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size)
train_dataloader = DataLoader(
train_dataset, batch_size=args.batch_size, shuffle=True
)
if args.dataset == "pre_aug":
dataset = DATASET_REGISTRY["image_mask"]
val_dataset = dataset(val_images, val_masks, test_transform, device=args.device)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size)
test_dataset = dataset(
Expand Down

0 comments on commit d7575e4

Please sign in to comment.