From b0868359b671c8e16f0bdaf2de6bf4774c20db5c Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Tue, 25 Jun 2024 20:56:27 +0530 Subject: [PATCH 01/13] Introduce a variable skip_unrolling in class Metric --- ignite/metrics/metric.py | 54 +++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 25 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 39e5cb74522..1aed61c5ee3 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -300,7 +300,7 @@ def compute(self): _required_output_keys = required_output_keys def __init__( - self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling : bool =False ): self._output_transform = output_transform @@ -309,6 +309,7 @@ def __init__( raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.") self._device = torch.device(device) + self._skip_unrolling = skip_unrolling self.reset() @abstractmethod @@ -377,30 +378,33 @@ def iteration_completed(self, engine: Engine) -> None: """ output = self._output_transform(engine.state.output) - if isinstance(output, Mapping): - if self.required_output_keys is None: - raise TypeError( - f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, " - f"but given {type(output)}" - ) - if not all([k in output for k in self.required_output_keys]): - raise ValueError( - "When transformed engine's output is a mapping, " - f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}" - ) - output = tuple(output[k] for k in self.required_output_keys) - - if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): - if not (len(output) == 2 and len(output[0]) == len(output[1])): - raise ValueError( - f"Output should have 2 items of the same length, " - f"got {len(output)} and {len(output[0])}, {len(output[1])}" - ) - for o1, o2 in zip(output[0], output[1]): - # o1 and o2 are list of tensors or numbers - tensor_o1 = _to_batched_tensor(o1) - tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) - self.update((tensor_o1, tensor_o2)) + if self._skip_unrolling is False: + if isinstance(output, Mapping): + if self.required_output_keys is None: + raise TypeError( + f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, " + f"but given {type(output)}" + ) + if not all([k in output for k in self.required_output_keys]): + raise ValueError( + "When transformed engine's output is a mapping, " + f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}" + ) + output = tuple(output[k] for k in self.required_output_keys) + + if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): + if not (len(output) == 2 and len(output[0]) == len(output[1])): + raise ValueError( + f"Output should have 2 items of the same length, " + f"got {len(output)} and {len(output[0])}, {len(output[1])}" + ) + for o1, o2 in zip(output[0], output[1]): + # o1 and o2 are list of tensors or numbers + tensor_o1 = _to_batched_tensor(o1) + tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) + self.update((tensor_o1, tensor_o2)) + else: + self.update(output) else: self.update(output) From 40beddace00e7881e6d92d8c1a12b9ac18994812 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Wed, 26 Jun 2024 13:20:03 +0530 Subject: [PATCH 02/13] Add docstring for skip_unrolling, modify skip_unrolling clause --- ignite/metrics/metric.py | 67 ++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 1aed61c5ee3..2acf6662f81 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -233,6 +233,9 @@ class Metric(Serializable, metaclass=ABCMeta): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, output_transform can be used to handle this. Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the @@ -290,8 +293,8 @@ def compute(self): res = evaluator.run(data) - .. versionchanged:: 0.4.2 - ``required_output_keys`` became public attribute. + .. versionchanged:: 0.5.1 + ``skip_unrolling`` attribute is introduced for Metric class. """ # public class attribute @@ -300,7 +303,10 @@ def compute(self): _required_output_keys = required_output_keys def __init__( - self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling : bool =False + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ): self._output_transform = output_transform @@ -378,33 +384,34 @@ def iteration_completed(self, engine: Engine) -> None: """ output = self._output_transform(engine.state.output) - if self._skip_unrolling is False: - if isinstance(output, Mapping): - if self.required_output_keys is None: - raise TypeError( - f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, " - f"but given {type(output)}" - ) - if not all([k in output for k in self.required_output_keys]): - raise ValueError( - "When transformed engine's output is a mapping, " - f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}" - ) - output = tuple(output[k] for k in self.required_output_keys) - - if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): - if not (len(output) == 2 and len(output[0]) == len(output[1])): - raise ValueError( - f"Output should have 2 items of the same length, " - f"got {len(output)} and {len(output[0])}, {len(output[1])}" - ) - for o1, o2 in zip(output[0], output[1]): - # o1 and o2 are list of tensors or numbers - tensor_o1 = _to_batched_tensor(o1) - tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) - self.update((tensor_o1, tensor_o2)) - else: - self.update(output) + if isinstance(output, Mapping): + if self.required_output_keys is None: + raise TypeError( + f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, " + f"but given {type(output)}" + ) + if not all([k in output for k in self.required_output_keys]): + raise ValueError( + "When transformed engine's output is a mapping, " + f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}" + ) + output = tuple(output[k] for k in self.required_output_keys) + + if ( + (not self._skip_unrolling) + and isinstance(output, Sequence) + and all([_is_list_of_tensors_or_numbers(o) for o in output]) + ): + if not (len(output) == 2 and len(output[0]) == len(output[1])): + raise ValueError( + f"Output should have 2 items of the same length, " + f"got {len(output)} and {len(output[0])}, {len(output[1])}" + ) + for o1, o2 in zip(output[0], output[1]): + # o1 and o2 are list of tensors or numbers + tensor_o1 = _to_batched_tensor(o1) + tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) + self.update((tensor_o1, tensor_o2)) else: self.update(output) From 571f5d2def1c0ab5a8f19091e93b411fdd4a040a Mon Sep 17 00:00:00 2001 From: Simeet Nayan <78461155+simeetnayan81@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:04:38 +0530 Subject: [PATCH 03/13] Modify docstring Co-authored-by: vfdev --- ignite/metrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 2acf6662f81..7864d57f15a 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -234,7 +234,7 @@ class Metric(Serializable, metaclass=ABCMeta): metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be - true for multi-output model, for example, if``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` Alternatively, output_transform can be used to handle this. Attributes: From 58114bd9ebe54b8250084c9c91665d329ff6107d Mon Sep 17 00:00:00 2001 From: Simeet Nayan <78461155+simeetnayan81@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:06:09 +0530 Subject: [PATCH 04/13] Apply suggestions from code review Co-authored-by: vfdev --- ignite/metrics/metric.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 7864d57f15a..886eb258d1f 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -235,7 +235,7 @@ class Metric(Serializable, metaclass=ABCMeta): non-blocking. By default, CPU. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` - Alternatively, output_transform can be used to handle this. + Alternatively, ``output_transform`` can be used to handle this. Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the @@ -294,7 +294,7 @@ def compute(self): res = evaluator.run(data) .. versionchanged:: 0.5.1 - ``skip_unrolling`` attribute is introduced for Metric class. + ``skip_unrolling`` attribute is added. """ # public class attribute From 9d5daea63f5d87fdd7ce193862ed10a4c3d38c01 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Thu, 27 Jun 2024 08:44:51 +0530 Subject: [PATCH 05/13] Modify docstring, revert version tag --- ignite/metrics/metric.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 886eb258d1f..d5cfa96d49d 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -293,6 +293,9 @@ def compute(self): res = evaluator.run(data) + .. versionchanged:: 0.4.2 + ``required_output_keys`` became public attribute. + .. versionchanged:: 0.5.1 ``skip_unrolling`` attribute is added. """ From c14e1d8c8d5063b45603fdffb61ee041d40425eb Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Fri, 28 Jun 2024 11:09:52 +0530 Subject: [PATCH 06/13] Add test_skip_unrolling, DummyMetric5 class --- tests/ignite/metrics/test_metric.py | 30 +++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/tests/ignite/metrics/test_metric.py b/tests/ignite/metrics/test_metric.py index f9db11b1a37..96c19d668d7 100644 --- a/tests/ignite/metrics/test_metric.py +++ b/tests/ignite/metrics/test_metric.py @@ -1416,3 +1416,33 @@ def wrapper(x, **kwargs): assert (output == expected).all(), (output, expected) else: assert output == expected, (output, expected) + + +class DummyMetric5(Metric): + def __init__(self, true_output, output_transform=lambda x: x, skip_unrolling=False): + super(DummyMetric5, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling) + self.true_output = true_output + + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + assert output == self.true_output + + +def test_skip_unrolling(): + # y_pred and y are ouputs recieved from a multi_output model + a_pred = torch.rand(8, 1) + b_pred = torch.rand(8, 1) + y_pred = [a_pred, b_pred] + a_true = torch.rand(8, 1) + b_true = torch.rand(8, 1) + y_true = [a_true, b_true] + + metric = DummyMetric5(true_output=(y_pred, y_true), skip_unrolling=True) + state = State(output=(y_pred, y_true)) + engine = MagicMock(state=state) + metric.iteration_completed(engine) From a66f856152234b17be4d80226e6f4a3351cc7d30 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Fri, 28 Jun 2024 17:11:24 +0530 Subject: [PATCH 07/13] Add example usage of skip unrolling in Metric, Update Loss class with skip_unrolling arg --- ignite/metrics/loss.py | 8 ++++++- ignite/metrics/metric.py | 49 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index 7182e7033d5..ad81ed1499a 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -29,6 +29,9 @@ class Loss(Metric): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether input should be unrolled or not before it is passed to to loss_fn. + Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as + ``(y_pred_a, y_pred_b)`` Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the @@ -62,6 +65,8 @@ class Loss(Metric): -0.3499999... + .. versionchanged:: 0.5.1 + ``skip_unrolling`` attribute is added. """ required_output_keys = ("y_pred", "y", "criterion_kwargs") @@ -73,8 +78,9 @@ def __init__( output_transform: Callable = lambda x: x, batch_size: Callable = len, device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ): - super(Loss, self).__init__(output_transform, device=device) + super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling) self._loss_fn = loss_fn self._batch_size = batch_size diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index d5cfa96d49d..bfe2b55ac88 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -235,6 +235,55 @@ class Metric(Serializable, metaclass=ABCMeta): non-blocking. By default, CPU. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + + Example usage: + The following example shows a custom loss metric that expects input from a multi-output model. + .. code-block:: python + import torch + import torch.nn as nn + import torch.nn.functional as F + + from ignite.engine import create_supervised_evaluator + from ignite.metrics import Loss + + class MyLoss(nn.Module): + def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None: + super().__init__() + self.ca = ca + self.cb = cb + + def forward(self, + y_pred: Tuple[torch.Tensor, torch.Tensor], + y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor: + a_true, b_true = y_true + a_pred, b_pred = y_pred + return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true) + + + def prepare_batch(batch, device, non_blocking): + return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2)) + + + class MyModel(nn.Module): + + def forward(self, x): + return torch.rand(4, 1), torch.rand(4, 2) + + + model = MyModel() + + device = "cpu" + loss = MyLoss(0.5, 1.0) + metrics = { + "Loss": Loss(loss, skip_unrolling=True) + } + train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch) + + + data = range(10) + train_evaluator.run(data) + train_evaluator.state.metrics["Loss"] + Alternatively, ``output_transform`` can be used to handle this. Attributes: From 64a3e39aa09211b9cf1dd281ec179ed4d32ed160 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Fri, 28 Jun 2024 17:31:37 +0530 Subject: [PATCH 08/13] Fix doc --- ignite/metrics/metric.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index bfe2b55ac88..bd5e48e948f 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -238,7 +238,9 @@ class Metric(Serializable, metaclass=ABCMeta): Example usage: The following example shows a custom loss metric that expects input from a multi-output model. + .. code-block:: python + import torch import torch.nn as nn import torch.nn.functional as F From 2ad32ee07e69acc3b1c561635702a43c15a9c7c8 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Fri, 28 Jun 2024 21:00:00 +0530 Subject: [PATCH 09/13] Add test for skip_unrolling in Loss --- tests/ignite/metrics/test_loss.py | 49 ++++++++++++++++++++++++++++++- 1 file changed, 48 insertions(+), 1 deletion(-) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 19cc68cd45c..371c2a5e551 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -5,7 +5,7 @@ import torch from numpy.testing import assert_almost_equal from torch import nn -from torch.nn.functional import nll_loss +from torch.nn.functional import mse_loss, nll_loss import ignite.distributed as idist from ignite.engine import State @@ -314,3 +314,50 @@ def compute(self): (torch.rand(4, 10), torch.randint(0, 3, size=(4,))), ] evaluator.run(data) + + +class CustomMultiMSELoss(nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward( + self, y_pred: list[torch.Tensor, torch.Tensor], y_true: list[torch.Tensor, torch.Tensor] + ) -> torch.Tensor: + a_true, b_true = y_true + a_pred, b_pred = y_pred + return mse_loss(a_pred, a_true) + mse_loss(b_pred, b_true) + + +class DummyLoss3(Loss): + def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False): + super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling) + self._expected_loss = expected_loss + self._loss_fn = loss_fn + + def reset(self): + pass + + def compute(self): + pass + + def update(self, output): + y_pred, y_true = output + calculated_loss = self._loss_fn(y_pred=y_pred, y_true=y_true) + assert calculated_loss == self._expected_loss + + +def test_skip_unrolling_loss(): + a_pred = torch.rand(8, 1) + b_pred = torch.rand(8, 1) + y_pred = [a_pred, b_pred] + a_true = torch.rand(8, 1) + b_true = torch.rand(8, 1) + y_true = [a_true, b_true] + + multi_output_mse_loss = CustomMultiMSELoss() + expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true) + + loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True) + state = State(output=(y_pred, y_true)) + engine = MagicMock(state=state) + loss_metric.iteration_completed(engine) From 6cca3393f01eab61489b49c8edf3665c86d4107f Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 28 Jun 2024 21:11:14 +0200 Subject: [PATCH 10/13] Apply suggestions from code review --- ignite/metrics/loss.py | 2 +- ignite/metrics/metric.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index ad81ed1499a..2be0a7d2387 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -66,7 +66,7 @@ class Loss(Metric): -0.3499999... .. versionchanged:: 0.5.1 - ``skip_unrolling`` attribute is added. + ``skip_unrolling`` argument is added. """ required_output_keys = ("y_pred", "y", "criterion_kwargs") diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index bd5e48e948f..5c509ff5cb6 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -236,7 +236,7 @@ class Metric(Serializable, metaclass=ABCMeta): skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` - Example usage: + Example: The following example shows a custom loss metric that expects input from a multi-output model. .. code-block:: python @@ -348,7 +348,7 @@ def compute(self): ``required_output_keys`` became public attribute. .. versionchanged:: 0.5.1 - ``skip_unrolling`` attribute is added. + ``skip_unrolling`` argument is added. """ # public class attribute From 425776c16628be61a10cb42853e84849a958d2c2 Mon Sep 17 00:00:00 2001 From: vfdev Date: Fri, 28 Jun 2024 21:12:03 +0200 Subject: [PATCH 11/13] Update ignite/metrics/metric.py --- ignite/metrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 5c509ff5cb6..e299b460522 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -236,7 +236,7 @@ class Metric(Serializable, metaclass=ABCMeta): skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` - Example: + Examples: The following example shows a custom loss metric that expects input from a multi-output model. .. code-block:: python From 9d9e6dfe2da1b59173a85b104eade97c122b6967 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Sat, 29 Jun 2024 04:56:44 +0530 Subject: [PATCH 12/13] Update docstring --- ignite/metrics/metric.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index e299b460522..4ccfd8ea7af 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -235,6 +235,7 @@ class Metric(Serializable, metaclass=ABCMeta): non-blocking. By default, CPU. skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, ``output_transform`` can be used to handle this. Examples: The following example shows a custom loss metric that expects input from a multi-output model. @@ -286,8 +287,6 @@ def forward(self, x): train_evaluator.run(data) train_evaluator.state.metrics["Loss"] - Alternatively, ``output_transform`` can be used to handle this. - Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the latter is a dictionary. By default, ``("y_pred", "y")``. This is useful with custom metrics that can require From da69053536b783871a7854c4a525b9f4cd61aefe Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Sat, 29 Jun 2024 10:50:08 +0530 Subject: [PATCH 13/13] fix test_loss.py for python below 3.9 --- tests/ignite/metrics/test_loss.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 371c2a5e551..0e945bec58c 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -1,4 +1,5 @@ import os +from typing import Tuple from unittest.mock import MagicMock import pytest @@ -321,7 +322,7 @@ def __init__(self) -> None: super().__init__() def forward( - self, y_pred: list[torch.Tensor, torch.Tensor], y_true: list[torch.Tensor, torch.Tensor] + self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor] ) -> torch.Tensor: a_true, b_true = y_true a_pred, b_pred = y_pred