Skip to content

Callbacks Documentation

Qisheng Robert He edited this page Jan 12, 2024 · 6 revisions

Classes

#torchmanager/callbacks

Callback

An empty basic training callback

  • Methods
    • on_batch_end

      • The callback when batch ends
      • Parameters:
        • batch: An int of batch index
        • summary: A dict of summary with name in str and value in float
    • on_batch_start

      • The callback when batch starts
        • Parameters:
          • batch: An int of batch index
    • on_epoch_end

      • The callback when batch ends
        • Parameters:
          • epoch: An int of epoch index
          • summary: A dict of training summary with name in str and value in float
          • val_summary: A dict of validation summary with name in str and value in float
    • on_epoch_start

      • The callback when epoch starts
      • Parameters:
        • epoch: An int of epoch index
    • on_train_end

      • The callback when training ends
  • on_train_start
    • The callback when training starts
    • Parameters: - initial_epoch: An int of initial epoch index

EarlyStop

The early stop callback that raises StopTraining error during the training if monitored metric not improved for several steps

  • extends: Callback
  • Properties: - monitor: A str of monitored metric - monitor_type: A MonitorType of either MIN of MAX mode for the best model - steps: An int of steps to monitor

Experiment

The callback that wraps last and best checkpoints in checkpoints folder by last.model and best_*.model with tensorboard logs in data folder together into a wrapped *.exp file

  • extends: .callback.Callback
  • requires: tensorboard package
  • Properties: - best_ckpts: A list of .ckpt.BestCheckpoint callbacks that records best checkpoints - last_ckpt: A .ckpt.LastCheckpoint callback that records the last checkpoint - tensorboard: A .ckpt.TensorBoard callback that records data to tensorboard

FrequencyCallback

A callback with frequency control

  • extends: Callbacks
  • abstract class that needs implementation of step method
  • Properties: - current_step: An int of the current step index
    • freq: A WeightUpdateFreq of the frequency type to update the weight

DynamicWeight

An abstract dynamic weight callback that set weight dynamically

  • extends: .callback.FrequencyCallback
  • abstract class that needs implementation of step method

LambdaDynamicWeight

A dynamic weight callback that set weight dynamically with lambda function

  • extends: DynamicWeight
  • Targeting to any object that performs to .protocol.Weighted protocol:
from torchmanager import losses
loss_fi = losses.Loss(...) # where torchmanager.losses.Loss` performs to `.protocols.Weighted` protocol
  • Passing defined functions into the DynamicWeight callback:
def weight_fn(step: int) -> int: ...
dynamic_weight_callback = LambdaDynamicWeight(weight_fn, loss_fn)
  • Or using Python lambda functions:
dyncami_weight_callback = LambdaDynamicWeight(lambda e: ..., loss_fn)
  • Add to callbacks list and parsing to fit function:
from torchmanager import Manager
manager = Manager(..., loss_fn=loss_fn, ...)
callbacks_list = [..., dynamic_weight_callback]
manager.fit(..., callbacks_list=callbacks_list)

LastCheckpoint

The callback to save the last checkpoint during training

  • extends: Callback
  • Properties:
    • ckpt_path: A str of checkpoint path

LrSchedueler

The callback to step learning rate scheduler

  • extends: Callback
  • Parameters:
    • freq: An _lr.LrScheduleFreq of the frequency to update learning rate

MonitorType

The enum of monitor types

BestCheckpoint

The callback to save the latest checkpoint for each epoch

  • extends: LastCheckpoint
  • Properties: - best_score: A float of the best score to be monitored - monitor: A str of the summary name to be monitored - monitor_type: A MonitorType of the monitor

TensorBoard

The callback to record summary to tensorboard for each epoch

  • extends: FrequencyCallback
  • Properties: - writer: A tensorboard.SummaryWriter to record scalars
  • Methods:
    • add_graph

      • Add graph to TensorBoard
      • Parameters: - model: A torch.nn.Module to add - input_shape: An optional tuple of in int for the inputs