From 2ad32ee07e69acc3b1c561635702a43c15a9c7c8 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Fri, 28 Jun 2024 21:00:00 +0530 Subject: [PATCH] Add test for skip_unrolling in Loss --- tests/ignite/metrics/test_loss.py | 49 ++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 19cc68cd45c..371c2a5e551 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -5,7 +5,7 @@ import torch from numpy.testing import assert_almost_equal from torch import nn -from torch.nn.functional import nll_loss +from torch.nn.functional import mse_loss, nll_loss import ignite.distributed as idist from ignite.engine import State @@ -314,3 +314,50 @@ def compute(self): (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), ] evaluator.run(data) + + +class CustomMultiMSELoss(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, y_pred: list[torch.Tensor, torch.Tensor], y_true: list[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + a_true, b_true = y_true + a_pred, b_pred = y_pred + return mse_loss(a_pred, a_true) + mse_loss(b_pred, b_true) + + +class DummyLoss3(Loss): + def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False): + super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling) + self._expected_loss = expected_loss + self._loss_fn = loss_fn + + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + y_pred, y_true = output + calculated_loss = self._loss_fn(y_pred=y_pred, y_true=y_true) + assert calculated_loss == self._expected_loss + + +def test_skip_unrolling_loss(): + a_pred = torch.rand(8, 1) + b_pred = torch.rand(8, 1) + y_pred = [a_pred, b_pred] + a_true = torch.rand(8, 1) + b_true = torch.rand(8, 1) + y_true = [a_true, b_true] + + multi_output_mse_loss = CustomMultiMSELoss() + expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true) + + loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True) + state = State(output=(y_pred, y_true)) + engine = MagicMock(state=state) + loss_metric.iteration_completed(engine)