diff --git a/CHANGELOG.md b/CHANGELOG.md index f8143e1e07e..aebadcf8fe5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed multiclass recall macro avg. ignore index ([#2710](https://github.com/Lightning-AI/torchmetrics/pull/2710)) --- diff --git a/src/torchmetrics/classification/precision_recall.py b/src/torchmetrics/classification/precision_recall.py index 0380545b5ac..f889b01f770 100644 --- a/src/torchmetrics/classification/precision_recall.py +++ b/src/torchmetrics/classification/precision_recall.py @@ -746,6 +746,7 @@ def compute(self) -> Tensor: fn, average=self.average, multidim_average=self.multidim_average, + ignore_index=self.ignore_index, top_k=self.top_k, zero_division=self.zero_division, ) diff --git a/src/torchmetrics/functional/classification/precision_recall.py b/src/torchmetrics/functional/classification/precision_recall.py index 96214c82274..5f233863e8a 100644 --- a/src/torchmetrics/functional/classification/precision_recall.py +++ b/src/torchmetrics/functional/classification/precision_recall.py @@ -43,6 +43,7 @@ def _precision_recall_reduce( average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]], multidim_average: Literal["global", "samplewise"] = "global", multilabel: bool = False, + ignore_index: Optional[int] = None, top_k: int = 1, zero_division: float = 0, ) -> Tensor: @@ -56,7 +57,7 @@ def _precision_recall_reduce( return _safe_divide(tp, tp + different_stat, zero_division) score = _safe_divide(tp, tp + different_stat, zero_division) - return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k) + return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index=ignore_index, top_k=top_k) def binary_precision( diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index ee11a36136f..deb2e614ca3 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -61,7 +61,14 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens def _adjust_weights_safe_divide( - score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1 + score: Tensor, + average: Optional[str], + multilabel: bool, + tp: Tensor, + fp: Tensor, + fn: Tensor, + ignore_index: Optional[int] = None, + top_k: int = 1, ) -> Tensor: if average is None or average == "none": return score @@ -71,6 +78,10 @@ def _adjust_weights_safe_divide( weights = torch.ones_like(score) if not multilabel: weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0 + + if ignore_index is not None and 0 <= ignore_index < len(score): + weights[ignore_index] = 0.0 + return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1) diff --git a/tests/unittests/classification/test_precision_recall.py b/tests/unittests/classification/test_precision_recall.py index 00eee202cc0..a40427d16c4 100644 --- a/tests/unittests/classification/test_precision_recall.py +++ b/tests/unittests/classification/test_precision_recall.py @@ -659,6 +659,37 @@ def test_corner_case(): assert res == 1.0 +def test_multiclass_recall_ignore_index(): + """Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441.""" + y_true = torch.tensor([0, 0, 1, 1]) + y_pred = torch.tensor([ + [0.9, 0.1], + [0.9, 0.1], + [0.9, 0.1], + [0.1, 0.9], + ]) + + # Test with ignore_index=0 and average="macro" + metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro") + res_ignore_0 = metric_ignore_0(y_pred, y_true) + assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}" + + # Test with ignore_index=1 and average="macro" + metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro") + res_ignore_1 = metric_ignore_1(y_pred, y_true) + assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}" + + # Test with no ignore_index and average="macro" + metric_no_ignore = MulticlassRecall(num_classes=2, average="macro") + res_no_ignore = metric_no_ignore(y_pred, y_true) + assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}" + + # Test with ignore_index=0 and average="none" + metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none") + res_none = metric_none(y_pred, y_true) + assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}" + + @pytest.mark.parametrize( ("metric", "kwargs", "base_metric"), [