Skip to content
Kison Ho edited this page Jun 9, 2022 · 23 revisions

Packages

Contains callbacks for fit method in the manager

core

Contains required packages and core packages

Contains wrapped loss functions

Contains wrapped metric functions

train

Contains other basic utilities for training

Classes

BaseManager

  • The basic manager
  • Properties: - compiled: A bool flag of if the manager has been compiled - loss_fn: A Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor - metrics: A dict of metrics with a name in str and a Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor - model: A target torch.nn.Module to be trained - optimizer: A torch.optim.Optimizer to train the model
  • Methods:
    • Constructor - Parameters: - loss_fn: An optional Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses - metrics: An optional dict of metrics with a name in str and a Metric object to calculate the metric - model: An optional target torch.nn.Module to be trained - optimizer: An optional torch.optim.Optimizer to train the model
    • compile
      • Compiles the manager
        • Parameters:
          • loss_fn: A Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses
          • metrics: A dict of metrics with a name in str and a Metric object to calculate the metric
          • optimizer: A torch.optim.Optimizer to train the model
    • from_checkpoint
      • Method to load a manager from a saved Checkpoint. The manager will not be compiled with a loss function and its metrics.
      • classmethod
        • Returns: A loaded Manager
    • to_checkpoint
      • Convert the current manager to a checkpoint
        • Returns: A Checkpoint with its model in Module type

DataManager

  • The manager to load data during training or testing
  • Methods:
    • unpack_data
      • Unpacks data to input and target
        • Parameters:
          • data: Any kind of data object
        • Returns: A tuple of Any kind of input and Any kind of target

TestingManager

  • A testing manager, only used for testing
  • extends: .manager.Manager
  • Properties:
    • compiled_losses: The loss function in Loss that must be exist
    • compiled_metrics: The dict of metrics in Metric that does not contain losses
  • Methods:
    • test

      • Test target model
        • Parameters:
          • dataset: Either SizedIterable or data.DataLoader to load the dataset
          • device: An optional torch.device to test on
          • use_multi_gpus: A bool flag to use multi gpus during testing
          • show_verbose: A bool flag to show the progress bar during testing
        • Returns: A dict of validation summary
    • test_step

      • A single testing step
      • Parameters:
        • x_train: The testing data in torch.Tensor
        • y_train: The testing label in torch.Tensor
      • Returns: A dict of validation summary

Manager

  • A training manager
  • extends: .testing.Manager
  • [Deprecation Warning]: Method train becomes protected from v1.0.2, the public method will be removed from v1.2.0. Override _train method instead.
  • Properties:
    • current_epoch: The int index of current training epoch
    • compiled_optimizer: The torch.optim.Optimizer that must be exist
  • Methods:
    • _train
      • The single training step for an epoch
        • Parameters:
          • dataset: A SizedIterable training dataset
          • iterations: An optional int of total training iterations, must be smaller than the size of dataset
          • device: A torch.device where the data is moved to, should be same as the model
          • use_multi_gpus: A bool flag of if using multi gpus
          • show_verbose: A bool flag of if showing progress bar
          • verbose_type: A view.VerboseType that controls the display of verbose
          • callbacks_list: A list of callbacks in Callback
        • Returns: A summary of dict with keys as str and values as float
    • fit
      • Training algorithm
        • Parameters:
          • training_dataset: Any kind of training dataset, must performs to SizedIterable
          • epochs: An optional int number of training epochs
          • iterations: An optional int number of training iterations
          • lr_scheduelr: An optioanl torch.optim.lr_scheduler._LRScheduler to update the lr per epoch
          • is_dynamic_pruning: A bool flag of if using dynamic pruning
          • val_dataset: An optional validation Any
          • device: An optional torch.device where the data is moved to, gpu will be used when available if not specified.
          • use_multi_gpus: A bool flag of if using multi gpus
          • callbacks_list: A list of callbacks in Callback
          • **kwargs: Additional keyword arguments that will be passed to train method.
        • Returns: A trained torch.nn.Module
    • train_step
      • A single training step
        • Parameters:
          • x_train: The training data
          • y_train: The training label
        • Returns: A summary of dict with keys as str and values as float
Clone this wiki locally