Skip to content

Commit

Permalink
Merge pull request #716 from Aske-Rosted/RMSEVonMisesFisher3DLoss
Browse files Browse the repository at this point in the history
RMSE von Mises-Fisher 3d loss
  • Loading branch information
Aske-Rosted authored May 21, 2024
2 parents f5167f3 + 338814c commit 1a944ae
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions src/graphnet/training/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Implement loss calculation."""
# Check(s)
assert prediction.dim() == 2
if target.dim() != prediction.dim():
target = target.squeeze(1)
assert prediction.size() == target.size()

elements = torch.mean((prediction - target) ** 2, dim=-1)
Expand Down Expand Up @@ -443,3 +445,93 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
kappa = prediction[:, 3]
p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]]
return self._evaluate(p, target)


class EnsembleLoss(LossFunction):
"""Chain multiple loss functions together."""

def __init__(
self,
loss_functions: List[LossFunction],
loss_factors: List[float] = None,
prediction_keys: Optional[List[List[int]]] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Chain multiple loss functions together.
Optionally apply a weight to each loss function contribution.
E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5
Args:
loss_functions: A list of loss functions to use.
Each loss function contributes a term to the overall loss.
loss_factors: An optional list of factors that will be mulitplied
to each loss function contribution. Must be ordered according
to `loss_functions`. If not given, the weights default to 1.
prediction_keys: An optional list of lists of indices for which
prediction columns to use for each loss function. If not
given, all columns are used for all loss functions.
"""
if loss_factors is None:
# add weight of 1 - i.e no discrimination
loss_factors = np.repeat(1, len(loss_functions)).tolist()

assert len(loss_functions) == len(loss_factors)
self._factors = loss_factors
self._loss_functions = loss_functions

if prediction_keys is not None:
self._prediction_keys: Optional[List[List[int]]] = prediction_keys
else:
self._prediction_keys = None
super().__init__(*args, **kwargs)

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Calculate loss using multiple loss functions.
Args:
prediction: Output of the model.
target: Target tensor, extracted from graph object.
Returns:
Elementwise loss terms. Shape [N,]
"""
if self._prediction_keys is None:
prediction_keys = [list(range(prediction.size(1)))] * len(
self._loss_functions
)
else:
prediction_keys = self._prediction_keys
for k, (loss_function, prediction_key) in enumerate(
zip(self._loss_functions, prediction_keys)
):
if k == 0:
elements = self._factors[k] * loss_function._forward(
prediction=prediction[:, prediction_key], target=target
)
else:
elements += self._factors[k] * loss_function._forward(
prediction=prediction[:, prediction_key], target=target
)
return elements


class RMSEVonMisesFisher3DLoss(EnsembleLoss):
"""Combine the VonMisesFisher3DLoss with RMSELoss."""

def __init__(self, vmfs_factor: float = 0.05) -> None:
"""VonMisesFisher3DLoss with a RMSE penality term.
The VonMisesFisher3DLoss will be weighted with `vmfs_factor`.
Args:
vmfs_factor: A factor applied to the VonMisesFisher3DLoss term.
Defaults ot 0.05.
"""
super().__init__(
loss_functions=[RMSELoss(), VonMisesFisher3DLoss()],
loss_factors=[1, vmfs_factor],
prediction_keys=[[0, 1, 2], [0, 1, 2, 3]],
)

0 comments on commit 1a944ae

Please sign in to comment.