-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_forward.py
114 lines (94 loc) · 4.25 KB
/
train_forward.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import argparse
import torch
import wandb
from utils import get_device, get_parameter_count
from dataset import load_symbol_dataset
import models
parser = argparse.ArgumentParser("Train symbol forward model")
parser.add_argument("-i", help="Wandb run id", type=str)
parser.add_argument("-n", help="Number of hidden units", type=int)
parser.add_argument("-l", help="Number of layers", type=int)
parser.add_argument("-e", help="Number of epochs", type=int)
parser.add_argument("-b", help="Batch size", type=int)
parser.add_argument("-lr", help="Learning rate", type=float)
args = parser.parse_args()
run = wandb.init(entity="colorslab", project="active_exploration", resume="must", id=args.i)
device = get_device()
wandb.config.update({"device": device}, allow_val_change=True)
input_dim = run.config["latent_dim"] + run.config["action_dim"]
model = models.SymbolForward(input_dim=input_dim, hidden_dim=args.n,
output_dim=run.config["latent_dim"], num_layers=args.l,
num_heads=run.config["n_attention_heads"]).to(device)
print(model)
print(f"Number of parameters: {get_parameter_count(model)}")
wandb.config.update({"forward_model":
{"hidden_unit": args.n,
"layer": args.l,
"epoch": args.e,
"batch_size": args.b,
"learning_rate": args.lr}}, allow_val_change=True)
save_path = os.path.join(run.config["save_folder"], "symbol_forward.pt")
if not os.path.exists(run.config["save_folder"]):
os.makedirs(run.config["save_folder"])
train_set = load_symbol_dataset("train", run, device)
val_set = load_symbol_dataset("val", run, device)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.b, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.b, shuffle=False)
criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
optimizer = torch.optim.Adam(lr=args.lr, params=model.parameters())
for e in range(args.e):
train_obj_loss = 0.0
train_rel_loss = 0.0
for zo_i, zr_i, a, zo_f, zr_f, m in train_loader:
zo_i = zo_i.float()
zr_i = zr_i.float()
a = a.float()
zo_f = zo_f.float()
zr_f = zr_f.float()
m = m.float()
zi_cat = torch.cat([zo_i, a], dim=-1)
zo_f_bar, zo_r_bar = model(zi_cat, zr_i)
m = m.unsqueeze(2)
m_rel = (m @ m.permute(0, 2, 1)).unsqueeze(1)
obj_loss = (criterion(zo_f_bar, zo_f) * m).sum(dim=[1, 2]).mean()
rel_loss = (criterion(zo_r_bar, zr_f) * m_rel).sum(dim=[1, 2, 3]).mean()
loss = obj_loss + rel_loss
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
optimizer.step()
train_obj_loss += obj_loss.item()
train_rel_loss += rel_loss.item()
train_obj_loss = train_obj_loss / len(train_loader)
train_rel_loss = train_rel_loss / len(train_loader)
val_obj_loss = 0.0
val_rel_loss = 0.0
with torch.no_grad():
for zo_i, zr_i, a, zo_f, zr_f, m in val_loader:
zo_i = zo_i.float()
zr_i = zr_i.float()
a = a.float()
zo_f = zo_f.float()
zr_f = zr_f.float()
m = m.float()
zi_cat = torch.cat([zo_i, a], dim=-1)
zo_f_bar, zo_r_bar = model(zi_cat, zr_i)
m = m.unsqueeze(2)
m_rel = (m @ m.permute(0, 2, 1)).unsqueeze(1)
obj_loss = (criterion(zo_f_bar, zo_f) * m).sum(dim=[1, 2]).mean()
rel_loss = (criterion(zo_r_bar, zr_f) * m_rel).sum(dim=[1, 2, 3]).mean()
val_obj_loss += obj_loss.item()
val_rel_loss += rel_loss.item()
val_obj_loss = val_obj_loss / len(val_loader)
val_rel_loss = val_rel_loss / len(val_loader)
wandb.log({"train_obj_loss": train_obj_loss,
"train_rel_loss": train_rel_loss,
"val_obj_loss": val_obj_loss,
"val_rel_loss": val_rel_loss})
print(f"Epoch={e}, Train obj loss={train_obj_loss:.5f}, Train rel loss={train_rel_loss:.5f}, "
f"Val obj loss={val_obj_loss:.5f}, Val rel loss={val_rel_loss:.5f}")
sd = model.eval().cpu().state_dict()
torch.save(sd, save_path)
model.train().to(device)
wandb.save(save_path)