diff --git a/src/mpol/training.py b/src/mpol/training.py index 43fa423d..70167972 100644 --- a/src/mpol/training.py +++ b/src/mpol/training.py @@ -273,11 +273,8 @@ def train(self, model, dataset): # update model parameters via gradient descent self._optimizer.step() - # store current training parameter values - # TODO: store hyperpar values, access in crossval.py - self._train_state["kfold"] = self._kfold - self._train_state["epoch"] = count - self._train_state["learn_rate"] = self._optimizer.state_dict()['param_groups'][0]['lr'] + if self._scheduler is not None: + self._scheduler.step(loss) # generate optional fit diagnostics if self._train_diag_step is not None and (count % self._train_diag_step == 0 or count == self._epochs or self.loss_convergence(np.array(losses))):