-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.py
117 lines (85 loc) · 4.01 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import copy
import numpy as np
import torch
import tqdm
from evaluation import compute_metrics
from loss import focal_loss
def beyond_acc(y_true, y_pred):
tp = (y_true * y_pred).sum().to(torch.float32)
fp = ((1 - y_true) * y_pred).sum().to(torch.float32)
fn = (y_true * (1 - y_pred)).sum().to(torch.float32)
epsilon = 1e-7
train_prec = tp / (tp + fp + epsilon)
train_recall = tp / (tp + fn + epsilon)
train_f1 = 2 * (train_prec * train_recall) / (train_prec + train_recall + epsilon)
return train_prec, train_recall, train_f1
class EarlyStopping:
def __init__(self, patience=3, min_delta=0.0):
self.patience = patience # number of times to allow for no improvement before stopping the execution
self.min_delta = min_delta # the minimum change to be counted as improvement
self.counter = 0 # count the number of times the validation accuracy not improving
self.min_val_loss = -np.inf
self.best_model = None
# return True when validation loss is not decreased by the `min_delta` for `patience` times
def early_stop_check(self, val_loss, model):
if (val_loss + self.min_delta) > self.min_val_loss:
self.min_val_loss = val_loss
self.best_model = copy.deepcopy(model.state_dict())
self.counter = 0 # reset the counter if validation loss decreased at least by min_delta
elif (val_loss + self.min_delta) <= self.min_val_loss:
self.counter += 1 # increase the counter if validation loss is not decreased by the min_delta
if self.counter >= self.patience:
return True
return False
def train_model(model, optimizer, X_train, y_train, X_val, y_val, train_prefixes=None, val_prefixes=None, batch_size=64, epochs=50, alpha=0.9):
early_stopping = EarlyStopping(patience=3, min_delta=0.0)
training_loader = torch.utils.data.DataLoader(
range(len(X_train)),
batch_size=batch_size,
shuffle=True,
)
for t in range(epochs):
curr_loss = 0
curr_acc = 0
curr_f1 = 0
model.train()
model.model.eval()
with tqdm.tqdm(total=int(len(X_train) / batch_size)) as pbar:
for i, indices in enumerate(training_loader):
optimizer.zero_grad()
X_batch = X_train[indices]
y_batch = y_train[indices]
if train_prefixes is not None:
batch_prefixes = train_prefixes[indices]
else:
batch_prefixes = None
# Forward pass: Compute predicted y by passing x to the model
y_pred = model(X_batch, prefixes=batch_prefixes, training=True)
loss = focal_loss(y_pred, y_batch, alpha=alpha, reduction="mean") # higher weight for positives
loss += model.reg()
loss.backward()
optimizer.step()
y_labels = (y_pred > 0.5).to(torch.float32)
train_acc = torch.sum(y_labels == y_batch) / batch_size * 100
_, _, train_f1 = beyond_acc(y_batch, y_labels)
curr_loss += loss.item()
curr_acc += train_acc
curr_f1 += train_f1
print(str(curr_loss / (i + 1)))
pbar.set_postfix(
training_loss=str(curr_loss / (i + 1)),
train_f1=str(curr_f1 / (i + 1)),
curr_acc=str(curr_acc / (i + 1)),
)
pbar.update(1)
model.eval()
if val_prefixes is None:
val_prefixes = None
with torch.no_grad():
y_val_pred = model(X_val, prefixes=val_prefixes, training=False)
val_results = compute_metrics(y_val.detach().cpu().numpy(), y_val_pred.detach().cpu().numpy())
if early_stopping.early_stop_check(val_results["F1 Score"], model):
print(f"Early stopping performed on epoch {t}")
break
model.load_state_dict(early_stopping.best_model)
return model