From 50982abc8d1f7a6615227c13e82817aac1a268a0 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 16:10:07 +0100 Subject: [PATCH 1/8] fix implementation --- src/torchmetrics/functional/segmentation/dice.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 53c4b3d6622..db194cd8e6b 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -78,6 +78,10 @@ def _dice_score_compute( support: Optional[Tensor] = None, ) -> Tensor: """Compute the Dice score from the numerator and denominator.""" + # If both numerator and denominator are 0, the dice score is 0 + if torch.all(numerator == 0) and torch.all(denominator == 0): + return torch.tensor(0.0, device=numerator.device, dtype=torch.float) + if average == "micro": numerator = torch.sum(numerator, dim=-1) denominator = torch.sum(denominator, dim=-1) From 059f6df30ca1466ba11ca2b0b76f8befcadfd7f1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 16:11:10 +0100 Subject: [PATCH 2/8] add tests --- tests/unittests/segmentation/test_dice.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index b009401f481..b324bdaa0f9 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -106,3 +106,18 @@ def test_dice_score_functional(self, preds, target, input_format, include_backgr "input_format": input_format, }, ) + + +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +def test_corner_case_no_overlap(average): + """Check that if no overlap and intersection between target and preds, the dice score is 0. + + See issue: https://github.com/Lightning-AI/torchmetrics/issues/2851 + + """ + target = torch.full((4, 4, 128, 128), 0, dtype=torch.int8) + preds = torch.full((4, 4, 128, 128), 0, dtype=torch.int8) + target[0, 0] = 1 + preds[0, 0] = 1 + dice = DiceScore(num_classes=3, average=average, include_background=False) + assert dice(preds, target) == 0.0 From 4b0790580966ff348332269722d9a2b443c64efd Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Tue, 3 Dec 2024 16:13:21 +0100 Subject: [PATCH 3/8] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a253f90f6ef..05635c9b7a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,7 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed `DiceScore` when there is zero overlap between predictions and targets ([#2860](https://github.com/PyTorchLightning/metrics/pull/2860)) --- From 3bf4f09ec1d70f817ea024a3899fb45752f27193 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 4 Dec 2024 10:12:25 +0100 Subject: [PATCH 4/8] introduce zero_division argument --- .../functional/segmentation/dice.py | 12 ++++++++--- src/torchmetrics/segmentation/dice.py | 9 +++++++-- src/torchmetrics/utilities/compute.py | 20 +++++++++++++++---- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index db194cd8e6b..029f5b25dd1 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from torch import Tensor @@ -27,6 +27,7 @@ def _dice_score_validate_args( include_background: bool, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", input_format: Literal["one-hot", "index"] = "one-hot", + zero_divide: Union[float, Literal["warn", "nan"]] = 1.0, ) -> None: """Validate the arguments of the metric.""" if not isinstance(num_classes, int) or num_classes <= 0: @@ -38,6 +39,10 @@ def _dice_score_validate_args( raise ValueError(f"Expected argument `average` to be one of {allowed_average} or None, but got {average}.") if input_format not in ["one-hot", "index"]: raise ValueError(f"Expected argument `input_format` to be one of 'one-hot', 'index', but got {input_format}.") + if zero_divide not in [1.0, 0.0, "warn", "nan"]: + raise ValueError( + f"Expected argument `zero_divide` to be one of 1.0, 0.0, 'warn', 'nan', but got {zero_divide}." + ) def _dice_score_update( @@ -76,6 +81,7 @@ def _dice_score_compute( denominator: Tensor, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", support: Optional[Tensor] = None, + zero_division: Union[float, Literal["warn", "nan"]] = 1.0, ) -> Tensor: """Compute the Dice score from the numerator and denominator.""" # If both numerator and denominator are 0, the dice score is 0 @@ -85,11 +91,11 @@ def _dice_score_compute( if average == "micro": numerator = torch.sum(numerator, dim=-1) denominator = torch.sum(denominator, dim=-1) - dice = _safe_divide(numerator, denominator, zero_division=1.0) + dice = _safe_divide(numerator, denominator, zero_division=zero_division) if average == "macro": dice = torch.mean(dice, dim=-1) elif average == "weighted" and support is not None: - weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=1.0) + weights = _safe_divide(support, torch.sum(support, dim=-1, keepdim=True), zero_division=zero_division) dice = torch.sum(dice * weights, dim=-1) return dice diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 05a6e29b387..f6319ecacf0 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -68,6 +68,8 @@ class DiceScore(Metric): or ``None``. This determines how to average the dice score across different classes. input_format: What kind of input the function receives. Choose between ``"one-hot"`` for one-hot encoded tensors or ``"index"`` for index tensors + zero_division: The value to return when there is a division by zero. Options are 1.0, 0.0, "warn" or "nan". + Setting it to "warn" behaves like 0.0 but will also create a warning. kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. Raises: @@ -110,14 +112,16 @@ def __init__( include_background: bool = True, average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro", input_format: Literal["one-hot", "index"] = "one-hot", + zero_division: Union[float, Literal["warn", "nan"]] = 0.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) - _dice_score_validate_args(num_classes, include_background, average, input_format) + _dice_score_validate_args(num_classes, include_background, average, input_format, zero_division) self.num_classes = num_classes self.include_background = include_background self.average = average self.input_format = input_format + self.zero_division = zero_division num_classes = num_classes - 1 if not include_background else num_classes self.add_state("numerator", [], dist_reduce_fx="cat") @@ -141,7 +145,8 @@ def compute(self) -> Tensor: dim_zero_cat(self.denominator), self.average, support=dim_zero_cat(self.support) if self.average == "weighted" else None, - ).mean(dim=0) + zero_division=self.zero_division, + ).nanmean(dim=0) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index cbb648a8844..fdbaba8b8d9 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -11,10 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, Union import torch from torch import Tensor +from typing_extensions import Literal + +from torchmetrics.utilities import rank_zero_warn def _safe_matmul(x: Tensor, y: Tensor) -> Tensor: @@ -43,7 +46,11 @@ def _safe_xlogy(x: Tensor, y: Tensor) -> Tensor: return res -def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tensor: +def _safe_divide( + num: Tensor, + denom: Tensor, + zero_division: Union[float, Literal["warn", "nan"]] = 0.0, +) -> Tensor: """Safe division, by preventing division by zero. Function will cast to float if input is not already to secure backwards compatibility. @@ -63,8 +70,13 @@ def _safe_divide(num: Tensor, denom: Tensor, zero_division: float = 0.0) -> Tens """ num = num if num.is_floating_point() else num.float() denom = denom if denom.is_floating_point() else denom.float() - zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True) - return torch.where(denom != 0, num / denom, zero_division_tensor) + if isinstance(zero_division, float) or zero_division == "warn": + if zero_division == "warn" and torch.any(denom == 0): + rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0") + zero_division = 0.0 if zero_division == "warn" else zero_division + zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(num.device, non_blocking=True) + return torch.where(denom != 0, num / denom, zero_division_tensor) + return torch.true_divide(num, denom) def _adjust_weights_safe_divide( From 37d2530775c96b70d877d7a4a2875911c30d06b1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 4 Dec 2024 10:12:43 +0100 Subject: [PATCH 5/8] add tests --- tests/unittests/segmentation/test_dice.py | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index b324bdaa0f9..788fe32e427 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -16,6 +16,7 @@ import pytest import torch from sklearn.metrics import f1_score +from torch import tensor from torchmetrics.functional.segmentation.dice import dice_score from torchmetrics.segmentation.dice import DiceScore @@ -121,3 +122,30 @@ def test_corner_case_no_overlap(average): preds[0, 0] = 1 dice = DiceScore(num_classes=3, average=average, include_background=False) assert dice(preds, target) == 0.0 + + +@pytest.mark.parametrize("average", ["micro", "macro", "weighted", None]) +@pytest.mark.parametrize("zero_division", [1.0, 0.0, "warn", "nan"]) +def test_zero_division(zero_division, average): + """Test different zero_division values.""" + target = torch.full((1, 3, 128, 128), 0, dtype=torch.int8) + preds = torch.full((1, 3, 128, 128), 0, dtype=torch.int8) + target[0, 0] = 1 + dice = DiceScore(num_classes=3, average=average, zero_division=zero_division) + score = dice(preds, target) + + res_dict = { + "micro": {1.0: tensor(0.0), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(0.0)}, + "macro": {1.0: tensor(0.6667), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(float("nan"))}, + "weighted": {1.0: tensor(0.0), 0.0: tensor(0.0), "warn": tensor(0.0), "nan": tensor(float("nan"))}, + None: { + 1.0: tensor([0.0, 1.0, 1.0]), + 0.0: tensor([0.0, 0.0, 0.0]), + "warn": tensor([0.0, 0.0, 0.0]), + "nan": tensor([0.0, float("nan"), float("nan")]), + }, + } + + assert torch.allclose( + score, res_dict[average][zero_division], atol=1e-4, equal_nan=True + ), f"Expected {res_dict[average][zero_division]} but got {score}" From 539ca53d47671462c73fead46c254b310bf452f1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Wed, 4 Dec 2024 10:13:51 +0100 Subject: [PATCH 6/8] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 05635c9b7a6..09b50ae9d83 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- +- Added `zero_division` argument to `DiceScore` in segmentation package ([#2860](https://github.com/PyTorchLightning/metrics/pull/2860)) ### Changed From d59fad8a9e4e6e79fe3cba36799ddee9f7253f79 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 21 Dec 2024 10:34:15 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/utilities/compute.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 62324c04ce9..eea5fd8665e 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -16,6 +16,7 @@ import torch from torch import Tensor from typing_extensions import Literal + from torchmetrics.utilities import rank_zero_warn From 7984f440e659156392522afdfc6effbe57dd00b3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jan 2025 13:39:41 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/segmentation/test_dice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unittests/segmentation/test_dice.py b/tests/unittests/segmentation/test_dice.py index b1f3a3e6c8d..a3df3105b3e 100644 --- a/tests/unittests/segmentation/test_dice.py +++ b/tests/unittests/segmentation/test_dice.py @@ -147,9 +147,9 @@ def test_zero_division(zero_division, average): }, } - assert torch.allclose( - score, res_dict[average][zero_division], atol=1e-4, equal_nan=True - ), f"Expected {res_dict[average][zero_division]} but got {score}" + assert torch.allclose(score, res_dict[average][zero_division], atol=1e-4, equal_nan=True), ( + f"Expected {res_dict[average][zero_division]} but got {score}" + ) @pytest.mark.parametrize("compute_groups", [True, False])