Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/multiclass recall macro avg ignore index #2710

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
176711d
Fix: Corrected MulticlassRecall macro average calculation when ignore…
rittik9 Aug 31, 2024
df36d0f
style: format code to comply with pre-commit hooks
rittik9 Sep 1, 2024
0773bab
test: Add test for MulticlassRecall with ignore_index+macro (fixes #2…
rittik9 Sep 2, 2024
78177ac
chlog
Borda Sep 9, 2024
f7701ea
Merge branch 'master' into master
Borda Sep 16, 2024
d6f041b
Merge branch 'master' into master
mergify[bot] Sep 16, 2024
fb6c23d
Merge branch 'master' into master
mergify[bot] Sep 16, 2024
3ae861b
Merge branch 'master' into master
mergify[bot] Sep 16, 2024
a0401f6
Merge branch 'master' into master
Borda Sep 17, 2024
858e0d1
Merge branch 'master' into master
mergify[bot] Sep 17, 2024
bac6267
Merge branch 'master' into master
mergify[bot] Sep 24, 2024
2976947
Merge branch 'master' into master
mergify[bot] Sep 24, 2024
dbe1a5a
Merge branch 'master' into master
mergify[bot] Oct 1, 2024
bb36be4
Merge branch 'master' into master
Borda Oct 8, 2024
ead62fe
Merge branch 'master' into master
mergify[bot] Oct 9, 2024
e0ed7e7
Merge branch 'master' into master
mergify[bot] Oct 9, 2024
263548d
Merge branch 'master' into master
mergify[bot] Oct 10, 2024
8cc5bf1
Merge branch 'master' into master
mergify[bot] Oct 10, 2024
0483219
Merge branch 'master' into master
Borda Oct 10, 2024
982cfea
Merge branch 'master' into master
mergify[bot] Oct 11, 2024
d16c815
Merge branch 'master' into master
mergify[bot] Oct 11, 2024
c078bd2
Merge branch 'master' into master
mergify[bot] Oct 11, 2024
d61727e
Merge branch 'master' into master
mergify[bot] Oct 14, 2024
9aa5928
Merge branch 'master' into master
Borda Oct 15, 2024
581d3ec
Merge branch 'master' into master
mergify[bot] Oct 18, 2024
61a4b56
Merge branch 'master' into master
Borda Oct 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed multiclass recall macro avg. ignore index ([#2710](https://github.com/Lightning-AI/torchmetrics/pull/2710))


---
Expand Down
1 change: 1 addition & 0 deletions src/torchmetrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,7 @@ def compute(self) -> Tensor:
fn,
average=self.average,
multidim_average=self.multidim_average,
ignore_index=self.ignore_index,
top_k=self.top_k,
zero_division=self.zero_division,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def _precision_recall_reduce(
average: Optional[Literal["binary", "micro", "macro", "weighted", "none"]],
multidim_average: Literal["global", "samplewise"] = "global",
multilabel: bool = False,
ignore_index: Optional[int] = None,
top_k: int = 1,
zero_division: float = 0,
) -> Tensor:
Expand All @@ -56,7 +57,7 @@ def _precision_recall_reduce(
return _safe_divide(tp, tp + different_stat, zero_division)

score = _safe_divide(tp, tp + different_stat, zero_division)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, top_k=top_k)
return _adjust_weights_safe_divide(score, average, multilabel, tp, fp, fn, ignore_index=ignore_index, top_k=top_k)


def binary_precision(
Expand Down
13 changes: 12 additions & 1 deletion src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,14 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens


def _adjust_weights_safe_divide(
score: Tensor, average: Optional[str], multilabel: bool, tp: Tensor, fp: Tensor, fn: Tensor, top_k: int = 1
score: Tensor,
average: Optional[str],
multilabel: bool,
tp: Tensor,
fp: Tensor,
fn: Tensor,
ignore_index: Optional[int] = None,
top_k: int = 1,
) -> Tensor:
if average is None or average == "none":
return score
Expand All @@ -71,6 +78,10 @@ def _adjust_weights_safe_divide(
weights = torch.ones_like(score)
if not multilabel:
weights[tp + fp + fn == 0 if top_k == 1 else tp + fn == 0] = 0.0

if ignore_index is not None and 0 <= ignore_index < len(score):
weights[ignore_index] = 0.0

return _safe_divide(weights * score, weights.sum(-1, keepdim=True)).sum(-1)


Expand Down
31 changes: 31 additions & 0 deletions tests/unittests/classification/test_precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,37 @@ def test_corner_case():
assert res == 1.0


def test_multiclass_recall_ignore_index():
"""Issue: https://github.com/Lightning-AI/torchmetrics/issues/2441."""
y_true = torch.tensor([0, 0, 1, 1])
y_pred = torch.tensor([
[0.9, 0.1],
[0.9, 0.1],
[0.9, 0.1],
[0.1, 0.9],
])

# Test with ignore_index=0 and average="macro"
metric_ignore_0 = MulticlassRecall(num_classes=2, ignore_index=0, average="macro")
res_ignore_0 = metric_ignore_0(y_pred, y_true)
assert res_ignore_0 == 0.5, f"Expected 0.5, but got {res_ignore_0}"

# Test with ignore_index=1 and average="macro"
metric_ignore_1 = MulticlassRecall(num_classes=2, ignore_index=1, average="macro")
res_ignore_1 = metric_ignore_1(y_pred, y_true)
assert res_ignore_1 == 1.0, f"Expected 1.0, but got {res_ignore_1}"

# Test with no ignore_index and average="macro"
metric_no_ignore = MulticlassRecall(num_classes=2, average="macro")
res_no_ignore = metric_no_ignore(y_pred, y_true)
assert res_no_ignore == 0.75, f"Expected 0.75, but got {res_no_ignore}"

# Test with ignore_index=0 and average="none"
metric_none = MulticlassRecall(num_classes=2, ignore_index=0, average="none")
res_none = metric_none(y_pred, y_true)
assert torch.allclose(res_none, torch.tensor([0.0, 0.5])), f"Expected [0.0, 0.5], but got {res_none}"


@pytest.mark.parametrize(
("metric", "kwargs", "base_metric"),
[
Expand Down
Loading