Skip to content

Commit

Permalink
large refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed May 20, 2024
1 parent 3e33904 commit b9c3195
Showing 1 changed file with 75 additions and 18 deletions.
93 changes: 75 additions & 18 deletions src/graphnet/training/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,31 +445,88 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
return self._evaluate(p, target)


class RMSEVonMisesFisher3DLoss(VonMisesFisherLoss):
"""von Mises-Fisher loss function vectors in the 3D plane."""
class EnsembleLoss(LossFunction):
"""Chain multiple loss functions together."""

def __init__(
self,
loss_functions: List[LossFunction],
loss_factors: List[float] = None,
prediction_keys: Optional[List[List[int]]] = None,
) -> None:
"""Chain multiple loss functions together.
Optionally apply a weight to each loss function contribution.
E.g. Loss = RMSE*0.5 + LogCoshLoss*1.5
Args:
loss_functions: A list of loss functions to use.
Each loss function contributes a term to the overall loss.
loss_factors: An optional list of factors that will be mulitplied
to each loss function contribution. Must be ordered according
to `loss_functions`. If not given, the weights default to 1.
prediction_keys: An optional list of lists of indices for which
prediction columns to use for each loss function. If not
given, all columns are used for all loss functions.
"""
if loss_factors is None:
# add weight of 1 - i.e no discrimination
loss_factors = np.repeat(1, len(loss_functions)).tolist()

assert len(loss_functions) == len(loss_factors)
self._factors = loss_factors
self._loss_functions = loss_functions

if prediction_keys is not None:
self._prediction_keys: Optional[List[List[int]]] = prediction_keys
else:
self._prediction_keys = None

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Calculate von Mises-Fisher loss for a direction in the 3D.
"""Calculate loss using multiple loss functions.
Args:
prediction: Output of the model. Must have shape [N, 4] where
columns 0, 1, 2 are predictions of `direction` and last column
is an estimate of `kappa`.
prediction: Output of the model.
target: Target tensor, extracted from graph object.
Returns:
Elementwise von Mises-Fisher loss terms. Shape [N,]
Elementwise loss terms. Shape [N,]
"""
target = target.reshape(-1, 3)
# Check(s)
assert prediction.dim() == 2 and prediction.size()[1] == 4
assert target.dim() == 2
assert prediction.size()[0] == target.size()[0]
if self._prediction_keys is None:
prediction_keys = [list(range(prediction.size(1)))] * len(
self._loss_functions
)
else:
prediction_keys = self._prediction_keys
for k, (loss_function, prediction_key) in enumerate(
zip(self._loss_functions, prediction_keys)
):
if k == 0:
elements = self._factors[k] * loss_function._forward(
prediction=prediction[prediction_key], target=target
)
else:
elements += self._factors[k] * loss_function._forward(
prediction=prediction[prediction_key], target=target
)
return elements

kappa = prediction[:, 3]
p = kappa.unsqueeze(1) * prediction[:, [0, 1, 2]]
elements = 0.05 * self._evaluate(p, target)
elements += torch.sqrt(
torch.mean((prediction[:, :-1] - target) ** 2, dim=-1)

class RMSEVonMisesFisher3DLoss(EnsembleLoss):
"""Combine the VonMisesFisher3DLoss with RMSELoss."""

def __init__(self, vmfs_factor: float = 0.05) -> None:
"""VonMisesFisher3DLoss with a RMSE penality term.
The VonMisesFisher3DLoss will be weighted with `vmfs_factor`.
Args:
vmfs_factor: A factor applied to the VonMisesFisher3DLoss term.
Defaults ot 0.05.
"""
super().__init__(
loss_functions=[RMSELoss(), VonMisesFisher3DLoss()],
loss_factors=[1, vmfs_factor],
prediction_keys=[[0, 1, 2], [0, 1, 2, 3]],
)
return elements

0 comments on commit b9c3195

Please sign in to comment.