From 40beddace00e7881e6d92d8c1a12b9ac18994812 Mon Sep 17 00:00:00 2001 From: Simeet Nayan Date: Wed, 26 Jun 2024 13:20:03 +0530 Subject: [PATCH] Add docstring for skip_unrolling, modify skip_unrolling clause --- ignite/metrics/metric.py | 67 ++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 1aed61c5ee3..2acf6662f81 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -233,6 +233,9 @@ class Metric(Serializable, metaclass=ABCMeta): device: specifies which device updates are accumulated on. Setting the metric's device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By default, CPU. + skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be + true for multi-output model, for example, if``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)`` + Alternatively, output_transform can be used to handle this. Attributes: required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the @@ -290,8 +293,8 @@ def compute(self): res = evaluator.run(data) - .. versionchanged:: 0.4.2 - ``required_output_keys`` became public attribute. + .. versionchanged:: 0.5.1 + ``skip_unrolling`` attribute is introduced for Metric class. """ # public class attribute @@ -300,7 +303,10 @@ def compute(self): _required_output_keys = required_output_keys def __init__( - self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling : bool =False + self, + output_transform: Callable = lambda x: x, + device: Union[str, torch.device] = torch.device("cpu"), + skip_unrolling: bool = False, ): self._output_transform = output_transform @@ -378,33 +384,34 @@ def iteration_completed(self, engine: Engine) -> None: """ output = self._output_transform(engine.state.output) - if self._skip_unrolling is False: - if isinstance(output, Mapping): - if self.required_output_keys is None: - raise TypeError( - f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, " - f"but given {type(output)}" - ) - if not all([k in output for k in self.required_output_keys]): - raise ValueError( - "When transformed engine's output is a mapping, " - f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}" - ) - output = tuple(output[k] for k in self.required_output_keys) - - if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]): - if not (len(output) == 2 and len(output[0]) == len(output[1])): - raise ValueError( - f"Output should have 2 items of the same length, " - f"got {len(output)} and {len(output[0])}, {len(output[1])}" - ) - for o1, o2 in zip(output[0], output[1]): - # o1 and o2 are list of tensors or numbers - tensor_o1 = _to_batched_tensor(o1) - tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) - self.update((tensor_o1, tensor_o2)) - else: - self.update(output) + if isinstance(output, Mapping): + if self.required_output_keys is None: + raise TypeError( + f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, " + f"but given {type(output)}" + ) + if not all([k in output for k in self.required_output_keys]): + raise ValueError( + "When transformed engine's output is a mapping, " + f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}" + ) + output = tuple(output[k] for k in self.required_output_keys) + + if ( + (not self._skip_unrolling) + and isinstance(output, Sequence) + and all([_is_list_of_tensors_or_numbers(o) for o in output]) + ): + if not (len(output) == 2 and len(output[0]) == len(output[1])): + raise ValueError( + f"Output should have 2 items of the same length, " + f"got {len(output)} and {len(output[0])}, {len(output[1])}" + ) + for o1, o2 in zip(output[0], output[1]): + # o1 and o2 are list of tensors or numbers + tensor_o1 = _to_batched_tensor(o1) + tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device) + self.update((tensor_o1, tensor_o2)) else: self.update(output)