Skip to content

Commit

Permalink
Add test for skip_unrolling in Loss
Browse files Browse the repository at this point in the history
  • Loading branch information
simeetnayan81 committed Jun 28, 2024
1 parent 64a3e39 commit 2ad32ee
Showing 1 changed file with 48 additions and 1 deletion.
49 changes: 48 additions & 1 deletion tests/ignite/metrics/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from numpy.testing import assert_almost_equal
from torch import nn
from torch.nn.functional import nll_loss
from torch.nn.functional import mse_loss, nll_loss

import ignite.distributed as idist
from ignite.engine import State
Expand Down Expand Up @@ -314,3 +314,50 @@ def compute(self):
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
]
evaluator.run(data)


class CustomMultiMSELoss(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(
self, y_pred: list[torch.Tensor, torch.Tensor], y_true: list[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return mse_loss(a_pred, a_true) + mse_loss(b_pred, b_true)


class DummyLoss3(Loss):
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False):
super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling)
self._expected_loss = expected_loss
self._loss_fn = loss_fn

def reset(self):
pass

def compute(self):
pass

def update(self, output):
y_pred, y_true = output
calculated_loss = self._loss_fn(y_pred=y_pred, y_true=y_true)
assert calculated_loss == self._expected_loss


def test_skip_unrolling_loss():
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
a_true = torch.rand(8, 1)
b_true = torch.rand(8, 1)
y_true = [a_true, b_true]

multi_output_mse_loss = CustomMultiMSELoss()
expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true)

loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
loss_metric.iteration_completed(engine)

0 comments on commit 2ad32ee

Please sign in to comment.