Skip to content

Commit

Permalink
default num image tiles to 8 if less than 1024 resolution, else 4
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 9, 2021
1 parent 81769ff commit 7c16bed
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 5 deletions.
11 changes: 10 additions & 1 deletion lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@

import numpy as np

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def cast_list(el):
return el if isinstance(el, list) else [el]

Expand Down Expand Up @@ -93,7 +99,7 @@ def train_from_folder(
antialias = False,
interpolation_num_steps = 100,
save_frames = False,
num_image_tiles = 8,
num_image_tiles = None,
trunc_psi = 0.75,
aug_prob = None,
aug_types = ['cutout', 'translation'],
Expand All @@ -103,6 +109,8 @@ def train_from_folder(
seed = 42,
amp = False
):
num_image_tiles = default(num_image_tiles, 4 if image_size > 512 else 8)

model_args = dict(
name = name,
results_dir = results_dir,
Expand All @@ -113,6 +121,7 @@ def train_from_folder(
disc_output_size = disc_output_size,
antialias = antialias,
image_size = image_size,
num_image_tiles = num_image_tiles,
optimizer = optimizer,
fmap_max = fmap_max,
transparent = transparent,
Expand Down
9 changes: 6 additions & 3 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ def __init__(
optimizer="adam",
latent_dim = 256,
image_size = 128,
num_image_tiles = 8,
fmap_max = 512,
transparent = False,
greyscale = False,
Expand Down Expand Up @@ -763,9 +764,10 @@ def __init__(
assert is_power_of_two(image_size), 'image size must be a power of 2 (64, 128, 256, 512, 1024)'
assert all(map(is_power_of_two, attn_res_layers)), 'resolution layers of attention must all be powers of 2 (16, 32, 64, 128, 256, 512)'

self.optimizer = optimizer
self.latent_dim = latent_dim
self.image_size = image_size
self.num_image_tiles = num_image_tiles

self.latent_dim = latent_dim
self.fmap_max = fmap_max
self.transparent = transparent
self.greyscale = greyscale
Expand All @@ -776,6 +778,7 @@ def __init__(
self.aug_types = aug_types

self.lr = lr
self.optimizer = optimizer
self.ttur_mult = ttur_mult
self.batch_size = batch_size
self.gradient_accumulate_every = gradient_accumulate_every
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def train(self):
self.save(self.checkpoint_num)

if self.steps % self.evaluate_every == 0 or (self.steps % 100 == 0 and self.steps < 20000):
self.evaluate(floor(self.steps / self.evaluate_every))
self.evaluate(floor(self.steps / self.evaluate_every), num_image_tiles = self.num_image_tiles)

if exists(self.calculate_fid_every) and self.steps % self.calculate_fid_every == 0 and self.steps != 0:
num_batches = math.ceil(CALC_FID_NUM_IMAGES / self.batch_size)
Expand Down
2 changes: 1 addition & 1 deletion lightweight_gan/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.16.1'
__version__ = '0.16.2'

0 comments on commit 7c16bed

Please sign in to comment.