Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a variable skip_unrolling in class Metric #3258

Merged
merged 14 commits into from
Jul 1, 2024
Merged
18 changes: 16 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ class Metric(Serializable, metaclass=ABCMeta):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if ``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, ``output_transform`` can be used to handle this.

Attributes:
required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the
Expand Down Expand Up @@ -292,6 +295,9 @@ def compute(self):

.. versionchanged:: 0.4.2
``required_output_keys`` became public attribute.

.. versionchanged:: 0.5.1
``skip_unrolling`` attribute is added.
"""

# public class attribute
Expand All @@ -300,7 +306,10 @@ def compute(self):
_required_output_keys = required_output_keys

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
self._output_transform = output_transform

Expand All @@ -309,6 +318,7 @@ def __init__(
raise ValueError("Cannot create metric on an XLA device. Use device='cpu' instead.")

self._device = torch.device(device)
self._skip_unrolling = skip_unrolling
self.reset()

@abstractmethod
Expand Down Expand Up @@ -390,7 +400,11 @@ def iteration_completed(self, engine: Engine) -> None:
)
output = tuple(output[k] for k in self.required_output_keys)

if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]):
if (
(not self._skip_unrolling)
and isinstance(output, Sequence)
and all([_is_list_of_tensors_or_numbers(o) for o in output])
):
if not (len(output) == 2 and len(output[0]) == len(output[1])):
raise ValueError(
f"Output should have 2 items of the same length, "
Expand Down
30 changes: 30 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,3 +1416,33 @@ def wrapper(x, **kwargs):
assert (output == expected).all(), (output, expected)
else:
assert output == expected, (output, expected)


class DummyMetric5(Metric):
def __init__(self, true_output, output_transform=lambda x: x, skip_unrolling=False):
super(DummyMetric5, self).__init__(output_transform=output_transform, skip_unrolling=skip_unrolling)
self.true_output = true_output

def reset(self):
pass

def compute(self):
pass

def update(self, output):
assert output == self.true_output


def test_skip_unrolling():
# y_pred and y are ouputs recieved from a multi_output model
a_pred = torch.rand(8, 1)
b_pred = torch.rand(8, 1)
y_pred = [a_pred, b_pred]
a_true = torch.rand(8, 1)
b_true = torch.rand(8, 1)
y_true = [a_true, b_true]

metric = DummyMetric5(true_output=(y_pred, y_true), skip_unrolling=True)
state = State(output=(y_pred, y_true))
engine = MagicMock(state=state)
metric.iteration_completed(engine)
Loading