Skip to content

Commit

Permalink
fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
Aske-Rosted committed May 20, 2024
1 parent b9c3195 commit 338814c
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/graphnet/training/loss_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Implement loss calculation."""
# Check(s)
assert prediction.dim() == 2
if target.dim() != prediction.dim():
target = target.squeeze(1)
assert prediction.size() == target.size()

elements = torch.mean((prediction - target) ** 2, dim=-1)
Expand Down Expand Up @@ -453,6 +455,8 @@ def __init__(
loss_functions: List[LossFunction],
loss_factors: List[float] = None,
prediction_keys: Optional[List[List[int]]] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Chain multiple loss functions together.
Expand Down Expand Up @@ -482,6 +486,7 @@ def __init__(
self._prediction_keys: Optional[List[List[int]]] = prediction_keys
else:
self._prediction_keys = None
super().__init__(*args, **kwargs)

def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
"""Calculate loss using multiple loss functions.
Expand All @@ -504,11 +509,11 @@ def _forward(self, prediction: Tensor, target: Tensor) -> Tensor:
):
if k == 0:
elements = self._factors[k] * loss_function._forward(
prediction=prediction[prediction_key], target=target
prediction=prediction[:, prediction_key], target=target
)
else:
elements += self._factors[k] * loss_function._forward(
prediction=prediction[prediction_key], target=target
prediction=prediction[:, prediction_key], target=target
)
return elements

Expand Down

0 comments on commit 338814c

Please sign in to comment.