-
Notifications
You must be signed in to change notification settings - Fork 0
Callbacks Documentation
#torchmanager/callbacks
An empty basic training callback
- Methods
-
on_batch_end
- The callback when batch ends
- Parameters:
- batch: An
int
of batch index - summary: A
dict
of summary with name instr
and value infloat
- batch: An
-
on_batch_start
- The callback when batch starts
- Parameters:
- batch: An
int
of batch index
- batch: An
- Parameters:
- The callback when batch starts
-
on_epoch_end
- The callback when batch ends
- Parameters:
- epoch: An
int
of epoch index - summary: A
dict
of training summary with name instr
and value infloat
- val_summary: A
dict
of validation summary with name instr
and value infloat
- epoch: An
- Parameters:
- The callback when batch ends
-
on_epoch_start
- The callback when epoch starts
- Parameters:
- epoch: An
int
of epoch index
- epoch: An
-
on_train_end
- The callback when training ends
-
-
on_train_start
- The callback when training starts
- Parameters:
- initial_epoch: An
int
of initial epoch index
The early stop callback that raises StopTraining
error during the training if monitored metric not improved for several steps
- extends:
Callback
- Properties:
- monitor: A
str
of monitored metric - monitor_type: AMonitorType
of eitherMIN
ofMAX
mode for the best model - steps: Anint
of steps to monitor
The callback that wraps last and best checkpoints in checkpoints
folder by last.model
and best_*.model
with tensorboard logs in data
folder together into a wrapped *.exp file
- extends:
.callback.Callback
- requires:
tensorboard
package
- Properties:
- best_ckpts: A
list
of.ckpt.BestCheckpoint
callbacks that records best checkpoints - last_ckpt: A.ckpt.LastCheckpoint
callback that records the last checkpoint - tensorboard: A.ckpt.TensorBoard
callback that records data to tensorboard
A callback with frequency control
- extends:
Callbacks
- abstract class that needs implementation of
step
method
- Properties:
- current_step: An
int
of the current step index- freq: A
WeightUpdateFreq
of the frequency type to update the weight
- freq: A
An abstract dynamic weight callback that set weight dynamically
- extends:
.callback.FrequencyCallback
- abstract class that needs implementation of
step
method
A dynamic weight callback that set weight dynamically with lambda function
- extends:
DynamicWeight
- Targeting to any object that performs to
.protocol.Weighted
protocol:
from torchmanager import losses
loss_fi = losses.Loss(...) # where torchmanager.losses.Loss` performs to `.protocols.Weighted` protocol
- Passing defined functions into the
DynamicWeight
callback:
def weight_fn(step: int) -> int: ...
dynamic_weight_callback = LambdaDynamicWeight(weight_fn, loss_fn)
- Or using Python lambda functions:
dyncami_weight_callback = LambdaDynamicWeight(lambda e: ..., loss_fn)
- Add to callbacks list and parsing to
fit
function:
from torchmanager import Manager
manager = Manager(..., loss_fn=loss_fn, ...)
callbacks_list = [..., dynamic_weight_callback]
manager.fit(..., callbacks_list=callbacks_list)
The callback to save the last checkpoint during training
- extends:
Callback
- Properties:
- ckpt_path: A
str
of checkpoint path
- ckpt_path: A
The callback to step learning rate scheduler
- extends:
Callback
- Parameters:
- freq: An
_lr.LrScheduleFreq
of the frequency to update learning rate
- freq: An
The enum of monitor types
The callback to save the latest checkpoint for each epoch
- extends:
LastCheckpoint
- Properties:
- best_score: A
float
of the best score to be monitored - monitor: Astr
of the summary name to be monitored - monitor_type: AMonitorType
of the monitor
The callback to record summary to tensorboard for each epoch
- extends:
FrequencyCallback
- Properties:
- writer: A
tensorboard.SummaryWriter
to record scalars
- Methods:
-
add_graph
- Add graph to TensorBoard
- Parameters:
- model: A
torch.nn.Module
to add - input_shape: An optionaltuple
of inint
for the inputs
-