-
Notifications
You must be signed in to change notification settings - Fork 1
/
earlystop.py
56 lines (49 loc) · 2.1 KB
/
earlystop.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""
This code is part of an adaptation/modification from the original project available at:
https://github.com/peterwang512/CNNDetection
The original code was created by Wang et al. and is used here under the terms of the license
specified in the original project's repository. Any use of this adapted/modified code
must respect the terms of such license.
Adaptations and modifications made by: Daniel Cabanas Gonzalez
Modification date: 08/04/2024
"""
import numpy as np
import torch
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=1, verbose=False, delta=0):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.score_max = -np.Inf
self.delta = delta
def __call__(self, score, model):
if self.best_score is None:
self.best_score = score
self.save_checkpoint(score, model)
elif score < self.best_score - self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(score, model)
self.counter = 0
def save_checkpoint(self, score, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation accuracy increased ({self.score_max:.6f} --> {score:.6f}). Saving model ...')
model.save_networks('best')
self.score_max = score