Skip to content

Commit

Permalink
train: track learn rate
Browse files Browse the repository at this point in the history
jeffjennings committed Nov 29, 2023
1 parent bacdeec commit 16e9499
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/mpol/training.py
Original file line number Diff line number Diff line change
@@ -223,7 +223,7 @@ def train(self, model, dataset):

count = 0
losses = []
self._train_state = {}
learn_rates = []

# guess initial strengths for regularizers in `self._regularizers`
# that have 'guess':True
@@ -275,6 +275,7 @@ def train(self, model, dataset):

if self._scheduler is not None:
self._scheduler.step(loss)
learn_rates.append(self._optimizer.param_groups[0]['lr'])

# generate optional fit diagnostics
if self._train_diag_step is not None and (count % self._train_diag_step == 0 or count == self._epochs or self.loss_convergence(np.array(losses))):

0 comments on commit 16e9499

Please sign in to comment.