Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: merge_state method #2786

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122))


- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786))


### Changed

- Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649))
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/adjusted_rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class AdjustedRandScore(Metric):

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = -0.5
plot_upper_bound: float = 1.0
preds: List[Tensor]
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/dunn_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class DunnIndex(Metric):

is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
data: List[Tensor]
labels: List[Tensor]
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/clustering/rand_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class RandScore(Metric):

is_differentiable = True
higher_is_better = None
full_state_update: bool = True
full_state_update: bool = False
plot_lower_bound: float = 0.0
preds: List[Tensor]
target: List[Tensor]
Expand Down
57 changes: 56 additions & 1 deletion src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ class Metric(Module, ABC):
* ``update()``
* ``compute()``


Args:
kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info.

Expand Down Expand Up @@ -398,6 +397,60 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any:

return batch_val

def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None:
"""Merge incoming metric state to the current state of the metric.

Args:
incoming_state: either a dict containing a metric state similar to the metric itself or an instance of the
metric class.

Raises:
RuntimeError:
If the metric has ``full_state_update=True`` or ``dist_sync_on_step=True``. In these cases, the metric
cannot be merged with another metric state in a simple way. The user should overwrite the method in the
metric class to handle the merge operation.

Example:
>>> from torchmetrics.aggregation import SumMetric
>>> metric1 = SumMetric()
>>> metric2 = SumMetric()
>>> metric1.update(1)
>>> metric2.update(2)
>>> metric1.merge_state(metric2)
>>> metric1.compute()
tensor(3)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

Example:
>>> from torchmetrics.aggregation import SumMetric
>>> metric = SumMetric()
>>> metric.update(1)
>>> # SumMetric has one state variable called `sum_value`
>>> metric.merge_state({"sum_value": torch.tensor(2)})
>>> metric.compute()
tensor(3)
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

"""
if not isinstance(incoming_state, (dict, Metric)):
raise ValueError(
f"Expected incoming state to be a dict or an instance of Metric but got {type(incoming_state)}"
)

if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step:
raise RuntimeError(
"``merge_state`` is not supported for metrics with ``full_state_update=True`` or "
"``dist_sync_on_step=True``. Please overwrite the merge_state method in the metric class."
)

if isinstance(incoming_state, Metric):
this_class = self.__class__
if not isinstance(incoming_state, this_class):
raise ValueError(
f"Expected incoming state to be an instance of {this_class.__name__} but got {type(incoming_state)}"
)
incoming_state = incoming_state.metric_state

self._reduce_states(incoming_state)

def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""Add an incoming metric state to the current state of the metric.

Expand All @@ -407,6 +460,8 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""
for attr in self._defaults:
local_state = getattr(self, attr)
if attr not in incoming_state:
raise ValueError(f"Expected state variable {attr} to be present in incoming state {incoming_state}")
global_state = incoming_state[attr]
reduce_fn = self._reductions[attr]
if reduce_fn == dim_zero_sum:
Expand Down
78 changes: 77 additions & 1 deletion tests/unittests/bases/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@
import torch
from torch import Tensor, tensor
from torch.nn import Module, Parameter
from torchmetrics.aggregation import MeanMetric, SumMetric
from torchmetrics.classification import BinaryAccuracy
from torchmetrics.regression import PearsonCorrCoef
from torchmetrics.clustering import AdjustedRandScore
from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.regression import PearsonCorrCoef, R2Score

from unittests._helpers import seed_all
from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum
Expand Down Expand Up @@ -609,3 +612,76 @@ def test_dtype_property():
assert metric.dtype == torch.float64 # should not change after initialization
metric.set_dtype(torch.float32)
assert metric.dtype == torch.float32


def test_merge_state_feature_basic():
"""Check the merge_state method works as expected for a basic metric."""
metric1 = SumMetric()
metric2 = SumMetric()
metric1.update(1)
metric2.update(2)
metric1.merge_state(metric2)
assert metric1.compute() == 3

metric = SumMetric()
metric.update(1)
metric.merge_state({"sum_value": torch.tensor(2)})
assert metric.compute() == 3


def test_merge_state_feature_raises_errors():
"""Check the merge_state method raises errors when expected."""

class TempMetric(SumMetric):
full_state_update = True

metric = TempMetric()
metric2 = SumMetric()
metric3 = MeanMetric()

with pytest.raises(ValueError, match="Expected incoming state to be a.*"):
metric.merge_state(2)

with pytest.raises(RuntimeError, match="``merge_state`` is not supported.*"):
metric.merge_state({"sum_value": torch.tensor(2)})

with pytest.raises(ValueError, match="Expected incoming state to be an.*"):
metric2.merge_state(metric3)


@pytest.mark.parametrize(
("metric_class", "preds", "target"),
[
(BinaryAccuracy, lambda: torch.randint(2, (100,)), lambda: torch.randint(2, (100,))),
(R2Score, lambda: torch.randn(100), lambda: torch.randn(100)),
(StructuralSimilarityIndexMeasure, lambda: torch.randn(1, 3, 25, 25), lambda: torch.randn(1, 3, 25, 25)),
(AdjustedRandScore, lambda: torch.randint(10, (100,)), lambda: torch.randint(10, (100,))),
],
)
def test_merge_state_feature_for_different_metrics(metric_class, preds, target):
"""Check the merge_state method works as expected for different metrics.

It should work such that the metric is the same as if it had seen the data twice, but in different ways.

"""
metric1_1 = metric_class()
metric1_2 = metric_class()
metric2 = metric_class()

preds1, target1 = preds(), target()
preds2, target2 = preds(), target()

metric1_1.update(preds1, target1)
metric1_2.update(preds2, target2)
metric2.update(preds1, target1)
metric2.update(preds2, target2)
metric1_1.merge_state(metric1_2)

# should be the same because it has seen the same data twice, but in different ways
res1 = metric1_1.compute()
res2 = metric2.compute()
assert torch.allclose(res1, res2)

# should not be the same because it has only seen half the data
res3 = metric1_2.compute()
assert not torch.allclose(res3, res2)
Loading