Skip to content

Commit

Permalink
autoset augmentation probability for user
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 23, 2020
1 parent 62437d8 commit 01c9a0c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train_from_folder(
save_frames = False,
num_image_tiles = 8,
trunc_psi = 0.75,
aug_prob = 0.,
aug_prob = None,
aug_types = ['cutout', 'translation'],
dataset_aug_prob = 0.,
multi_gpus = False,
Expand Down
10 changes: 8 additions & 2 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def __init__(
save_every = 1000,
evaluate_every = 1000,
trunc_psi = 0.6,
aug_prob = 0.,
aug_prob = None,
aug_types = ['translation', 'cutout'],
dataset_aug_prob = 0.,
calculate_fid_every = None,
Expand Down Expand Up @@ -872,6 +872,12 @@ def set_data_src(self, folder):
dataloader = DataLoader(self.dataset, num_workers = NUM_CORES, batch_size = math.ceil(self.batch_size / self.world_size), sampler = sampler, shuffle = not self.is_ddp, drop_last = True, pin_memory = True)
self.loader = cycle(dataloader)

# auto set augmentation prob for user if dataset is detected to be low
num_samples = len(self.dataset)
if not exists(self.aug_prob) and num_samples < 1e5:
self.aug_prob = min(0.5, (1e5 - num_samples) * 3e-6)
print(f'autosetting augmentation probability to {round(self.aug_prob * 100)}%')

def train(self):
assert exists(self.loader), 'You must first initialize the data source with `.set_data_src(<folder of images>)`'
device = torch.device(f'cuda:{self.rank}')
Expand All @@ -888,7 +894,7 @@ def train(self):
image_size = self.GAN.image_size
latent_dim = self.GAN.latent_dim

aug_prob = self.aug_prob
aug_prob = default(self.aug_prob, 0)
aug_types = self.aug_types
aug_kwargs = {'prob': aug_prob, 'types': aug_types}

Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.11.1'
__version__ = '0.11.2'

0 comments on commit 01c9a0c

Please sign in to comment.