diff --git a/src/graphnet/training/loss_functions.py b/src/graphnet/training/loss_functions.py index 624a5fa53..d3fc43f7e 100644 --- a/src/graphnet/training/loss_functions.py +++ b/src/graphnet/training/loss_functions.py @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: """Implement loss calculation.""" # Check(s) assert prediction.dim() == 2 + if target.dim() != prediction.dim(): + target = target.squeeze(1) assert prediction.size() == target.size() elements = torch.mean((prediction - target) ** 2, dim=-1) @@ -443,3 +445,93 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: kappa = prediction[:, 3] p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]] return self._evaluate(p, target) + + +class EnsembleLoss(LossFunction): + """Chain multiple loss functions together.""" + + def __init__( + self, + loss_functions: List[LossFunction], + loss_factors: List[float] = None, + prediction_keys: Optional[List[List[int]]] = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Chain multiple loss functions together. + + Optionally apply a weight to each loss function contribution. + + E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5 + + Args: + loss_functions: A list of loss functions to use. + Each loss function contributes a term to the overall loss. + loss_factors: An optional list of factors that will be mulitplied + to each loss function contribution. Must be ordered according + to `loss_functions`. If not given, the weights default to 1. + prediction_keys: An optional list of lists of indices for which + prediction columns to use for each loss function. If not + given, all columns are used for all loss functions. + """ + if loss_factors is None: + # add weight of 1 - i.e no discrimination + loss_factors = np.repeat(1, len(loss_functions)).tolist() + + assert len(loss_functions) == len(loss_factors) + self._factors = loss_factors + self._loss_functions = loss_functions + + if prediction_keys is not None: + self._prediction_keys: Optional[List[List[int]]] = prediction_keys + else: + self._prediction_keys = None + super().__init__(*args, **kwargs) + + def _forward(self, prediction: Tensor, target: Tensor) -> Tensor: + """Calculate loss using multiple loss functions. + + Args: + prediction: Output of the model. + target: Target tensor, extracted from graph object. + + Returns: + Elementwise loss terms. Shape [N,] + """ + if self._prediction_keys is None: + prediction_keys = [list(range(prediction.size(1)))] * len( + self._loss_functions + ) + else: + prediction_keys = self._prediction_keys + for k, (loss_function, prediction_key) in enumerate( + zip(self._loss_functions, prediction_keys) + ): + if k == 0: + elements = self._factors[k] * loss_function._forward( + prediction=prediction[:, prediction_key], target=target + ) + else: + elements += self._factors[k] * loss_function._forward( + prediction=prediction[:, prediction_key], target=target + ) + return elements + + +class RMSEVonMisesFisher3DLoss(EnsembleLoss): + """Combine the VonMisesFisher3DLoss with RMSELoss.""" + + def __init__(self, vmfs_factor: float = 0.05) -> None: + """VonMisesFisher3DLoss with a RMSE penality term. + + The VonMisesFisher3DLoss will be weighted with `vmfs_factor`. + + Args: + vmfs_factor: A factor applied to the VonMisesFisher3DLoss term. + Defaults ot 0.05. + """ + super().__init__( + loss_functions=[RMSELoss(), VonMisesFisher3DLoss()], + loss_factors=[1, vmfs_factor], + prediction_keys=[[0, 1, 2], [0, 1, 2, 3]], + )