-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_celebA.py
111 lines (83 loc) · 3.47 KB
/
main_celebA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# main.py
import sys
import traceback
import torch
import random
import config
import utils
from model import Model
from train_celebA import Trainer
from test_celebA import Tester
# from dataloader import Dataloader
from dataloader import PrivacyDataLoader
from checkpoints import Checkpoints
def main():
# parse the arguments
args = config.parse_args()
if (args.ngpu > 0 and torch.cuda.is_available()):
device = "cuda:0"
else:
device = "cpu"
alpha = args.alpha
args.device = torch.device(device)
m = 0
for k in range(args.niters):
dataloader = PrivacyDataLoader(args)
for lam in alpha:
if args.save_results:
utils.saveargs(args)
random.seed(args.manual_seed + k)
torch.manual_seed(args.manual_seed + k)
# initialize the checkpoint class
checkpoints = Checkpoints(args)
# Create Model
models = Model(args)
model, criterion, evaluation = models.setup(checkpoints)
loaders_train = dataloader.create("Train")
loaders_test = dataloader.create("Test")
trainer_train = Trainer (args, model, criterion, evaluation, lam, k, encoder=True)
tester_train = Tester (args, model, criterion, evaluation, lam, k, encoder=True)
trainer_test = Trainer (args, model, criterion, evaluation, lam, k, encoder=False)
tester_test = Tester (args, model, criterion, evaluation, lam, k, encoder=False)
loss_best = 1e10
step = 0
for epoch in range(int(args.nepochs_e)):
lam1 = lam
print('\nEpoch %d/%d\n' % (epoch + 1, args.nepochs_e))
loss_train_E = trainer_train.train(epoch, loaders_train, lam1, args.reg)
with torch.no_grad():
loss_test_E = tester_train.test(epoch, loaders_test, lam1, args.reg)
if loss_best > loss_test_E['Loss']:
loss_best = loss_test_E['Loss']
if args.save_results:
checkpoints.save(k, lam, 'Encoder', model['Encoder'])
loss_best_t = 1e10
loss_best_a = 1e10
for epoch in range(int(args.nepochs)):
# print('\nEpoch %d/%d\n' % (epoch + 1, args.nepochs))
loss_train = trainer_test.train(epoch, loaders_train, lam, args.reg)
# A_train[m] = loss_train['Accuracy_Adversary']
# T_train[m] = loss_train['Accuracy_Target']
with torch.no_grad():
loss_test = tester_test.test(epoch, loaders_test, lam, args.reg)
if loss_best_a > loss_test['Loss_Adversary']:
loss_best_a = loss_test['Loss_Adversary']
if args.save_results:
checkpoints.save(k, lam, 'Adv', model['Adversary'])
if loss_best_t > loss_test['Loss_Target']:
loss_best_t = loss_test['Loss_Target']
if args.save_results:
checkpoints.save(k, lam, 'tgt', model['Target'])
m = m+1
if __name__ == "__main__":
utils.setup_graceful_exit()
try:
main()
except (KeyboardInterrupt, SystemExit):
# do not print stack trace when ctrl-c is pressed
pass
except Exception:
traceback.print_exc(file=sys.stdout)
finally:
traceback.print_exc(file=sys.stdout)
utils.cleanup()