Skip to content

Commit

Permalink
solved bug in test error history
Browse files Browse the repository at this point in the history
  • Loading branch information
bowen-bd committed Feb 22, 2024
1 parent 84226d7 commit af44916
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions chgnet/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ def train(
print("---------Evaluate Model on Test Set---------------")
for file in os.listdir(save_dir):
if file.startswith("bestE_"):
best_checkpoint = torch.load(os.path.join(save_dir, file))
test_file = file
best_checkpoint = torch.load(os.path.join(save_dir, test_file))

self.model.load_state_dict(best_checkpoint["model"]["state_dict"])
if save_test_result:
Expand All @@ -274,8 +275,10 @@ def train(
test_mae = self._validate(
test_loader, is_test=True, test_result_save_path=None
)
self.training_history[key]["test"] = [test_mae[key] for key in self.targets]
self.save(filename=os.path.join(save_dir, file))

for key in self.targets:
self.training_history[key]["test"] = test_mae[key]
self.save(filename=os.path.join(save_dir, test_file))

def _train(self, train_loader: DataLoader, current_epoch: int) -> dict:
"""Train all data for one epoch.
Expand Down

0 comments on commit af44916

Please sign in to comment.