diff --git a/pypots/classification/base.py b/pypots/classification/base.py index 9dcb8a52..ca587c29 100644 --- a/pypots/classification/base.py +++ b/pypots/classification/base.py @@ -353,7 +353,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/clustering/crli/model.py b/pypots/clustering/crli/model.py index 330270bd..8210fc6d 100644 --- a/pypots/clustering/crli/model.py +++ b/pypots/clustering/crli/model.py @@ -309,7 +309,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss:.4f}", ) diff --git a/pypots/clustering/vader/model.py b/pypots/clustering/vader/model.py index f306ffa0..cb5f3201 100644 --- a/pypots/clustering/vader/model.py +++ b/pypots/clustering/vader/model.py @@ -317,7 +317,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/forecasting/base.py b/pypots/forecasting/base.py index 1806272c..1cf41c1b 100644 --- a/pypots/forecasting/base.py +++ b/pypots/forecasting/base.py @@ -350,7 +350,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/forecasting/csdi/model.py b/pypots/forecasting/csdi/model.py index e25c5ff7..734d3870 100644 --- a/pypots/forecasting/csdi/model.py +++ b/pypots/forecasting/csdi/model.py @@ -313,7 +313,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/imputation/base.py b/pypots/imputation/base.py index 23739638..6ca8bcb2 100644 --- a/pypots/imputation/base.py +++ b/pypots/imputation/base.py @@ -350,7 +350,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/imputation/csdi/model.py b/pypots/imputation/csdi/model.py index 148ba3af..7d7138e1 100644 --- a/pypots/imputation/csdi/model.py +++ b/pypots/imputation/csdi/model.py @@ -293,7 +293,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/imputation/gpvae/model.py b/pypots/imputation/gpvae/model.py index 0062fc38..85314b28 100644 --- a/pypots/imputation/gpvae/model.py +++ b/pypots/imputation/gpvae/model.py @@ -320,7 +320,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", ) diff --git a/pypots/imputation/usgan/model.py b/pypots/imputation/usgan/model.py index e5c607ab..f29eaced 100644 --- a/pypots/imputation/usgan/model.py +++ b/pypots/imputation/usgan/model.py @@ -334,7 +334,7 @@ def _train_model( # save the model if necessary self._auto_save_model_if_necessary( - confirm_saving=mean_loss < self.best_loss, + confirm_saving=self.best_epoch == epoch, saving_name=f"{self.__class__.__name__}_epoch{epoch}_loss{mean_loss}", )