-
Notifications
You must be signed in to change notification settings - Fork 0
Home
Kison Ho edited this page Feb 24, 2022
·
23 revisions
-
callbacks
: Contains callbacks forfit
method in the manager -
losses
: Contains wrapped loss functions -
metrics
: Contains wrapped metric functions -
train
: Contains other basic utilities for training
- A training manager
- Properties:
- compiled_losses: The loss function in
Metric
that must be exist - compiled_optimizer: The
torch.optim.Optimizer
that must be exist - loss_fn: A
Callable
method that takes the truth and predictions intorch.Tensor
and returns a losstorch.Tensor
- metrics: A
dict
of metrics with a name instr
and aCallable
method that takes the truth and predictions intorch.Tensor
and returns a losstorch.Tensor
- model: A target
torch.nn.Module
to be trained - optimizer: A
torch.optim.Optimizer
to train the model
- compiled_losses: The loss function in
- Methods:
-
__init__
-
Constructor
- Parameters:
- loss_fn: An optional
Loss
object to calculate the loss for single loss or adict
of losses inLoss
with their names instr
to calculate multiple losses - metrics: An optional
dict
of metrics with a name instr
and aMetric
object to calculate the metric - model: An optional target
torch.nn.Module
to be trained - optimizer: An optional
torch.optim.Optimizer
to train the model
-
-
compile
- Compiles the manager
- Parameters:
- loss_fn: A
Loss
object to calculate the loss for single loss or adict
of losses inLoss
with their names instr
to calculate multiple losses - metrics: A
dict
of metrics with a name instr
and aMetric
object to calculate the metric - optimizer: A
torch.optim.Optimizer
to train the model
- loss_fn: A
-
fit
- Training algorithm
- Parameters:
- training_dataset: The
data.DataLoader
for training dataset - epochs: The
int
number of training epochs - lr_scheduelr: An optioanl
torch.optim.lr_scheduler._LRScheduler
to update the lr per epoch - is_dynamic_pruning: A
bool
flag of if using dynamic pruning - show_verbose: A
bool
flag of if showing progress bar - val_dataset: An optional validation
data.DataLoader
- device: An optional
torch.device
where the data is moved to, gpu will be used when available if not specified. - use_multi_gpus: A
bool
flag of if using multi gpus - callbacks_list: A
list
of callbacks inCallback
- **kwargs: Additional keyword arguments that will be passed to
train_step
method. If given,train
method must be overriden to accept these arguments. - Returns: A trained
torch.nn.Module
- training_dataset: The
-
(classmethod)
from_checkpoint
-
Method to load a manager from a saved
Checkpoint
. The manager will not be compiled with a loss function and its metrics. -
Returns: A loaded
Manager
-
-
train
-
The single training step for an epoch
-
Parameters: - dataset: The
data.DataLoader
for training dataset - device: Atorch.device
where the data is moved to, should be same as the model - use_multi_gpus: Abool
flag of if using multi gpus - show_verbose: Abool
flag of if showing progress bar - callbacks_list: Alist
of callbacks inCallback
- Returns: A summary of
dict
with keys asstr
and values asfloat
- Returns: A summary of
-
-
train_step
- A single training step
- Parameters:
- x_train: The training data
- y_train: The training label
- Returns: A summary of
dict
with keys asstr
and values asfloat
-
test
- Test target model
- Parameters:
- dataset: A
data.DataLoader
to load the dataset - use_multi_gpus: A
bool
flag to use multi gpus during testing
- dataset: A
- Returns: A
dict
of validation summary
-
test_step
- A single testing step
- Parameters:
- x_train: The testing data in
torch.Tensor
- y_train: The testing label in
torch.Tensor
- x_train: The testing data in
- Returns: A
dict
of validation summary
-