From dcaa0ae5287732955b4ce9123aeb5d17b00cfd4c Mon Sep 17 00:00:00 2001 From: Hongyu Yu <74477906+Hongyu-yu@users.noreply.github.com> Date: Thu, 5 Sep 2024 09:24:33 +0800 Subject: [PATCH] Fix bug about undefined swa Bug will come out when swa is not used at the end of training. ``` mace/tools/train.py", line 262, in train if patience_counter >= patience and epoch < swa.start: AttributeError: 'NoneType' object has no attribute 'start' ``` --- mace/tools/train.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/mace/tools/train.py b/mace/tools/train.py index b38bce16..5af25456 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -259,16 +259,17 @@ def train( if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break if save_all_checkpoints: param_context = ( ema.average_parameters()