From 8d4849c8f7e49c83bb2ec50b0d1479d80454410a Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Tue, 5 Dec 2023 14:48:05 +0100 Subject: [PATCH] fix: make --continue_path work again (#131) * fix: make --continue_path work again There were errors when loading models with `--continue_path` because #121 changed https://github.com/coqui-ai/Trainer/blob/47781f58d2714d8139dc00f57dbf64bcc14402b7/trainer/trainer.py#L1924 to save the `model_loss` as `{"train_loss": train_loss, "eval_loss": eval_loss}` instead of just a float. https://github.com/coqui-ai/Trainer/blob/47781f58d2714d8139dc00f57dbf64bcc14402b7/trainer/io.py#L195 still saves a float in `model_loss`, so loading the best model would still work fine. Loading a model via `--restore-path` also works fine because in that case the best loss is reset and not initialised from the saved model. This fix: - changes `save_best_model()` to also save a dict with train and eval loss, so that this is consistent everywhere - ensures that the model loader can handle both float and dict `model_loss` for backwards compatibility - adds relevant test cases * fixup! fix: make --continue_path work again --- tests/test_continue_train.py | 13 ++++++++++++- trainer/io.py | 5 ++++- trainer/trainer.py | 15 +++++++++++---- 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/tests/test_continue_train.py b/tests/test_continue_train.py index 6bd158f..cc6632b 100644 --- a/tests/test_continue_train.py +++ b/tests/test_continue_train.py @@ -14,8 +14,19 @@ def test_continue_train(): continue_path = max(glob.glob(os.path.join(output_path, "*/")), key=os.path.getmtime) number_of_checkpoints = len(glob.glob(os.path.join(continue_path, "*.pth"))) - command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path}" + # Continue training from the best model + command_continue = f"python tests/utils/train_mnist.py --continue_path {continue_path} --coqpit.run_eval_steps=1" run_cli(command_continue) assert number_of_checkpoints < len(glob.glob(os.path.join(continue_path, "*.pth"))) + + # Continue training from the last checkpoint + for best in glob.glob(os.path.join(continue_path, "best_model*")): + os.remove(best) + run_cli(command_continue) + + # Continue training from a specific checkpoint + restore_path = os.path.join(continue_path, "checkpoint_5.pth") + command_continue = f"python tests/utils/train_mnist.py --restore_path {restore_path}" + run_cli(command_continue) shutil.rmtree(continue_path) diff --git a/trainer/io.py b/trainer/io.py index 6e08aea..eb34082 100644 --- a/trainer/io.py +++ b/trainer/io.py @@ -180,7 +180,10 @@ def save_best_model( save_func=None, **kwargs, ): - if current_loss < best_loss: + use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None + if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or ( + not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"] + ): best_model_name = f"best_model_{current_step}.pth" checkpoint_path = os.path.join(out_path, best_model_name) logger.info(" > BEST MODEL : %s", checkpoint_path) diff --git a/trainer/trainer.py b/trainer/trainer.py index cc74024..a62b2b1 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -451,7 +451,7 @@ def __init__( # pylint: disable=dangerous-default-value self.epochs_done = 0 self.restore_step = 0 self.restore_epoch = 0 - self.best_loss = float("inf") + self.best_loss = {"train_loss": float("inf"), "eval_loss": float("inf") if self.config.run_eval else None} self.train_loader = None self.test_loader = None self.eval_loader = None @@ -1724,8 +1724,15 @@ def _restore_best_loss(self): logger.info(" > Restoring best loss from %s ...", os.path.basename(self.args.best_path)) ch = load_fsspec(self.args.restore_path, map_location="cpu") if "model_loss" in ch: - self.best_loss = ch["model_loss"] - logger.info(" > Starting with loaded last best loss %f", self.best_loss) + if isinstance(ch["model_loss"], dict): + self.best_loss = ch["model_loss"] + # For backwards-compatibility: + elif isinstance(ch["model_loss"], float): + if self.config.run_eval: + self.best_loss = {"train_loss": None, "eval_loss": ch["model_loss"]} + else: + self.best_loss = {"train_loss": ch["model_loss"], "eval_loss": None} + logger.info(" > Starting with loaded last best loss %s", self.best_loss) def test(self, model=None, test_samples=None) -> None: """Run evaluation steps on the test data split. You can either provide the model and the test samples @@ -1907,7 +1914,7 @@ def save_best_model(self) -> None: # save the model and update the best_loss self.best_loss = save_best_model( - eval_loss if eval_loss else train_loss, + {"train_loss": train_loss, "eval_loss": eval_loss}, self.best_loss, self.config, self.model,