Skip to content

Commit

Permalink
Merge pull request #40 from Dok11/generator_update
Browse files Browse the repository at this point in the history
Made generated files in separate files
  • Loading branch information
lucidrains authored Jan 8, 2021
2 parents d79b5e1 + 579be23 commit 8913b44
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,16 @@ Also one flag to use `--multi-gpus`
Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If `--load-from` is not specified, will default to the latest.

```bash
$ lightweight_gan --name {name of run} --load-from {checkpoint num} --generate
$ lightweight_gan \
--name {name of run} \
--load-from {checkpoint num} \
--generate \
--generate-types {types of result, default: [default,ema]} \
--num-image-tiles {count of image result}
```

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

```bash
Expand Down
6 changes: 4 additions & 2 deletions lightweight_gan/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def train_from_folder(
save_every = 1000,
evaluate_every = 1000,
generate = False,
generate_types = ['default', 'ema'],
generate_interpolation = False,
aug_test = False,
attn_res_layers = [32],
Expand Down Expand Up @@ -131,8 +132,9 @@ def train_from_folder(
model = Trainer(**model_args)
model.load(load_from)
samples_name = timestamped_filename()
model.evaluate(samples_name, num_image_tiles)
print(f'sample images generated at {results_dir}/{name}/{samples_name}')
checkpoint = model.checkpoint_num
dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types)
print(f'sample images generated at {dir_result}')
return

if generate_interpolation:
Expand Down
30 changes: 30 additions & 0 deletions lightweight_gan/lightweight_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,36 @@ def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0):
generated_images = self.generate_truncated(self.GAN.GE, latents)
torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows)

@torch.no_grad()
def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']):
self.GAN.eval()

latent_dim = self.GAN.latent_dim
dir_name = self.name + str('-generated-') + str(checkpoint)
dir_full = Path().absolute() / self.results_dir / dir_name
ext = self.image_extension

if not dir_full.exists():
os.mkdir(dir_full)

# regular
if 'default' in types:
for i in tqdm(range(num_image_tiles), desc='Saving generated default images'):
latents = torch.randn((1, latent_dim)).cuda(self.rank)
generated_image = self.generate_truncated(self.GAN.G, latents)
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}')
torchvision.utils.save_image(generated_image[0], path, nrow=1)

# moving averages
if 'ema' in types:
for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'):
latents = torch.randn((1, latent_dim)).cuda(self.rank)
generated_image = self.generate_truncated(self.GAN.GE, latents)
path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}')
torchvision.utils.save_image(generated_image[0], path, nrow=1)

return dir_full

@torch.no_grad()
def calculate_fid(self, num_batches):
from pytorch_fid import fid_score
Expand Down

0 comments on commit 8913b44

Please sign in to comment.