-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmain.py
104 lines (82 loc) · 4.58 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
from tqdm import tqdm
import pickle
import argparse
import time
import torch
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from utils import set_seed, load_model, save, get_model, update_optimizer, get_data
from epoch import train_epoch, val_epoch, test_epoch
from cli import add_all_parsers
def train(args):
set_seed(args, use_gpu=torch.cuda.is_available())
train_loader, val_loader, test_loader, dataset_attributes = get_data(args.root, args.image_size, args.crop_size,
args.batch_size, args.num_workers, args.pretrained)
model = get_model(args, n_classes=dataset_attributes['n_classes'])
criteria = CrossEntropyLoss()
if args.use_gpu:
print('USING GPU')
torch.cuda.set_device(0)
model.cuda()
criteria.cuda()
optimizer = SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.mu, nesterov=True)
# Containers for storing metrics over epochs
loss_train, acc_train, topk_acc_train = [], [], []
loss_val, acc_val, topk_acc_val, avgk_acc_val, class_acc_val = [], [], [], [], []
save_name = args.save_name_xp.strip()
save_dir = os.path.join(os.getcwd(), 'results', save_name)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
print('args.k : ', args.k)
lmbda_best_acc = None
best_val_acc = float('-inf')
for epoch in tqdm(range(args.n_epochs), desc='epoch', position=0):
t = time.time()
optimizer = update_optimizer(optimizer, lr_schedule=args.epoch_decay, epoch=epoch)
loss_epoch_train, acc_epoch_train, topk_acc_epoch_train = train_epoch(model, optimizer, train_loader,
criteria, loss_train, acc_train,
topk_acc_train, args.k,
dataset_attributes['n_train'],
args.use_gpu)
loss_epoch_val, acc_epoch_val, topk_acc_epoch_val, \
avgk_acc_epoch_val, lmbda_val = val_epoch(model, val_loader, criteria,
loss_val, acc_val, topk_acc_val, avgk_acc_val,
class_acc_val, args.k, dataset_attributes, args.use_gpu)
# save model at every epoch
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights.tar'))
# save model with best val accuracy
if acc_epoch_val > best_val_acc:
best_val_acc = acc_epoch_val
lmbda_best_acc = lmbda_val
save(model, optimizer, epoch, os.path.join(save_dir, save_name + '_weights_best_acc.tar'))
print()
print(f'epoch {epoch} took {time.time()-t:.2f}')
print(f'loss_train : {loss_epoch_train}')
print(f'loss_val : {loss_epoch_val}')
print(f'acc_train : {acc_epoch_train} / topk_acc_train : {topk_acc_epoch_train}')
print(f'acc_val : {acc_epoch_val} / topk_acc_val : {topk_acc_epoch_val} / '
f'avgk_acc_val : {avgk_acc_epoch_val}')
# load weights corresponding to best val accuracy and evaluate on test
load_model(model, os.path.join(save_dir, save_name + '_weights_best_acc.tar'), args.use_gpu)
loss_test_ba, acc_test_ba, topk_acc_test_ba, \
avgk_acc_test_ba, class_acc_test = test_epoch(model, test_loader, criteria, args.k,
lmbda_best_acc, args.use_gpu,
dataset_attributes)
# Save the results as a dictionary and save it as a pickle file in desired location
results = {'loss_train': loss_train, 'acc_train': acc_train, 'topk_acc_train': topk_acc_train,
'loss_val': loss_val, 'acc_val': acc_val, 'topk_acc_val': topk_acc_val, 'class_acc_val': class_acc_val,
'avgk_acc_val': avgk_acc_val,
'test_results': {'loss': loss_test_ba,
'accuracy': acc_test_ba,
'topk_accuracy': topk_acc_test_ba,
'avgk_accuracy': avgk_acc_test_ba,
'class_acc_dict': class_acc_test},
'params': args.__dict__}
with open(os.path.join(save_dir, save_name + '.pkl'), 'wb') as f:
pickle.dump(results, f)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
add_all_parsers(parser)
args = parser.parse_args()
train(args)