-
Notifications
You must be signed in to change notification settings - Fork 0
/
simple_aug.py
40 lines (32 loc) · 1018 Bytes
/
simple_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import albumentations as A
from albumentations.pytorch import ToTensorV2
def train_augment(size):
aug = A.Compose([
A.Resize(size, size),
A.RandomResizedCrop(size, size), #, scale=(0.8, 1.0), ratio=(0.8, 1.2)),
A.Transpose(p=0.5),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.ShiftScaleRotate(p=0.25),
A.RandomRotate90(p=0.25),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0,
p=1.0
),
ToTensorV2(p=1.0),
], p=1.)
return aug
def val_augment(size):
aug = A.Compose([
A.Resize(size, size),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
max_pixel_value=255.0,
p=1.0
),
ToTensorV2(p=1.0),
], p=1.)
return aug