[Feature Request] EarlyStopping
for torchrl.trainers.Trainer
#2370
Labels
enhancement
New feature or request
EarlyStopping
for torchrl.trainers.Trainer
#2370
Motivation
Often times we only want to train an algorithm until it learned the intended behavior, and a total number of frames is only a proxy for the stopping condition.
I would like to propose adding a new callback that makes early stopping possible, much like the example in Lightning.
Currently, no callback can influence the end of the training loop, so my current workaround is setting
self.total_frames
to 1, which isn't great.Solution
TrainerHookBase
The main question for me would be where to call the hook and what it can observe.
I would argue that both the
batch
and thelosses_detached
would be a relevant signal for the stopping hook, as one might want to stop when a certain return was seen n-times or the loss is smaller than a threshold.Therefore, either the call would happen in
optim_steps(self, batch: TensorDictBase) -> None
and the function returns the stopping signal todef train(self)
or the hook is called intrain
andoptim_steps
returns the loss such that the early stopping hook can observe it.An example for the second approach could look something like:
I am happy to open a PR if this is something of interest to you.
Checklist
The text was updated successfully, but these errors were encountered: