Skip to content
Kison Ho edited this page Feb 24, 2022 · 23 revisions

Torchmanager Documentation

Packages

  • callbacks: Contains callbacks for fit method in the manager
  • losses: Contains wrapped loss functions
  • metrics: Contains wrapped metric functions
  • train: Contains other basic utilities for training

Classes

Manager:

  • A training manager
  • Properties:
    • compiled_losses: The loss function in Metric that must be exist
    • compiled_optimizer: The torch.optim.Optimizer that must be exist
    • loss_fn: A Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor
    • metrics: A dict of metrics with a name in str and a Callable method that takes the truth and predictions in torch.Tensor and returns a loss torch.Tensor
    • model: A target torch.nn.Module to be trained
    • optimizer: A torch.optim.Optimizer to train the model
  • Methods
    • __init__

      • Constructor

        • Parameters:
        • loss_fn: An optional Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses
        • metrics: An optional dict of metrics with a name in str and a Metric object to calculate the metric
        • model: An optional target torch.nn.Module to be trained
        • optimizer: An optional torch.optim.Optimizer to train the model
    • compile

      • Compiles the manager
      • Parameters:
        • loss_fn: A Loss object to calculate the loss for single loss or a dict of losses in Loss with their names in str to calculate multiple losses
        • metrics: A dict of metrics with a name in str and a Metric object to calculate the metric
        • optimizer: A torch.optim.Optimizer to train the model
    • fit

      • Training algorithm
      • Parameters:
        • training_dataset: The data.DataLoader for training dataset
        • epochs: The int number of training epochs
        • lr_scheduelr: An optioanl torch.optim.lr_scheduler._LRScheduler to update the lr per epoch
        • is_dynamic_pruning: A bool flag of if using dynamic pruning
        • show_verbose: A bool flag of if showing progress bar
        • val_dataset: An optional validation data.DataLoader
        • device: An optional torch.device where the data is moved to, gpu will be used when available if not specified.
        • use_multi_gpus: A bool flag of if using multi gpus
        • callbacks_list: A list of callbacks in Callback
        • **kwargs: Additional keyword arguments that will be passed to train_step method. If given, train method must be overriden to accept these arguments.
        • Returns: A trained torch.nn.Module
    • (classmethod) from_checkpoint

      • Method to load a manager from a saved Checkpoint. The manager will not be compiled with a loss function and its metrics.

      • Returns: A loaded Manager

    • train

      • The single training step for an epoch

      • Parameters: - dataset: The data.DataLoader for training dataset - device: A torch.device where the data is moved to, should be same as the model - use_multi_gpus: A bool flag of if using multi gpus - show_verbose: A bool flag of if showing progress bar - callbacks_list: A list of callbacks in Callback

        • Returns: A summary of dict with keys as str and values as float
    • train_step

      • A single training step
      • Parameters:
        • x_train: The training data
        • y_train: The training label
      • Returns: A summary of dict with keys as str and values as float
    • test

      • Test target model
      • Parameters:
        • dataset: A data.DataLoader to load the dataset
        • use_multi_gpus: A bool flag to use multi gpus during testing
      • Returns: A dict of validation summary
    • test_step

      • A single testing step
      • Parameters:
        • x_train: The testing data in torch.Tensor
        • y_train: The testing label in torch.Tensor
      • Returns: A dict of validation summary
Clone this wiki locally