-
Notifications
You must be signed in to change notification settings - Fork 29
/
Train.py
43 lines (32 loc) · 1.04 KB
/
Train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from Dataloaders import TrainDataset
from Model import SiameseConvNet, ContrastiveLoss, distance_metric
from torch.optim import RMSprop, Adam
from torch.utils.data import DataLoader
import numpy as np
from torch import save
model = SiameseConvNet()
criterion = ContrastiveLoss()
optimizer = Adam(model.parameters())
train_dataset = TrainDataset()
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
def checkpoint(epoch):
file_path = "Models/model_epoch_%d" % epoch
with open(file_path, 'wb') as f:
save(model.state_dict(), f)
def train(epoch):
total_loss = 0
for batch_index, data in enumerate(train_loader):
A = data[0]
B = data[1]
optimizer.zero_grad()
label = data[2].float()
f_A, f_B = model.forward(A, B)
loss = criterion(f_A, f_B, label)
total_loss += loss.item()
print('Epoch {}, batch {}, loss={}'.format(epoch, batch_index, loss.item()))
loss.backward()
optimizer.step()
print('Average epoch loss={}'.format(total_loss / (len(train_dataset) // 16)))
for e in range(1, 21):
train(e)
checkpoint(e)