From 579be23db58d3542c215168d84906a2f09a5da2f Mon Sep 17 00:00:00 2001 From: Dok11 Date: Fri, 8 Jan 2021 01:39:28 +0300 Subject: [PATCH] Made generated files in separate files --- README.md | 9 ++++++++- lightweight_gan/cli.py | 6 ++++-- lightweight_gan/lightweight_gan.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2bbfbf5..2d9d133 100644 --- a/README.md +++ b/README.md @@ -152,9 +152,16 @@ Also one flag to use `--multi-gpus` Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If `--load-from` is not specified, will default to the latest. ```bash -$ lightweight_gan --name {name of run} --load-from {checkpoint num} --generate +$ lightweight_gan \ + --name {name of run} \ + --load-from {checkpoint num} \ + --generate \ + --generate-types {types of result, default: [default,ema]} \ + --num-image-tiles {count of image result} ``` +After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}". + You can also generate interpolations ```bash diff --git a/lightweight_gan/cli.py b/lightweight_gan/cli.py index 7d475af..e601e04 100644 --- a/lightweight_gan/cli.py +++ b/lightweight_gan/cli.py @@ -85,6 +85,7 @@ def train_from_folder( save_every = 1000, evaluate_every = 1000, generate = False, + generate_types = ['default', 'ema'], generate_interpolation = False, aug_test = False, attn_res_layers = [32], @@ -131,8 +132,9 @@ def train_from_folder( model = Trainer(**model_args) model.load(load_from) samples_name = timestamped_filename() - model.evaluate(samples_name, num_image_tiles) - print(f'sample images generated at {results_dir}/{name}/{samples_name}') + checkpoint = model.checkpoint_num + dir_result = model.generate(samples_name, num_image_tiles, checkpoint, generate_types) + print(f'sample images generated at {dir_result}') return if generate_interpolation: diff --git a/lightweight_gan/lightweight_gan.py b/lightweight_gan/lightweight_gan.py index 81b8123..ff59eaf 100644 --- a/lightweight_gan/lightweight_gan.py +++ b/lightweight_gan/lightweight_gan.py @@ -1088,6 +1088,36 @@ def evaluate(self, num = 0, num_image_tiles = 8, trunc = 1.0): generated_images = self.generate_truncated(self.GAN.GE, latents) torchvision.utils.save_image(generated_images, str(self.results_dir / self.name / f'{str(num)}-ema.{ext}'), nrow=num_rows) + @torch.no_grad() + def generate(self, num=0, num_image_tiles=4, checkpoint=None, types=['default', 'ema']): + self.GAN.eval() + + latent_dim = self.GAN.latent_dim + dir_name = self.name + str('-generated-') + str(checkpoint) + dir_full = Path().absolute() / self.results_dir / dir_name + ext = self.image_extension + + if not dir_full.exists(): + os.mkdir(dir_full) + + # regular + if 'default' in types: + for i in tqdm(range(num_image_tiles), desc='Saving generated default images'): + latents = torch.randn((1, latent_dim)).cuda(self.rank) + generated_image = self.generate_truncated(self.GAN.G, latents) + path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}.{ext}') + torchvision.utils.save_image(generated_image[0], path, nrow=1) + + # moving averages + if 'ema' in types: + for i in tqdm(range(num_image_tiles), desc='Saving generated EMA images'): + latents = torch.randn((1, latent_dim)).cuda(self.rank) + generated_image = self.generate_truncated(self.GAN.GE, latents) + path = str(self.results_dir / dir_name / f'{str(num)}-{str(i)}-ema.{ext}') + torchvision.utils.save_image(generated_image[0], path, nrow=1) + + return dir_full + @torch.no_grad() def calculate_fid(self, num_batches): from pytorch_fid import fid_score