Description
Motivation
Often times we only want to train an algorithm until it learned the intended behavior, and a total number of frames is only a proxy for the stopping condition.
I would like to propose adding a new callback that makes early stopping possible, much like the example in Lightning.
Currently, no callback can influence the end of the training loop, so my current workaround is setting self.total_frames
to 1, which isn't great.
Solution
- Adding a new early-stopping-hook
- Creating a default hook implementation for
TrainerHookBase
- Stopping the loop if the hook gives a positive return
The main question for me would be where to call the hook and what it can observe.
I would argue that both the batch
and the losses_detached
would be a relevant signal for the stopping hook, as one might want to stop when a certain return was seen n-times or the loss is smaller than a threshold.
Therefore, either the call would happen in optim_steps(self, batch: TensorDictBase) -> None
and the function returns the stopping signal to def train(self)
or the hook is called in train
and optim_steps
returns the loss such that the early stopping hook can observe it.
An example for the second approach could look something like:
def train(self):
...
average_losses = None
if self.collected_frames > self.collector.init_random_frames:
average_losses = self.optim_steps(batch)
if self.early_stopping_hook(batch,average_losses):
self.save_trainer(force_save=True)
break
self._post_steps_hook()
...
I am happy to open a PR if this is something of interest to you.
Checklist
- I have checked that there is no similar issue in the repo (required)