Skip to content
Kison Ho edited this page Jun 9, 2022 · 23 revisions

Packages

callbacks:

Contains callbacks for fit method in the manager

core:

Contains required packages and core packages

losses:

Contains wrapped loss functions

metrics:

Contains wrapped metric functions

train:

Contains other basic utilities for training

Classes

BaseManager

  • The basic manager
  • Properties: - compiled: A bool flag of if the manager has been compiled - loss_fn: A Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor - metrics: A dict of metrics with a name in str and a Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor - model: A target torch.nn.Module to be trained - optimizer: A torch.optim.Optimizer to train the model
  • Methods:
    • Constructor - Parameters: - loss_fn: An optional Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses - metrics: An optional dict of metrics with a name in str and a Metric object to calculate the metric - model: An optional target torch.nn.Module to be trained - optimizer: An optional torch.optim.Optimizer to train the model
    • compile
      • Compiles the manager
        • Parameters:
          • loss_fn: A Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses
          • metrics: A dict of metrics with a name in str and a Metric object to calculate the metric
          • optimizer: A torch.optim.Optimizer to train the model
    • from_checkpoint
      • Method to load a manager from a saved Checkpoint. The manager will not be compiled with a loss function and its metrics.
      • classmethod
        • Returns: A loaded Manager
    • to_checkpoint
      • Convert the current manager to a checkpoint
        • Returns: A Checkpoint with its model in Module type

DataManager

  • The manager to load data during training or testing
  • Methods:
    • unpack_data
      • Unpacks data to input and target
        • Parameters:
          • data: Any kind of data object
        • Returns: A tuple of Any kind of input and Any kind of target

TestingManager

  • A testing manager, only used for testing
  • extends: .manager.Manager
  • Properties:
    • compiled_losses: The loss function in Loss that must be exist
    • compiled_metrics: The dict of metrics in Metric that does not contain losses
  • Methods:
    • test

      • Test target model
        • Parameters:
          • dataset: Either SizedIterable or data.DataLoader to load the dataset
          • device: An optional torch.device to test on
          • use_multi_gpus: A bool flag to use multi gpus during testing
          • show_verbose: A bool flag to show the progress bar during testing
        • Returns: A dict of validation summary
    • test_step

      • A single testing step
      • Parameters:
        • x_train: The testing data in torch.Tensor
        • y_train: The testing label in torch.Tensor
      • Returns: A dict of validation summary

Manager

  • A training manager
  • extends: .testing.Manager
  • [Deprecation Warning]: Method train becomes protected from v1.0.2, the public method will be removed from v1.2.0. Override _train method instead.
  • Properties:
    • current_epoch: The int index of current training epoch
    • compiled_optimizer: The torch.optim.Optimizer that must be exist
  • Methods:
    • _train
      • The single training step for an epoch
        • Parameters:
          • dataset: A SizedIterable training dataset
          • iterations: An optional int of total training iterations, must be smaller than the size of dataset
          • device: A torch.device where the data is moved to, should be same as the model
          • use_multi_gpus: A bool flag of if using multi gpus
          • show_verbose: A bool flag of if showing progress bar
          • verbose_type: A view.VerboseType that controls the display of verbose
          • callbacks_list: A list of callbacks in Callback
        • Returns: A summary of dict with keys as str and values as float
    • fit
      • Training algorithm
        • Parameters:
          • training_dataset: Any kind of training dataset, must performs to SizedIterable
          • epochs: An optional int number of training epochs
          • iterations: An optional int number of training iterations
          • lr_scheduelr: An optioanl torch.optim.lr_scheduler._LRScheduler to update the lr per epoch
          • is_dynamic_pruning: A bool flag of if using dynamic pruning
          • val_dataset: An optional validation Any
          • device: An optional torch.device where the data is moved to, gpu will be used when available if not specified.
          • use_multi_gpus: A bool flag of if using multi gpus
          • callbacks_list: A list of callbacks in Callback
          • **kwargs: Additional keyword arguments that will be passed to train method.
        • Returns: A trained torch.nn.Module
    • train_step
      • A single training step
        • Parameters:
          • x_train: The training data
          • y_train: The training label
        • Returns: A summary of dict with keys as str and values as float
Clone this wiki locally