diff --git a/mace/tools/train.py b/mace/tools/train.py index 1ab86f82..8e293bee 100644 --- a/mace/tools/train.py +++ b/mace/tools/train.py @@ -279,16 +279,17 @@ def train( if rank == 0: if valid_loss >= lowest_loss: patience_counter += 1 - if patience_counter >= patience and epoch < swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" - ) - epoch = swa.start - elif patience_counter >= patience and epoch >= swa.start: - logging.info( - f"Stopping optimization after {patience_counter} epochs without improvement" - ) - break + if patience_counter >= patience: + if swa is not None and epoch < swa.start: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement and starting Stage Two" + ) + epoch = swa.start + else: + logging.info( + f"Stopping optimization after {patience_counter} epochs without improvement" + ) + break if save_all_checkpoints: param_context = ( ema.average_parameters()