Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a variable skip_unrolling in class Metric #3258

Merged
merged 14 commits into from
Jul 1, 2024
Merged
8 changes: 7 additions & 1 deletion ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ class Loss(Metric):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether input should be unrolled or not before it is passed to to loss_fn.
Should be true for multi-output model, for example, if ``y_pred`` contains multi-ouput as
``(y_pred_a, y_pred_b)``

Attributes:
required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the
Expand Down Expand Up @@ -62,6 +65,8 @@ class Loss(Metric):

-0.3499999...

.. versionchanged:: 0.5.1
``skip_unrolling`` attribute is added.
"""

required_output_keys = ("y_pred", "y", "criterion_kwargs")
Expand All @@ -73,8 +78,9 @@ def __init__(
output_transform: Callable = lambda x: x,
batch_size: Callable = len,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
super(Loss, self).__init__(output_transform, device=device)
super(Loss, self).__init__(output_transform, device=device, skip_unrolling=skip_unrolling)
self._loss_fn = loss_fn
self._batch_size = batch_size

Expand Down
69 changes: 67 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,60 @@ class Metric(Serializable, metaclass=ABCMeta):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``

Example usage:
The following example shows a custom loss metric that expects input from a multi-output model.

.. code-block:: python

import torch
import torch.nn as nn
import torch.nn.functional as F

from ignite.engine import create_supervised_evaluator
from ignite.metrics import Loss

class MyLoss(nn.Module):
def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
super().__init__()
self.ca = ca
self.cb = cb

def forward(self,
y_pred: Tuple[torch.Tensor, torch.Tensor],
y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)


def prepare_batch(batch, device, non_blocking):
return torch.rand(4, 1), (torch.rand(4, 1), torch.rand(4, 2))


class MyModel(nn.Module):

def forward(self, x):
return torch.rand(4, 1), torch.rand(4, 2)


model = MyModel()

device = "cpu"
loss = MyLoss(0.5, 1.0)
metrics = {
"Loss": Loss(loss, skip_unrolling=True)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)


data = range(10)
train_evaluator.run(data)
train_evaluator.state.metrics["Loss"]

Alternatively, ``output_transform`` can be used to handle this.

Attributes:
required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the
Expand Down Expand Up @@ -292,6 +346,9 @@ def compute(self):

.. versionchanged:: 0.4.2
``required_output_keys`` became public attribute.

.. versionchanged:: 0.5.1
``skip_unrolling`` attribute is added.
"""

# public class attribute
Expand All @@ -300,7 +357,10 @@ def compute(self):
_required_output_keys = required_output_keys

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
self._output_transform = output_transform

Expand All @@ -309,6 +369,7 @@ def __init__(
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")

self._device = torch.device(device)
self._skip_unrolling = skip_unrolling
self.reset()

@abstractmethod
Expand Down Expand Up @@ -390,7 +451,11 @@ def iteration_completed(self, engine: Engine) -> None:
)
output = tuple(output[k] for k in self.required_output_keys)

if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]):
if (
(not self._skip_unrolling)
and isinstance(output, Sequence)
and all([_is_list_of_tensors_or_numbers(o) for o in output])
):
if not (len(output) == 2 and len(output[0]) == len(output[1])):
raise ValueError(
f"Output should have 2 items of the same length, "
Expand Down
49 changes: 48 additions & 1 deletion tests/ignite/metrics/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from numpy.testing import assert_almost_equal
from torch import nn
from torch.nn.functional import nll_loss
from torch.nn.functional import mse_loss, nll_loss

import ignite.distributed as idist
from ignite.engine import State
Expand Down Expand Up @@ -314,3 +314,50 @@ def compute(self):
(torch.rand(4, 10), torch.randint(0, 3, size=(4,))),
]
evaluator.run(data)


class CustomMultiMSELoss(nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(
self, y_pred: list[torch.Tensor, torch.Tensor], y_true: list[torch.Tensor, torch.Tensor]
) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return mse_loss(a_pred, a_true) + mse_loss(b_pred, b_true)


class DummyLoss3(Loss):
def __init__(self, loss_fn, expected_loss, output_transform=lambda x: x, skip_unrolling=False):
super(DummyLoss3, self).__init__(loss_fn, output_transform=output_transform, skip_unrolling=skip_unrolling)
self._expected_loss = expected_loss
self._loss_fn = loss_fn

def reset(self):
pass

def compute(self):
pass

def update(self, output):
y_pred, y_true = output
calculated_loss = self._loss_fn(y_pred=y_pred, y_true=y_true)
assert calculated_loss == self._expected_loss


def test_skip_unrolling_loss():
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
a_true = torch.rand(8, 1)
b_true = torch.rand(8, 1)
y_true = [a_true, b_true]

multi_output_mse_loss = CustomMultiMSELoss()
expected_loss = multi_output_mse_loss(y_pred=y_pred, y_true=y_true)

loss_metric = DummyLoss3(loss_fn=multi_output_mse_loss, expected_loss=expected_loss, skip_unrolling=True)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
loss_metric.iteration_completed(engine)
30 changes: 30 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,3 +1416,33 @@ def wrapper(x, **kwargs):
assert (output == expected).all(), (output, expected)
else:
assert output == expected, (output, expected)


class DummyMetric5(Metric):
def __init__(self, true_output, output_transform=lambda x: x, skip_unrolling=False):
super(DummyMetric5, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
self.true_output = true_output

def reset(self):
pass

def compute(self):
pass

def update(self, output):
assert output == self.true_output


def test_skip_unrolling():
# y_pred and y are ouputs recieved from a multi_output model
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
a_true = torch.rand(8, 1)
b_true = torch.rand(8, 1)
y_true = [a_true, b_true]

metric = DummyMetric5(true_output=(y_pred, y_true), skip_unrolling=True)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
metric.iteration_completed(engine)