-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathtrain_cub.py
86 lines (73 loc) · 2.79 KB
/
train_cub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import torch.optim as optim
import torch.nn as nn
from model import TransZero
from dataset import CUBDataLoader
from helper_func import eval_zs_gzsl
import numpy as np
import wandb
# init wandb from config file
wandb.init(project='TransZero', config='wandb_config/cub_gzsl.yaml')
config = wandb.config
print('Config file from wandb:', config)
# load dataset
dataloader = CUBDataLoader('.', config.device, is_balance=False)
# set random seed
seed = config.random_seed
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
# TransZero model
model = TransZero(config, dataloader.att, dataloader.w2v_att,
dataloader.seenclasses, dataloader.unseenclasses).to(config.device)
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
# main loop
niters = dataloader.ntrain * config.epochs//config.batch_size
report_interval = niters//config.epochs
best_performance = [0, 0, 0, 0]
best_performance_zsl = 0
for i in range(0, niters):
model.train()
optimizer.zero_grad()
batch_label, batch_feature, batch_att = dataloader.next_batch(
config.batch_size)
out_package = model(batch_feature)
in_package = out_package
in_package['batch_label'] = batch_label
out_package = model.compute_loss(in_package)
loss, loss_CE, loss_cal, loss_reg = out_package['loss'], out_package[
'loss_CE'], out_package['loss_cal'], out_package['loss_reg']
loss.backward()
optimizer.step()
# report result
if i % report_interval == 0:
print('-'*30)
acc_seen, acc_novel, H, acc_zs = eval_zs_gzsl(
dataloader, model, config.device, batch_size=config.batch_size)
if H > best_performance[2]:
best_performance = [acc_novel, acc_seen, H, acc_zs]
if acc_zs > best_performance_zsl:
best_performance_zsl = acc_zs
print('iter/epoch=%d/%d | loss=%.3f, loss_CE=%.3f, loss_cal=%.3f, '
'loss_reg=%.3f | acc_unseen=%.3f, acc_seen=%.3f, H=%.3f | '
'acc_zs=%.3f' % (
i, int(i//report_interval),
loss.item(), loss_CE.item(), loss_cal.item(),
loss_reg.item(),
best_performance[0], best_performance[1],
best_performance[2], best_performance_zsl))
wandb.log({
'iter': i,
'loss': loss.item(),
'loss_CE': loss_CE.item(),
'loss_cal': loss_cal.item(),
'loss_reg': loss_reg.item(),
'acc_unseen': acc_novel,
'acc_seen': acc_seen,
'H': H,
'acc_zs': acc_zs,
'best_acc_unseen': best_performance[0],
'best_acc_seen': best_performance[1],
'best_H': best_performance[2],
'best_acc_zs': best_performance_zsl
})