From ec741fded76c7bf1b0e71325bd5ab9c6237652ed Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 15 Oct 2024 15:43:11 +0200 Subject: [PATCH 1/8] merge state method --- src/torchmetrics/metric.py | 50 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 940e393c6d1..0371e3eb426 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -398,6 +398,54 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: return batch_val + def merge_states(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: + """Merge incoming metric state to the current state of the metric. + + Args: + incoming_state: either a dict containing a metric state similar to the metric itself or an instance of the + metric class. + + Raises: + RuntimeError: + If the metric has ``full_state_update=True`` or ``dist_sync_on_step=True``. In these cases, the metric + cannot be merged with another metric state in a simple way. The user should overwrite the method in the + metric class to handle the merge operation. + + Example: + >>> from torchmetrics.aggregation import SumMetric + >>> metric1 = SumMetric() + >>> metric2 = SumMetric() + >>> metric1.update(1) + >>> metric2.update(2) + >>> metric1.merge_states(metric2) + >>> metric1.compute() + tensor(3) + + Example: + >>> from torchmetrics.aggregation import SumMetric + >>> metric = SumMetric() + >>> metric.update(1) + >>> # SumMetric has one state variable called `sum_value` + >>> metric.merge_states({"sum_value": torch.tensor(2)}) + >>> metric.compute() + tensor(3) + + """ + if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: + raise RuntimeError( + "``merge_states`` is not supported for metrics with ``full_state_update=True`` or " + "``dist_sync_on_step=True``. Please overwrite the merge_states method in the metric class." + ) + if isinstance(incoming_state, Metric): + this_class = self.__class__ + if not isinstance(incoming_state, this_class): + raise ValueError( + f"Expected incoming state to be an instance of {this_class.__name__} but got {type(incoming_state)}" + ) + incoming_state = incoming_state.metric_state + + self._reduce_states(incoming_state) + def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: """Add an incoming metric state to the current state of the metric. @@ -407,6 +455,8 @@ def _reduce_states(self, incoming_state: Dict[str, Any]) -> None: """ for attr in self._defaults: local_state = getattr(self, attr) + if attr not in incoming_state: + raise ValueError(f"Expected state variable {attr} to be present in incoming state {incoming_state}") global_state = incoming_state[attr] reduce_fn = self._reductions[attr] if reduce_fn == dim_zero_sum: From 6eab70a5637310327a9d9a6c92299b58043501dd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 15 Oct 2024 15:45:03 +0200 Subject: [PATCH 2/8] small change --- src/torchmetrics/metric.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 0371e3eb426..c245c2dccb8 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -64,7 +64,6 @@ class Metric(Module, ABC): * ``update()`` * ``compute()`` - Args: kwargs: additional keyword arguments, see :ref:`Metric kwargs` for more info. From e71b041e57d1887e5e0133b758247d833d9aff9b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Oct 2024 15:26:48 +0200 Subject: [PATCH 3/8] refactor: rename merge_states method to merge_state and update related documentation --- src/torchmetrics/metric.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index c245c2dccb8..25928d5233b 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -397,7 +397,7 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: return batch_val - def merge_states(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: + def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: """Merge incoming metric state to the current state of the metric. Args: @@ -416,7 +416,7 @@ def merge_states(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: >>> metric2 = SumMetric() >>> metric1.update(1) >>> metric2.update(2) - >>> metric1.merge_states(metric2) + >>> metric1.merge_state(metric2) >>> metric1.compute() tensor(3) @@ -425,16 +425,22 @@ def merge_states(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: >>> metric = SumMetric() >>> metric.update(1) >>> # SumMetric has one state variable called `sum_value` - >>> metric.merge_states({"sum_value": torch.tensor(2)}) + >>> metric.merge_state({"sum_value": torch.tensor(2)}) >>> metric.compute() tensor(3) """ + if not isinstance(incoming_state, (dict, Metric)): + raise ValueError( + f"Expected incoming state to be a dict or an instance of Metric but got {type(incoming_state)}" + ) + if self.full_state_update or self.full_state_update is None or self.dist_sync_on_step: raise RuntimeError( - "``merge_states`` is not supported for metrics with ``full_state_update=True`` or " - "``dist_sync_on_step=True``. Please overwrite the merge_states method in the metric class." + "``merge_state`` is not supported for metrics with ``full_state_update=True`` or " + "``dist_sync_on_step=True``. Please overwrite the merge_state method in the metric class." ) + if isinstance(incoming_state, Metric): this_class = self.__class__ if not isinstance(incoming_state, this_class): From 293841adc68f9538bd9881afadc7716ff63c36df Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Oct 2024 15:27:54 +0200 Subject: [PATCH 4/8] fix: set full_state_update to False in clustering metrics --- src/torchmetrics/clustering/adjusted_rand_score.py | 2 +- src/torchmetrics/clustering/dunn_index.py | 2 +- src/torchmetrics/clustering/rand_score.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchmetrics/clustering/adjusted_rand_score.py b/src/torchmetrics/clustering/adjusted_rand_score.py index 9df67a8d802..5c1f5f49276 100644 --- a/src/torchmetrics/clustering/adjusted_rand_score.py +++ b/src/torchmetrics/clustering/adjusted_rand_score.py @@ -64,7 +64,7 @@ class AdjustedRandScore(Metric): is_differentiable = True higher_is_better = None - full_state_update: bool = True + full_state_update: bool = False plot_lower_bound: float = -0.5 plot_upper_bound: float = 1.0 preds: List[Tensor] diff --git a/src/torchmetrics/clustering/dunn_index.py b/src/torchmetrics/clustering/dunn_index.py index 9373db1045e..5a85074443d 100644 --- a/src/torchmetrics/clustering/dunn_index.py +++ b/src/torchmetrics/clustering/dunn_index.py @@ -64,7 +64,7 @@ class DunnIndex(Metric): is_differentiable: bool = True higher_is_better: bool = True - full_state_update: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 data: List[Tensor] labels: List[Tensor] diff --git a/src/torchmetrics/clustering/rand_score.py b/src/torchmetrics/clustering/rand_score.py index 84f244a1bbe..8ded8b27d0d 100644 --- a/src/torchmetrics/clustering/rand_score.py +++ b/src/torchmetrics/clustering/rand_score.py @@ -63,7 +63,7 @@ class RandScore(Metric): is_differentiable = True higher_is_better = None - full_state_update: bool = True + full_state_update: bool = False plot_lower_bound: float = 0.0 preds: List[Tensor] target: List[Tensor] From 17880a37c44bd4fbae12d8baaca45a657da8c4cc Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Oct 2024 15:29:20 +0200 Subject: [PATCH 5/8] tests --- tests/unittests/bases/test_metric.py | 78 +++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/tests/unittests/bases/test_metric.py b/tests/unittests/bases/test_metric.py index 753150478e4..363b2d31a66 100644 --- a/tests/unittests/bases/test_metric.py +++ b/tests/unittests/bases/test_metric.py @@ -24,8 +24,11 @@ import torch from torch import Tensor, tensor from torch.nn import Module, Parameter +from torchmetrics.aggregation import MeanMetric, SumMetric from torchmetrics.classification import BinaryAccuracy -from torchmetrics.regression import PearsonCorrCoef +from torchmetrics.clustering import AdjustedRandScore +from torchmetrics.image import StructuralSimilarityIndexMeasure +from torchmetrics.regression import PearsonCorrCoef, R2Score from unittests._helpers import seed_all from unittests._helpers.testers import DummyListMetric, DummyMetric, DummyMetricMultiOutput, DummyMetricSum @@ -609,3 +612,76 @@ def test_dtype_property(): assert metric.dtype == torch.float64 # should not change after initialization metric.set_dtype(torch.float32) assert metric.dtype == torch.float32 + + +def test_merge_state_feature_basic(): + """Check the merge_state method works as expected for a basic metric.""" + metric1 = SumMetric() + metric2 = SumMetric() + metric1.update(1) + metric2.update(2) + metric1.merge_state(metric2) + assert metric1.compute() == 3 + + metric = SumMetric() + metric.update(1) + metric.merge_state({"sum_value": torch.tensor(2)}) + assert metric.compute() == 3 + + +def test_merge_state_feature_raises_errors(): + """Check the merge_state method raises errors when expected.""" + + class TempMetric(SumMetric): + full_state_update = True + + metric = TempMetric() + metric2 = SumMetric() + metric3 = MeanMetric() + + with pytest.raises(ValueError, match="Expected incoming state to be a.*"): + metric.merge_state(2) + + with pytest.raises(RuntimeError, match="``merge_state`` is not supported.*"): + metric.merge_state({"sum_value": torch.tensor(2)}) + + with pytest.raises(ValueError, match="Expected incoming state to be an.*"): + metric2.merge_state(metric3) + + +@pytest.mark.parametrize( + ("metric_class", "preds", "target"), + [ + (BinaryAccuracy, lambda: torch.randint(2, (100,)), lambda: torch.randint(2, (100,))), + (R2Score, lambda: torch.randn(100), lambda: torch.randn(100)), + (StructuralSimilarityIndexMeasure, lambda: torch.randn(1, 3, 25, 25), lambda: torch.randn(1, 3, 25, 25)), + (AdjustedRandScore, lambda: torch.randint(10, (100,)), lambda: torch.randint(10, (100,))), + ], +) +def test_merge_state_feature_for_different_metrics(metric_class, preds, target): + """Check the merge_state method works as expected for different metrics. + + It should work such that the metric is the same as if it had seen the data twice, but in different ways. + + """ + metric1_1 = metric_class() + metric1_2 = metric_class() + metric2 = metric_class() + + preds1, target1 = preds(), target() + preds2, target2 = preds(), target() + + metric1_1.update(preds1, target1) + metric1_2.update(preds2, target2) + metric2.update(preds1, target1) + metric2.update(preds2, target2) + metric1_1.merge_state(metric1_2) + + # should be the same because it has seen the same data twice, but in different ways + res1 = metric1_1.compute() + res2 = metric2.compute() + assert torch.allclose(res1, res2) + + # should not be the same because it has only seen half the data + res3 = metric1_2.compute() + assert not torch.allclose(res3, res2) From 9954207cbfe17ae61e1a1add21ac0b2c8571d0db Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Oct 2024 15:32:14 +0200 Subject: [PATCH 6/8] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f432c7afa26..b4d5f5a7e77 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `HausdorffDistance` to segmentation package ([#2122](https://github.com/Lightning-AI/torchmetrics/pull/2122)) +- Added method `merge_state` to `Metric` ([#2786](https://github.com/Lightning-AI/torchmetrics/pull/2786)) + + ### Changed - Tracker higher is better integration ([#2649](https://github.com/Lightning-AI/torchmetrics/pull/2649)) From 858b599a5628bc2482acf0cc5de160670349c6c7 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 18 Oct 2024 15:53:49 +0200 Subject: [PATCH 7/8] Update src/torchmetrics/metric.py --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index 25928d5233b..ef4280a501a 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -418,7 +418,7 @@ def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: >>> metric2.update(2) >>> metric1.merge_state(metric2) >>> metric1.compute() - tensor(3) + tensor(3.) Example: >>> from torchmetrics.aggregation import SumMetric From 0556b92038d209e7fdf2ffbc1e2428ea1d34f076 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Fri, 18 Oct 2024 15:53:55 +0200 Subject: [PATCH 8/8] Update src/torchmetrics/metric.py --- src/torchmetrics/metric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/metric.py b/src/torchmetrics/metric.py index ef4280a501a..d858ed34fdc 100644 --- a/src/torchmetrics/metric.py +++ b/src/torchmetrics/metric.py @@ -427,7 +427,7 @@ def merge_state(self, incoming_state: Union[Dict[str, Any], "Metric"]) -> None: >>> # SumMetric has one state variable called `sum_value` >>> metric.merge_state({"sum_value": torch.tensor(2)}) >>> metric.compute() - tensor(3) + tensor(3.) """ if not isinstance(incoming_state, (dict, Metric)):