-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgenerate_data.py
123 lines (107 loc) · 5.2 KB
/
generate_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# python3.7
"""Generates a collection of images with specified model.
Commonly, this file is used for data preparation. More specifically, before
exploring the hidden semantics from the latent space, user need to prepare a
collection of images. These images can be used for further attribute prediction.
In this way, it is able to build a relationship between input latent codes and
the corresponding attribute scores.
"""
import os.path
import argparse
from collections import defaultdict
import cv2
import numpy as np
from tqdm import tqdm
from models.model_settings import MODEL_POOL
from models.pggan_generator import PGGANGenerator
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
from utils.logger import setup_logger
# _________________________________________________________________________
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Generate images with given model.')
parser.add_argument('-m', '--model_name', type=str, required=True,
choices=list(MODEL_POOL),
help='Name of the model for generation. (required)')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Directory to save the output results. (required)')
parser.add_argument('-i', '--latent_codes_path', type=str, default='',
help='If specified, will load latent codes from given '
'path instead of randomly sampling. (optional)')
parser.add_argument('-n', '--num', type=int, default=1,
help='Number of images to generate. This field will be '
'ignored if `latent_codes_path` is specified. '
'(default: 1)')
parser.add_argument('-s', '--latent_space_type', type=str, default='z',
choices=['z', 'Z', 'w', 'W', 'wp', 'wP', 'Wp', 'WP'],
help='Latent space used in Style GAN. (default: `Z`)')
parser.add_argument('-S', '--generate_style', action='store_true',
help='If specified, will generate layer-wise style codes '
'in Style GAN. (default: do not generate styles)')
parser.add_argument('-I', '--generate_image', action='store_false',
help='If specified, will skip generating images in '
'Style GAN. (default: generate images)')
parser.add_argument('-C', '--generate_class', action='store_false',
help='If specified, will skip generating labels in '
'Style GAN. (default: generate label)')
parser.add_argument('-p', '--img_classifier', type=str, default=None,
help='path to pre-trained image classifier .PKL file for generating labels.')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
logger = setup_logger(args.output_dir, logger_name='generate_data')
logger.info(f'Initializing generator.')
gan_type = MODEL_POOL[args.model_name]['gan_type']
if gan_type == 'pggan':
model = PGGANGenerator(args.model_name, logger)
kwargs = {}
elif gan_type == 'stylegan':
model = StyleGANGenerator(args.model_name, logger)
kwargs = {'latent_space_type': args.latent_space_type}
elif gan_type == 'stylegan2':
model = StyleGAN2Generator(args.model_name, logger)
kwargs = {'latent_space_type': args.latent_space_type}
else:
raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')
logger.info(f'Preparing latent codes.')
if os.path.isfile(args.latent_codes_path):
logger.info(f' Load latent codes from `{args.latent_codes_path}`.')
latent_codes = np.load(args.latent_codes_path)
latent_codes = model.preprocess(latent_codes, **kwargs)
else:
logger.info(f' Sample latent codes randomly.')
latent_codes = model.easy_sample(args.num, **kwargs)
total_num = latent_codes.shape[0]
logger.info(f'Generating {total_num} samples.')
results = defaultdict(list)
pbar = tqdm(total=total_num, leave=False)
for latent_codes_batch in model.get_batch_inputs(latent_codes):
if gan_type == 'pggan':
outputs = model.easy_synthesize(latent_codes_batch)
elif gan_type == 'stylegan' or gan_type == 'stylegan2':
outputs = model.easy_synthesize(latent_codes_batch,
**kwargs,
generate_style=args.generate_style,
generate_image=args.generate_image)
for key, val in outputs.items():
if key == 'image':
for image in val:
save_path = os.path.join(args.output_dir, f'{pbar.n:06d}.jpg')
cv2.imwrite(save_path, image[:, :, ::-1])
pbar.update(1)
else:
results[key].append(val)
if 'image' not in outputs:
pbar.update(latent_codes_batch.shape[0])
if pbar.n % 1000 == 0 or pbar.n == total_num:
logger.debug(f' Finish {pbar.n:6d} samples.')
pbar.close()
logger.info(f'Saving results.')
for key, val in results.items():
save_path = os.path.join(args.output_dir, f'{key}.npy')
np.save(save_path, np.concatenate(val, axis=0))
if __name__ == '__main__':
main()