From af4491625b4b6b558bce3faab1683e1e3b6c17d0 Mon Sep 17 00:00:00 2001 From: BowenD-UCB <84425382+BowenD-UCB@users.noreply.github.com> Date: Wed, 21 Feb 2024 17:06:43 -0800 Subject: [PATCH] solved bug in test error history --- chgnet/trainer/trainer.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/chgnet/trainer/trainer.py b/chgnet/trainer/trainer.py index d11d9b34..0843ed42 100644 --- a/chgnet/trainer/trainer.py +++ b/chgnet/trainer/trainer.py @@ -263,7 +263,8 @@ def train( print("---------Evaluate Model on Test Set---------------") for file in os.listdir(save_dir): if file.startswith("bestE_"): - best_checkpoint = torch.load(os.path.join(save_dir, file)) + test_file = file + best_checkpoint = torch.load(os.path.join(save_dir, test_file)) self.model.load_state_dict(best_checkpoint["model"]["state_dict"]) if save_test_result: @@ -274,8 +275,10 @@ def train( test_mae = self._validate( test_loader, is_test=True, test_result_save_path=None ) - self.training_history[key]["test"] = [test_mae[key] for key in self.targets] - self.save(filename=os.path.join(save_dir, file)) + + for key in self.targets: + self.training_history[key]["test"] = test_mae[key] + self.save(filename=os.path.join(save_dir, test_file)) def _train(self, train_loader: DataLoader, current_epoch: int) -> dict: """Train all data for one epoch.