Skip to content

Losses Documentation

Qisheng Robert He edited this page Jan 12, 2024 · 6 revisions

Classes

Loss

The main loss function

  • extends: Metric
  • Could be use as a decorator of a function
@Loss
def loss_fn(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
	...
  • Loss tensor is stayed in memory until reset is called
  • Constructor:
    • Parameters: - loss_fn: A Callable function that accepts input or y_pred in Any kind and target or y_true in Any kind as inputs and gives a loss in torch.Tensor

CrossEntropy

The cross entropy loss

  • extends: Loss

Dice

The dice loss

  • extends: Loss

DiceCE

Combined Dice loss and CrossEntropy loss

  • extends: CrossEntropy, Dice

FocalCrossEntropy

The focal cross entropy loss

  • extends: Loss
  • Constructor:
    • Parameters: - alpha: A float of alpha in focal cross entropy - gamma: A float of gamma in focal cross entropy - calculate_average: A bool flag of if calculate average for the focal loss

KLDiv

KL-Div Loss

  • extends: Loss

MAE

The MAE loss

  • extends _ReductableLoss

MSE

The MSE loss

  • extends _ReductableLoss
  • Properties: - reduction: A .loss.Reduction of reduction method - replace_nan: A boolean flag of if replacing nan results to zeros

MultiLosses

A loss with multiple losses

  • extends: Loss
  • Properties:
    • losses: A list of loss metrics in Metric

MultiOutputsLosses

A loss with multiple losses for multiple outputs

  • Pending Depreciation Warning: MultiOutputsLosses will be deprecated in v1.1.0, use MultiLosses along with target parameter for each loss instead.
  • extends: Loss
  • Properties:
    • losses: A dict of loss metrics in Metric

Methods

loss

The loss wrapping function that wrap a function into a loss

  • Use as a decorator
@loss
def loss_fn(input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
	...