-
Notifications
You must be signed in to change notification settings - Fork 4
/
train.py
132 lines (110 loc) · 3.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""
This script trains a VAEMario using early stopping.
"""
from time import time
import torch
import torch.optim as optim
from torch.optim import Optimizer
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from vae import VAEMario, load_data
def fit(
model: VAEMario,
optimizer: Optimizer,
data_loader: DataLoader,
device: str,
) -> torch.Tensor:
"""
Runs a training epoch: evaluating the model in
the data provided by the data_loader, computing
the ELBO loss inside the model, and propagating
the error backwards to the parameters.
"""
model.train()
running_loss = 0.0
for (levels,) in data_loader:
levels = levels.to(device)
optimizer.zero_grad()
q_z_given_x, p_x_given_z = model.forward(levels)
loss = model.elbo_loss_function(levels, q_z_given_x, p_x_given_z)
running_loss += loss.item()
loss.backward()
optimizer.step()
return running_loss / len(data_loader)
def test(
model: VAEMario,
test_loader: DataLoader,
device: str,
epoch: int = 0,
) -> torch.Tensor:
"""
Evaluates the current model on the test set,
returning the average loss.
"""
model.eval()
running_loss = 0.0
with torch.no_grad():
for (levels,) in test_loader:
levels.to(device)
q_z_given_x, p_x_given_z = model.forward(levels)
loss = model.elbo_loss_function(levels, q_z_given_x, p_x_given_z)
running_loss += loss.item()
print(f"Epoch {epoch}. Loss in test: {running_loss / len(test_loader)}")
return running_loss / len(test_loader)
def run(
max_epochs: int = 500,
batch_size: int = 64,
lr: int = 1e-3,
save_every: int = None,
overfit: bool = False,
):
"""
Trains a VAEMario on the dataset for the provided hyperparameters.
This training uses early stopping with a patience of 25 epochs,
by which we mean that we maintain the model with lowest test loss
and, if we don't see any improvement on it for 25 epochs in a row,
we stop the training. The model can be forced to overfit if you
pass overfit=True.
"""
# Defining the name of the experiment
timestamp = str(time()).replace(".", "")
comment = f"{timestamp}_mariovae"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Loading the data.
training_tensors, test_tensors = load_data()
# Creating datasets.
dataset = TensorDataset(training_tensors)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_dataset = TensorDataset(test_tensors)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
# Loading the model and optimizer
print("Model:")
vae = VAEMario()
print(vae)
optimizer = optim.Adam(vae.parameters(), lr=lr)
# Training and testing.
print(f"Training experiment {comment}")
best_loss = np.Inf
n_without_improvement = 0
for epoch in range(max_epochs):
print(f"Epoch {epoch + 1} of {max_epochs}.")
_ = fit(vae, optimizer, data_loader, device)
test_loss = test(vae, test_loader, device, epoch)
if test_loss < best_loss:
best_loss = test_loss
n_without_improvement = 0
# Saving the best model so far.
torch.save(vae.state_dict(), f"./models/{comment}_final.pt")
else:
if not overfit:
n_without_improvement += 1
if save_every is not None and epoch % save_every == 0 and epoch != 0:
# Saving the model
print(f"Saving the model at checkpoint {epoch}.")
torch.save(vae.state_dict(), f"./models/{comment}_epoch_{epoch}.pt")
# Early stopping:
if n_without_improvement == 25:
print("Stopping early")
break
if __name__ == "__main__":
run()