Skip to content

Commit

Permalink
Added early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
sovit-123 committed Jul 29, 2024
1 parent 8221971 commit 3b7fdf1
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
15 changes: 14 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
save_model, save_loss_plot,
show_tranformed_image,
save_mAP, save_model_state, SaveBestModel,
yaml_save, init_seeds
yaml_save, init_seeds, EarlyStopping
)
from utils.logging import (
set_log, coco_log,
Expand Down Expand Up @@ -189,6 +189,13 @@ def parse_opt():
action='store_true',
help='use automatic mixed precision'
)
parser.add_argument(
'--patience',
default=10,
help='number of epochs to wait for when mAP does not increase to \
trigger early stopping',
type=int
)
parser.add_argument(
'--seed',
default=0,
Expand Down Expand Up @@ -400,6 +407,7 @@ def main(args):
scheduler = None

save_best_model = SaveBestModel()
early_stopping = EarlyStopping()

for epoch in range(start_epochs, NUM_EPOCHS):
train_loss_hist.reset()
Expand Down Expand Up @@ -560,6 +568,11 @@ def main(args):
data_configs,
args['model']
)

# Early stopping check.
early_stopping(stats[0])
if early_stopping.early_stop:
break

# Save models to Weights&Biases.
if not args['disable_wandb']:
Expand Down
34 changes: 33 additions & 1 deletion utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,4 +337,36 @@ def yaml_save(file_path=None, data={}):
{k: str(v) if isinstance(v, Path) else v for k, v in data.items()},
f,
sort_keys=False
)
)

class EarlyStopping():
"""
Early stopping to stop the training when the mAP does not improve after
certain epochs.
"""
def __init__(self, patience=10, min_delta=0):
"""
:param patience: how many epochs to wait before stopping mAP
is not improving
:param min_delta: minimum difference between new mAP and old mAP for
new mAP to be considered as an improvement
"""
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_map = None
self.early_stop = False

def __call__(self, map):
if self.best_map == None:
self.best_map = map
elif map - self.best_map > self.min_delta:
self.best_map = map
# reset counter if validation loss improves
self.counter = 0
elif map - self.best_map < self.min_delta:
self.counter += 1
print(f"INFO: Early stopping counter {self.counter} of {self.patience}")
if self.counter >= self.patience:
print('INFO: Early stopping')
self.early_stop = True

0 comments on commit 3b7fdf1

Please sign in to comment.