Skip to content

[Feature Request] EarlyStopping for torchrl.trainers.Trainer #2370

Open
@jkrude

Description

@jkrude

Motivation

Often times we only want to train an algorithm until it learned the intended behavior, and a total number of frames is only a proxy for the stopping condition.
I would like to propose adding a new callback that makes early stopping possible, much like the example in Lightning.

Currently, no callback can influence the end of the training loop, so my current workaround is setting self.total_frames to 1, which isn't great.

Solution

  • Adding a new early-stopping-hook
  • Creating a default hook implementation for TrainerHookBase
  • Stopping the loop if the hook gives a positive return

The main question for me would be where to call the hook and what it can observe.
I would argue that both the batch and the losses_detached would be a relevant signal for the stopping hook, as one might want to stop when a certain return was seen n-times or the loss is smaller than a threshold.
Therefore, either the call would happen in optim_steps(self, batch: TensorDictBase) -> None and the function returns the stopping signal to def train(self) or the hook is called in train and optim_steps returns the loss such that the early stopping hook can observe it.

An example for the second approach could look something like:

def train(self):
    ...
    average_losses = None
    if self.collected_frames > self.collector.init_random_frames:
        average_losses = self.optim_steps(batch)
    
    if self.early_stopping_hook(batch,average_losses):
      self.save_trainer(force_save=True)
     break
    
    self._post_steps_hook()
   ...

I am happy to open a PR if this is something of interest to you.

Checklist

  • I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions