Skip to content

Commit

Permalink
run_crossval: use local model for train/test
Browse files Browse the repository at this point in the history
jeffjennings committed Nov 29, 2023
1 parent ed07674 commit 01858b2
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/mpol/crossval.py
Original file line number Diff line number Diff line change
@@ -200,8 +200,13 @@ def run_crossval(self, dataset):
)

# run training
loss, loss_history = trainer.train(self._model, train_set)
loss, loss_history = trainer.train(model, train_set)

# run testing
all_scores.append(trainer.test(model, test_set))

# store objects from the most recent kfold for diagnostics
self._model = model
if self._store_cv_diagnostics:
self._diagnostics["loss_histories"].append(loss_history)
# update regularizer strength values

0 comments on commit 01858b2

Please sign in to comment.