-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
92 lines (74 loc) · 2.78 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import torch
import random
import argparse
from EncDec import *
from DLA3.EncDec.changerollno import *
from torch.utils.data import DataLoader
P = argparse.ArgumentParser()
P.add_argument("gpu", type=str)
P.add_argument("bonus", type=str)
A = P.parse_args()
if __name__ == "__main__":
Data = DataLoader(dataset=AlteredMNIST(),
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
drop_last=True,
pin_memory=True)
E = Encoder()
D = Decoder()
L = [AELossFn(),
VAELossFn()]
O = torch.optim.Adam(ParameterSelector(E, D), lr=LEARNING_RATE)
print("Training Encoder: {}, Decoder: {} on Modified MNIST dataset in AE training paradigm".format(
E.__class__.__name__,
D.__class__.__name__,
))
AETrainer(Data,
E,
D,
L[0],
O,
A.gpu)
print("Training Encoder: {}, Decoder: {} on Modified MNIST dataset in VAE training paradigm".format(
E.__class__.__name__,
D.__class__.__name__,
))
VAETrainer(Data,
E,
D,
L[1],
O,
A.gpu)
print("AE, VAE Training Complete")
if A.bonus == "T":
CL = CVAELossFn()
CVAE_Trainer(Data,
E,
D,
CL,
O)
else:
print("Bonus Question not done")
AE_Pipeline = AE_TRAINED(gpu=False)
VAE_Pipeline = VAE_TRAINED(gpu=False)
""" For TAs Only """
# TestData = TestMNIST()
# AESSIM, VAESSIM = [], []
# AEPSNR, VAEPSNR = [], []
# for sample, original in TestData:
# AESSIM.append(AE_Pipeline.from_path(sample, original, type="SSIM"))
# VAESSIM.append(VAE_Pipeline.from_path(sample, original, type="SSIM"))
# AEPSNR.append(AE_Pipeline.from_path(sample, original, type="PSNR"))
# VAEPSNR.append(VAE_Pipeline.from_path(sample, original, type="PSNR"))
# print("SSIM Score of AutoEncoder Training paradigm: {}".format(sum(AESSIM)/len(AESSIM)))
# print("SSIM Score of Variational AutoEncoder Training paradigm: {}".format(sum(VAESSIM)/len(VAESSIM)))
# print("PSNR Score of AutoEncoder Training paradigm: {}".format(sum(AEPSNR)/len(AEPSNR)))
# print("PSNR Score of Variational AutoEncoder Training paradigm: {}".format(sum(VAEPSNR)/len(VAEPSNR)))
# if A.bonus == "T":
# Generator = CVAE_Generator()
# for _ in range(24):
# Generator.save_image(digit=random.choice(range(10)), save_path=SAVEPATH)
#
# print("Similarity Score for generated data is: {}".format(GeneratorSimilarity(SAVEPATH)))