Skip to content

Commit

Permalink
train: update scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffjennings committed Nov 29, 2023
1 parent c6e73e9 commit bacdeec
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/mpol/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,8 @@ def train(self, model, dataset):
# update model parameters via gradient descent
self._optimizer.step()

# store current training parameter values
# TODO: store hyperpar values, access in crossval.py
self._train_state["kfold"] = self._kfold
self._train_state["epoch"] = count
self._train_state["learn_rate"] = self._optimizer.state_dict()['param_groups'][0]['lr']
if self._scheduler is not None:
self._scheduler.step(loss)

# generate optional fit diagnostics
if self._train_diag_step is not None and (count % self._train_diag_step == 0 or count == self._epochs or self.loss_convergence(np.array(losses))):
Expand Down

0 comments on commit bacdeec

Please sign in to comment.