-
Notifications
You must be signed in to change notification settings - Fork 2
/
cli.py
417 lines (342 loc) · 20.8 KB
/
cli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
"""
Package for launching training and validation flows
"""
import numpy as np
import random
from typing_extensions import Annotated
from pathlib import Path
from matplotlib import pyplot
# Internal imports
from src.logging import create_logger, VerboseMode
from src.datasets.mnist import MNISTLoader
from src.display.pca import PCAPlotter
from src.utils.color import lighten_color
from src.embeddings.training import EmbeddingTrainer
from src.embeddings.hyperparameters import EmbeddingHyperparameters
from src.embeddings.models.mnist import MNISTEmbeddingModel
from src.generator.hyperparameters import GeneratorHyperparameters
from src.generator.models.mnist import MNISTGeneratorModel
from src.generator.training import GeneratorTrainer
from src.evaluation.generator import GeneratorEvaluator
from datasets import DatasetHandler
import tensorflow_io as tfio
# Creating a CLI client
import typer
app = typer.Typer(help=('CLI for interacting with',
'implementations of models described',
'in the research paper'))
@app.command()
def train_embedding_model(
dataset: Annotated[int, typer.Option(help='Dataset to use. 0 for MNIST, 1 for LFW')] = 0,
hyperparams_path: Annotated[Path, typer.Option(help='Path to the hyperparameters file')] = Path('./hyperparams_embedding.json'),
model_save_path: Annotated[Path, typer.Option(help='Path to the model file')] = None,
history_path: Annotated[Path, typer.Option(help='Path to the history file')] = None,
pca_save_path: Annotated[Path, typer.Option(help='Path to the PCA plot file')] = None,
verbose: Annotated[int, typer.Option(help='Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug')] = 1,
) -> None:
"""
Train the embedding model and save it. After the training is successful, the PCA
plot is generated and saved together with the model's weight and history.
Parameters:
- hyperparams_path (str): Path to the hyperparameters file. Defaults to './hyperparams_embedding.json'
- model_save_path (str): Path to the model file. Defaults to './models/embedding/v{hyperparams.meta.version}.{hyperparams.meta.subversion}'
- history_path (str): Path to the history file. Defaults to './images/embedding/v{hyperparams.meta.version}.{hyperparams.meta.subversion}/history.png'
- pca_save_path (str): Path to the PCA plot file. Defaults to './images/embedding/v{hyperparams.meta.version}.{hyperparams.meta.subversion}/pca.png'
- verbose (int): Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug. Defaults to 1
"""
if DatasetHandler(dataset) != DatasetHandler.MNIST:
raise NotImplementedError('Datasets except for MNIST are not supported yet')
verbose = VerboseMode(verbose)
logger = create_logger(verbose)
# Getting the dataset
logger.info('Loading the MNIST dataset...')
mnist_loader = MNISTLoader(expand_dims=False)
logger.info('Successfully loaded the MNIST dataset')
if verbose == VerboseMode.DEBUG:
logger.info('Showing example images from the MNIST dataset...')
mnist_loader.show_examples()
# Getting hyperparameters
logger.info('Loading hyperparameters for the embedding model...')
hyperparams = EmbeddingHyperparameters(json_path=hyperparams_path)
# Creating the embedding model
logger.info('Successfully loaded the hyperparameters.')
embedding_model = MNISTEmbeddingModel(hyperparams)
logger.info('Using the following embedding model:')
embedding_model.summary()
logger.info('Launching the trainer...')
trainer = EmbeddingTrainer(logger, embedding_model, mnist_loader, hyperparams)
# Setting the save paths if not provided
if model_save_path is None:
model_save_path = Path(f'./models/embedding/v{hyperparams.meta.version}.{hyperparams.meta.subversion}')
if history_path is None:
history_path = Path(f'./images/embedding/v{hyperparams.meta.version}.{hyperparams.meta.subversion}/history.png')
if pca_save_path is None:
pca_save_path = Path(f'./images/embedding/v{hyperparams.meta.version}.{hyperparams.meta.subversion}/pca.png')
# Creating the directories if they do not exist
model_save_path.mkdir(parents=True, exist_ok=True)
history_path.parent.mkdir(parents=True, exist_ok=True)
pca_save_path.parent.mkdir(parents=True, exist_ok=True)
# Training the model
trainer.train(
model_save_path=model_save_path,
history_save_path=history_path,
pca_save_path=pca_save_path)
hyperparams.save(model_save_path / 'hyperparams.json')
logger.info('Loader successfully finished the training process. Exiting...')
@app.command()
def train_generator_model(
embedding_model_path: Annotated[Path, typer.Option(help='Path to the embedding model file')] = None,
dataset: Annotated[int, typer.Option(help='Dataset to use. 0 for MNIST, 1 for LFW')] = 0,
hyperparams_path: Annotated[Path, typer.Option(help='Path to the hyperparameters file')] = Path('./hyperparams_generator.json'),
model_save_path: Annotated[Path, typer.Option(help='Path to the model file')] = None,
history_path: Annotated[Path, typer.Option(help='Path to the history file')] = None,
image_base_path: Annotated[Path, typer.Option(help='Path where to save images')] = None,
verbose: Annotated[int, typer.Option(help='Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug')] = 1,
) -> None:
"""
Train the generator model and save it. After the training is successful, the
example images are generated and saved together with the model's weight and history.
Parameters:
- embedding_model_path (str): Path to the embedding model file
- hyperparams_path (str): Path to the hyperparameters file. Defaults to './hyperparams_generator.json'
- model_save_path (str): Path to the model file. Defaults to './models/generator/v{hyperparams.meta.version}.{hyperparams.meta.subversion}'
- history_path (str): Path to the history file. Defaults to './images/generator/v{hyperparams.meta.version}.{hyperparams.meta.subversion}/history.png'
- image_base_path (str): Path where to save images. Defaults to './images/generator/v{hyperparams.meta.version}.{hyperparams.meta.subversion}'
- verbose (int): Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug. Defaults to 1
"""
verbose = VerboseMode(verbose)
logger = create_logger(verbose)
# Getting hyperparameters
logger.info('Loading hyperparameters for the generator model...')
hyperparams = GeneratorHyperparameters(json_path=hyperparams_path)
logger.info('Successfully loaded the generator hyperparameters.')
dataset_handler = DatasetHandler(dataset)
logger.info('Preparing the dataset loader...')
loader = dataset_handler.dataset_loader(grayscale=hyperparams.grayscale)
logger.info('Successfully prepared the dataset loader')
# Initializing the generator model
logger.info('Loading the generator model...')
generator_model = dataset_handler.generator_model(grayscale=hyperparams.grayscale)
logger.info('Using the following generator model:')
if verbose == VerboseMode.DEBUG:
logger.info('Its summary:')
generator_model.summary()
# Loading the embedding model
logger.info('Loading the embedding model...')
embedding_model = dataset_handler.pretrained_embedding_model(embedding_model_path, trainable=False)
logger.info('Successfully loaded the embedding model.')
if verbose == VerboseMode.DEBUG:
logger.info('Its summary:')
embedding_model._model.summary()
logger.info('Setting the trainer...')
trainer = GeneratorTrainer(logger=logger,
generator_model=generator_model,
embedding_model=embedding_model,
dataset_loader=loader,
hyperparams=hyperparams)
# Setting the save paths if not provided
if model_save_path is None:
model_save_path = Path(f'./models/generator/{hyperparams.meta.dataset}/v{hyperparams.meta.version}.{hyperparams.meta.subversion}')
if history_path is None:
history_path = Path(f'./images/generator/{hyperparams.meta.dataset}/v{hyperparams.meta.version}.{hyperparams.meta.subversion}/history.png')
if image_base_path is None:
image_base_path = Path(f'./images/generator/{hyperparams.meta.dataset}/v{hyperparams.meta.version}.{hyperparams.meta.subversion}')
# Creating the directories if they do not exist
model_save_path.mkdir(parents=True, exist_ok=True)
history_path.parent.mkdir(parents=True, exist_ok=True)
image_base_path.parent.mkdir(parents=True, exist_ok=True)
# Training the model
trainer.train(
model_save_path=model_save_path,
history_save_path=history_path,
image_save_base_path=image_base_path,
grayscale=hyperparams.grayscale)
hyperparams.save(model_save_path / 'hyperparams.json')
logger.info('Loader successfully finished the training process. Exiting...')
@app.command()
def show_pca_comparison(
generator_model_path: Annotated[Path, typer.Option(help='Path to the generator model file')],
embedding_model_path: Annotated[Path, typer.Option(help='Path to the embedding model file')] = None,
dataset: Annotated[int, typer.Option(help='Dataset to use. 0 for MNIST, 1 for LFW')] = 0,
pca_save_path: Annotated[Path, typer.Option(help='Path to the PCA plot file')] = None,
classes_to_display: Annotated[int, typer.Option(help='Number of classes to display. Defaults to 3')] = 3,
verbose: Annotated[int, typer.Option(help='Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug')] = 1,
) -> None:
"""
This command does several things:
1. Loads the dataset.
2. Find embeddings of real images from the dataset itself.
3. Find embeddings of generated images from the generator model.
4. Plot the PCA graph of both embeddings.
Arguments:
- embedding_model_path (str): Path to the embedding model file
- generator_model_path (str): Path to the generator model file
- dataset (int): Dataset to use. 0 for MNIST, 1 for LFW
- pca_save_path (str): Path to the PCA plot file. Defaults to './images/evaluation/pca_embedding_{embedding_hyperparams.meta.version}.{embedding_hyperparams.meta.subversion}_generator_{generator_hyperparams.meta.version}.{generator_hyperparams.meta.subversion}.png'
- classes_to_display (int): Number of classes to display. Defaults to 3
- verbose (int): Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug. Defaults to 1.
"""
verbose = VerboseMode(verbose)
logger = create_logger(verbose)
# Loading the hyperparameters
generator_hyperparams = GeneratorHyperparameters(generator_model_path / 'hyperparams.json')
# Getting the dataset
dataset_handler = DatasetHandler(dataset)
logger.info(f'Loading the {dataset_handler.as_str()} dataset...')
loader = dataset_handler.dataset_loader(grayscale=generator_hyperparams.grayscale)
logger.info(f'Successfully loaded the {dataset_handler.as_str()} dataset')
# Loading the embedding model
logger.info('Loading the embedding model...')
embedding_model = dataset_handler.pretrained_embedding_model(embedding_model_path, trainable=False)
logger.info('Successfully loaded the embedding model.')
if verbose == VerboseMode.DEBUG:
logger.info('Its summary:')
embedding_model._model.summary()
# Loading the generator model
logger.info('Loading the generator model...')
generator_model = MNISTGeneratorModel.from_path(generator_model_path, trainable=False)
logger.info('Successfully loaded the generator model.')
if verbose == VerboseMode.DEBUG:
logger.info('Its summary:')
generator_model.summary()
# Setting the save paths if not provided
if pca_save_path is None:
embedding_version = 'facenet'
generator_version = f'{generator_hyperparams.meta.version}.{generator_hyperparams.meta.subversion}'
if embedding_model_path is not None:
embedding_hyperparams = EmbeddingHyperparameters(embedding_model_path / 'hyperparams.json')
embedding_version = f'{embedding_hyperparams.meta.version}.{embedding_hyperparams.meta.subversion}'
pca_save_path = Path(f'./images/evaluation/{dataset_handler.as_str()}/pca_embedding_{embedding_version}_generator_{generator_version}.png')
# Loading the dataset
(X_test, y_test), _ = loader.get()
y_uniques = np.unique(y_test)
y_uniques = sorted(y_uniques, key=lambda x: len(X_test[y_test == x]))
y_selected = y_uniques[-classes_to_display:]
X_selected = [X_test[y_test == label] for label in y_selected]
num_selected = [len(X_selected[i]) for i in range(len(X_selected))]
X_selected = np.array([item for batch in X_selected for item in batch], dtype=np.float32)
# Finding embeddings of real images
y_batches_real = [[f'{label} (Real)'] * num_selected[i] for i, label in enumerate(y_selected)]
y_batches_real = [item for batch in y_batches_real for item in batch]
X_real = embedding_model.raw.predict(X_selected)
# Finding embeddings of generated images
y_predicted_batches = [[f'{label} (Generated)'] * num_selected[i] for i, label in enumerate(y_selected)]
y_predicted_batches = [item for batch in y_predicted_batches for item in batch]
X_generated_images = generator_model.raw.predict(X_selected)
X_generated = embedding_model.raw.predict(X_generated_images)
# Picking colors
cmap = pyplot.cm.rainbow(np.linspace(0, 1, classes_to_display))
color_pairs = [(cmap[i], lighten_color(cmap[i], 0.5)) for i in range(classes_to_display)]
colors_to_display = [color for pair in color_pairs for color in pair]
# Applying PCA
logger.info('Launching PCA...')
pca = PCAPlotter(
X=np.array([*X_real, *X_generated], dtype=np.float64),
y=[*y_batches_real, *y_predicted_batches],
colors_to_display=colors_to_display
)
pca.plot(save_path=pca_save_path)
logger.info('Successfully launched PCA. Plotting...')
@app.command()
def analyze_generator_distances(
generator_model_path: Annotated[Path, typer.Option(help='Path to the generator model file')],
embedding_model_path: Annotated[Path, typer.Option(help='Path to the embedding model file')] = None,
dataset: Annotated[int, typer.Option(help='Dataset to use. 0 for MNIST, 1 for LFW')] = 0,
verbose: Annotated[int, typer.Option(help='Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug')] = 1,
) -> None:
"""
Evaluates the generator model by:
- Taking a random pair of real images with the same label and evaluating the distance between them
- Taking a random pair of generated images with the same label and evaluating the distance between them
- Taking a random pair of real and generated images with the same label and evaluating the distance between them
Prints the results using rich library.
Arguments:
- embedding_model_path (str): Path to the embedding model file
- generator_model_path (str): Path to the generator model file
- dataset (int): Dataset to use. 0 for MNIST, 1 for LFW
- verbose (int): Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug. Defaults to 1.
"""
verbose = VerboseMode(verbose)
logger = create_logger(verbose)
dataset_handler = DatasetHandler(dataset)
# Getting the dataset
logger.info(f'Loading the {dataset_handler.as_str()} dataset...')
loader = dataset_handler.dataset_loader(grayscale=False)
logger.info(f'Successfully loaded the {dataset_handler.as_str()} dataset')
# Loading the embedding model
logger.info('Loading the embedding model...')
embedding_model = dataset_handler.pretrained_embedding_model(embedding_model_path, trainable=False)
logger.info('Successfully loaded the embedding model.')
if verbose == VerboseMode.DEBUG:
logger.info('Its summary:')
embedding_model.summary()
# Loading the generator model
logger.info('Loading the generator model...')
generator_model = dataset_handler.pretrained_generator_model(generator_model_path, trainable=False)
logger.info('Successfully loaded the generator model. Its summary:')
if verbose == VerboseMode.DEBUG:
logger.info('Its summary:')
generator_model.summary()
# Creating an evaluator
logger.info('Creating an evaluator...')
# We are explicitly ignoring hyperparameters here because we do not need them
evaluator = GeneratorEvaluator(generator_model, embedding_model, loader, hyperparams=None)
evaluator.evaluate_image_distances()
logger.info('Successfully evaluated the generator model. Exiting...')
@app.command()
def analyze_generator_roc(
generator_model_path: Annotated[Path, typer.Option(help='Path to the generator model file')],
embedding_model_path: Annotated[Path, typer.Option(help='Path to the embedding model file')] = None,
dataset: Annotated[int, typer.Option(help='Dataset to use. 0 for MNIST, 1 for LFW')] = 0,
roc_image_path: Annotated[Path, typer.Option(help='Path to save ROC curve in')] = None,
classes_to_test: Annotated[int, typer.Option(help='Number of classes to test with')] = 3,
verbose: Annotated[int, typer.Option(help='Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug')] = 1,
) -> None:
"""
This command does several things:
1. Loads the dataset.
2. Find embeddings of real images from the dataset itself.
3. Find embeddings of generated images from the generator model.
4. Evaluate the generator model by calculating the ROC curve.
Arguments:
- generator_model_path (Path): Path to the generator model file
- embedding_model_path (Path, optional): Path to the embedding model file. Defaults to None
- dataset (int): Dataset to use. 0 for MNIST, 1 for LFW
- roc_image_path(Path, optional): Path to save ROC curve in. Defaults to None
- classes_to_test (int, optional): Number of classes to test authentication system with
- verbose (int): Whether to print the logs. 0 to set WARNING level only, 1 for INFO, 2 for showing model summary and debug. Defaults to 1.
"""
verbose = VerboseMode(verbose)
logger = create_logger(verbose)
# Getting the dataset
dataset_handler = DatasetHandler(dataset)
logger.info(f'Loading the {dataset_handler.as_str()} dataset...')
loader = dataset_handler.dataset_loader(grayscale=False)
logger.info(f'Successfully loaded the {dataset_handler.as_str()} dataset')
# Loading the embedding model
logger.info('Loading the embedding model...')
embedding_model = dataset_handler.pretrained_embedding_model(embedding_model_path, trainable=False)
logger.info('Successfully loaded the embedding model. Its summary:')
embedding_model.summary()
# Loading the generator model
logger.info('Loading the generator model...')
generator_model = dataset_handler.pretrained_generator_model(generator_model_path, trainable=False)
logger.info('Successfully loaded the generator model. Its summary:')
generator_model.summary()
# Setting the ROC save path if not provided
if roc_image_path is None:
generator_hyperparams = GeneratorHyperparameters(generator_model_path / 'hyperparams.json')
generator_version = f'{generator_hyperparams.meta.version}.{generator_hyperparams.meta.subversion}'
embedding_version = 'facenet'
if embedding_model_path is not None:
embedding_hyperparams = EmbeddingHyperparameters(embedding_model_path / 'hyperparams.json')
embedding_version = f'{embedding_hyperparams.meta.version}.{embedding_hyperparams.meta.subversion}'
roc_image_path = Path(f'./images/evaluation/roc_embedding_{embedding_version}_generator_{generator_version}.png')
# Creating an evaluator
logger.info('Creating an evaluator...')
# We are explicitly ignoring hyperparameters here because we do not need them
evaluator = GeneratorEvaluator(generator_model, embedding_model, loader, hyperparams=None)
evaluator.build_roc(roc_save_path=roc_image_path, classes_to_test=classes_to_test)
logger.info('Successfully evaluated the generator model. Exiting...')
if __name__ == '__main__':
app()