Skip to content

Commit

Permalink
Add docstring for skip_unrolling, modify skip_unrolling clause
Browse files Browse the repository at this point in the history
  • Loading branch information
simeetnayan81 committed Jun 26, 2024
1 parent b086835 commit 40bedda
Showing 1 changed file with 37 additions and 30 deletions.
67 changes: 37 additions & 30 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ class Metric(Serializable, metaclass=ABCMeta):
device: specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.
skip_unrolling: specifies whether output should be unrolled before being fed to update method. Should be
true for multi-output model, for example, if``y_pred`` contains multi-ouput as ``(y_pred_a, y_pred_b)``
Alternatively, output_transform can be used to handle this.
Attributes:
required_output_keys: dictionary defines required keys to be found in ``engine.state.output`` if the
Expand Down Expand Up @@ -290,8 +293,8 @@ def compute(self):
res = evaluator.run(data)
.. versionchanged:: 0.4.2
``required_output_keys`` became public attribute.
.. versionchanged:: 0.5.1
``skip_unrolling`` attribute is introduced for Metric class.
"""

# public class attribute
Expand All @@ -300,7 +303,10 @@ def compute(self):
_required_output_keys = required_output_keys

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu"), skip_unrolling : bool =False
self,
output_transform: Callable = lambda x: x,
device: Union[str, torch.device] = torch.device("cpu"),
skip_unrolling: bool = False,
):
self._output_transform = output_transform

Expand Down Expand Up @@ -378,33 +384,34 @@ def iteration_completed(self, engine: Engine) -> None:
"""

output = self._output_transform(engine.state.output)
if self._skip_unrolling is False:
if isinstance(output, Mapping):
if self.required_output_keys is None:
raise TypeError(
f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, "
f"but given {type(output)}"
)
if not all([k in output for k in self.required_output_keys]):
raise ValueError(
"When transformed engine's output is a mapping, "
f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}"
)
output = tuple(output[k] for k in self.required_output_keys)

if isinstance(output, Sequence) and all([_is_list_of_tensors_or_numbers(o) for o in output]):
if not (len(output) == 2 and len(output[0]) == len(output[1])):
raise ValueError(
f"Output should have 2 items of the same length, "
f"got {len(output)} and {len(output[0])}, {len(output[1])}"
)
for o1, o2 in zip(output[0], output[1]):
# o1 and o2 are list of tensors or numbers
tensor_o1 = _to_batched_tensor(o1)
tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device)
self.update((tensor_o1, tensor_o2))
else:
self.update(output)
if isinstance(output, Mapping):
if self.required_output_keys is None:
raise TypeError(
f"Transformed engine output for {self.__class__.__name__} metric should be a tuple/list, "
f"but given {type(output)}"
)
if not all([k in output for k in self.required_output_keys]):
raise ValueError(
"When transformed engine's output is a mapping, "
f"it should contain {self.required_output_keys} keys, but given {list(output.keys())}"
)
output = tuple(output[k] for k in self.required_output_keys)

if (
(not self._skip_unrolling)
and isinstance(output, Sequence)
and all([_is_list_of_tensors_or_numbers(o) for o in output])
):
if not (len(output) == 2 and len(output[0]) == len(output[1])):
raise ValueError(
f"Output should have 2 items of the same length, "
f"got {len(output)} and {len(output[0])}, {len(output[1])}"
)
for o1, o2 in zip(output[0], output[1]):
# o1 and o2 are list of tensors or numbers
tensor_o1 = _to_batched_tensor(o1)
tensor_o2 = _to_batched_tensor(o2, device=tensor_o1.device)
self.update((tensor_o1, tensor_o2))
else:
self.update(output)

Expand Down

0 comments on commit 40bedda

Please sign in to comment.