-
Notifications
You must be signed in to change notification settings - Fork 0
Home
Qisheng Robert He edited this page Jan 12, 2024
·
23 revisions
#torchmanager
Contains callbacks for fit
method in the manager
Contains wrapped loss functions
Contains wrapped metric functions
The basic manager
- Properties:
- compiled: A
bool
flag of if the manager has been compiled - loss_fn: ACallable
method that takes the truth and predictions intorch.Tensor
and returns a losstorch.Tensor
- metrics: Adict
of metrics with a name instr
and aCallable
method that takes the truth and predictions intorch.Tensor
and returns a losstorch.Tensor
- model: A targettorch.nn.Module
to be trained - optimizer: Atorch.optim.Optimizer
to train the model
- Methods:
- Constructor
- Parameters:
- loss_fn: An optional
Loss
object to calculate the loss for single loss or adict
of losses inLoss
with their names instr
to calculate multiple losses - metrics: An optionaldict
of metrics with a name instr
and aMetric
object to calculate the metric - model: An optional targettorch.nn.Module
to be trained - optimizer: An optionaltorch.optim.Optimizer
to train the model -
compile
- Compiles the manager
- Parameters:
- loss_fn: A
Loss
object to calculate the loss for single loss or adict
of losses inLoss
with their names instr
to calculate multiple losses - metrics: A
dict
of metrics with a name instr
and aMetric
object to calculate the metric - optimizer: A
torch.optim.Optimizer
to train the model
- loss_fn: A
- Parameters:
- Compiles the manager
-
from_checkpoint
- Method to load a manager from a saved
Checkpoint
. The manager will not be compiled with a loss function and its metrics. - classmethod
- Returns: A loaded
Manager
- Returns: A loaded
- Method to load a manager from a saved
-
to_checkpoint
Convert the current manager to a checkpoint - Returns: ACheckpoint
with its model inModule
type
- Constructor
- Parameters:
- loss_fn: An optional
The manager to load data during training or testing
- Methods:
-
unpack_data
Unpacks data to input and target - Parameters: - data:Any
kind of data object - Returns: Atuple
ofAny
kind of input andAny
kind of target
-
A testing manager, only used for testing
- extends:
BaseManager
,DataManager
- Properties
- compiled_losses: The loss function in
Loss
that must be exist - compiled_metrics: The
dict
of metrics inMetric
that does not contain losses
- compiled_losses: The loss function in
- Methods
-
test
- Test target model
- Parameters:
- dataset: Either
SizedIterable
ordata.DataLoader
to load the dataset - device: An optional
torch.device
to test on - use_multi_gpus: A
bool
flag to use multi gpus during testing - show_verbose: A
bool
flag to show the progress bar during testing
- dataset: Either
- Returns: A
dict
of validation summary
- Parameters:
- Test target model
-
test_step
- A single testing step
- Parameters:
- x_train: The testing data in
torch.Tensor
- y_train: The testing label in
torch.Tensor
- x_train: The testing data in
- Returns: A
dict
of validation summary
-
A training manager
- extends:
TestingManager
- [Deprecation Warning]: Method
train
becomes protected from v1.0.2, the public method will be removed from v1.2.0. Override_train
method instead. - Compile a model, optimizer, loss function, and metrics into the manager:
import torch
from torchmanager import losses, metrics
class SomeModel(torch.nn.Module): ...
model = SomeModel()
optimizer = torch.optim.SGD(...)
loss_fn = losses.Loss(...)
metric_fns = {
... ...
... }
manager = Manager(model, optimizer, loss_fn, metric_fns=metric_fns)
- Train using fit method:
from torch.utils.data import Dataset, DataLoader
dataset = Dataset(...)
dataset = DataLoader(dataset, ...)
epochs: int = ...
manager.fit(dataset, epochs, ...)
- Properties
- current_epoch: The
int
index of current training epoch - compiled_optimizer: The
torch.optim.Optimizer
that must be exist
- current_epoch: The
- Methods
-
_train
- The single training step for an epoch
- Parameters:
- dataset: A
SizedIterable
training dataset - iterations: An optional
int
of total training iterations, must be smaller than the size of dataset - device: A
torch.device
where the data is moved to, should be same as the model - use_multi_gpus: A
bool
flag of if using multi gpus - show_verbose: A
bool
flag of if showing progress bar - verbose_type: A
view.VerboseType
that controls the display of verbose - callbacks_list: A
list
of callbacks inCallback
- dataset: A
- Returns: A summary of
dict
with keys asstr
and values asfloat
- Parameters:
- The single training step for an epoch
-
fit
- Training algorithm
- Parameters:
- training_dataset: Any kind of training dataset, must performs to
SizedIterable
- epochs: An optional
int
number of training epochs - iterations: An optional
int
number of training iterations - lr_scheduelr: An optioanl
torch.optim.lr_scheduler._LRScheduler
to update the lr per epoch - is_dynamic_pruning: A
bool
flag of if using dynamic pruning - val_dataset: An optional validation
Any
- device: An optional
torch.device
where the data is moved to, gpu will be used when available if not specified. - use_multi_gpus: A
bool
flag of if using multi gpus - callbacks_list: A
list
of callbacks inCallback
- **kwargs: Additional keyword arguments that will be passed to
train
method.
- training_dataset: Any kind of training dataset, must performs to
- Returns: A trained
torch.nn.Module
- Parameters:
- Training algorithm
-
train_step
- A single training step
- Parameters:
- x_train: The training data
- y_train: The training label
- Returns: A summary of
dict
with keys asstr
and values asfloat
- Parameters:
- A single training step
-