From 9808ce23c9640462748754ba08c2617747944829 Mon Sep 17 00:00:00 2001 From: Zifu Wang Date: Tue, 3 Dec 2024 06:13:08 +0100 Subject: [PATCH] Modify Dice, Jaccard and Tversky losses (#8138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #8094. ### Description The Dice, Jaccard and Tversky losses in `monai.losses.dice` and `monai.losses.tversky` are modified based on [JDTLoss](https://github.com/zifuwanggg/JDTLosses/blob/master/losses/jdt_loss.py) and [segmentation_models.pytorch](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/losses/_functional.py). In the original versions, when `squared_pred=False`, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{\|x\|_p^p + \|y\|_p^p - \|x-y\|_p^p}{2}$. When $p$ is 2 (`squared_pred=True`), this reformulation becomes the classical inner product: $\langle x,y \rangle$. When $p$ is 1 (`squared_pred=False`), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2]. In summary, there are three scenarios: * [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions are the same as the original versions. * [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. * [Scenario 3] Either $x$ or $y$ is negative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions: * The target is non-binary: `test_multi_scale` * The input is negative: `test_dice_loss`, `test_tversky_loss`, `test_generalized_dice_loss`, `test_masked_loss`, `test_seg_loss_integration` The failures in `test_multi_scale` are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values. ### Example ``` import torch import torch.linalg as LA import torch.nn.functional as F torch.manual_seed(0) b, c, h, w = 4, 3, 32, 32 dims = (0, 2, 3) pred = torch.rand(b, c, h, w).softmax(dim=1) soft_label = torch.rand(b, c, h, w).softmax(dim=1) hard_label = torch.randint(low=0, high=c, size=(b, h, w)) one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float() def dice_old(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord intersection = torch.sum(x * y, dim=dims) return 2 * intersection / cardinality def dice_new(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord intersection = (cardinality - difference) / 2 return 2 * intersection / cardinality print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims)) print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims)) print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims)) print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims)) print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims)) print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims)) # tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317]) # tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700]) # tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.]) # tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935]) # tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503]) # tensor([1., 1., 1.]) tensor([1., 1., 1.]) ``` ### References [1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. *MICCAI 2023*. [2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. *NeurIPS 2023*. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Zifu Wang Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- monai/losses/dice.py | 55 +++++++++++++---------- monai/losses/tversky.py | 19 ++++---- monai/losses/utils.py | 68 +++++++++++++++++++++++++++++ tests/test_dice_loss.py | 16 +++++++ tests/test_generalized_dice_loss.py | 16 +++++++ tests/test_tversky_loss.py | 16 +++++++ 6 files changed, 160 insertions(+), 30 deletions(-) create mode 100644 monai/losses/utils.py diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 3f02fae6b8..4108820bec 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -23,6 +23,7 @@ from monai.losses.focal_loss import FocalLoss from monai.losses.spatial_mask import MaskedLoss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after @@ -39,8 +40,16 @@ class DiceLoss(_Loss): The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of the inter-over-union calculation to smooth results respectively, these values should be small. - The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric - Medical Image Segmentation, 3DV, 2016. + The original papers: + + Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric + Medical Image Segmentation. 3DV 2016. + + Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with + Soft Labels. NeurIPS 2023. + + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. """ @@ -58,6 +67,7 @@ def __init__( smooth_dr: float = 1e-5, batch: bool = False, weight: Sequence[float] | float | int | torch.Tensor | None = None, + soft_label: bool = False, ) -> None: """ Args: @@ -89,6 +99,8 @@ def __init__( of the sequence should be the same as the number of classes. If not ``include_background``, the number of classes should not include the background category class 0). The value/values should be no less than 0. Defaults to None. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -114,6 +126,7 @@ def __init__( weight = torch.as_tensor(weight) if weight is not None else None self.register_buffer("class_weight", weight) self.class_weight: None | torch.Tensor + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, dim=reduce_axis) - - if self.squared_pred: - ground_o = torch.sum(target**2, dim=reduce_axis) - pred_o = torch.sum(input**2, dim=reduce_axis) - else: - ground_o = torch.sum(target, dim=reduce_axis) - pred_o = torch.sum(input, dim=reduce_axis) - - denominator = ground_o + pred_o - - if self.jaccard: - denominator = 2.0 * (denominator - intersection) + ord = 2 if self.squared_pred else 1 + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label) + if not self.jaccard: + fp *= 0.5 + fn *= 0.5 + numerator = 2 * tp + self.smooth_nr + denominator = 2 * (tp + fp + fn) + self.smooth_dr - f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr) + f: torch.Tensor = 1 - numerator / denominator num_of_classes = target.shape[1] if self.class_weight is not None and num_of_classes != 1: @@ -272,6 +279,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -295,6 +303,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. If True, the class-weighted intersection and union areas are first summed across the batches. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -319,6 +329,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def w_func(self, grnd): if self.w_type == str(Weight.SIMPLE): @@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: reduce_axis = [0] + reduce_axis - intersection = torch.sum(target * input, reduce_axis) - ground_o = torch.sum(target, reduce_axis) - pred_o = torch.sum(input, reduce_axis) - - denominator = ground_o + pred_o + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label) + fp *= 0.5 + fn *= 0.5 + denominator = 2 * (tp + fp + fn) + ground_o = torch.sum(target, reduce_axis) w = self.w_func(ground_o.float()) infs = torch.isinf(w) if self.batch: @@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: w = w + infs * max_values final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) diff --git a/monai/losses/tversky.py b/monai/losses/tversky.py index 4f22bf84b4..154f34c526 100644 --- a/monai/losses/tversky.py +++ b/monai/losses/tversky.py @@ -17,6 +17,7 @@ import torch from torch.nn.modules.loss import _Loss +from monai.losses.utils import compute_tp_fp_fn from monai.networks import one_hot from monai.utils import LossReduction @@ -28,6 +29,9 @@ class TverskyLoss(_Loss): Sadegh et al. (2017) Tversky loss function for image segmentation using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721) + Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with + Soft Labels. MICCAI 2023. + Adapted from: https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631 @@ -46,6 +50,7 @@ def __init__( smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, + soft_label: bool = False, ) -> None: """ Args: @@ -70,6 +75,8 @@ def __init__( batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``. @@ -93,6 +100,7 @@ def __init__( self.smooth_nr = float(smooth_nr) self.smooth_dr = float(smooth_dr) self.batch = batch + self.soft_label = soft_label def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ @@ -134,20 +142,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: if target.shape != input.shape: raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})") - p0 = input - p1 = 1 - p0 - g0 = target - g1 = 1 - g0 - # reducing only spatial dimensions (not batch nor channels) reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist() if self.batch: # reducing spatial dimensions and batch reduce_axis = [0] + reduce_axis - tp = torch.sum(p0 * g0, reduce_axis) - fp = self.alpha * torch.sum(p0 * g1, reduce_axis) - fn = self.beta * torch.sum(p1 * g0, reduce_axis) + tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False) + fp *= self.alpha + fn *= self.beta numerator = tp + self.smooth_nr denominator = tp + fp + fn + self.smooth_dr diff --git a/monai/losses/utils.py b/monai/losses/utils.py new file mode 100644 index 0000000000..782fd9c9c2 --- /dev/null +++ b/monai/losses/utils.py @@ -0,0 +1,68 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import torch +import torch.linalg as LA + + +def compute_tp_fp_fn( + input: torch.Tensor, + target: torch.Tensor, + reduce_axis: list[int], + ord: int, + soft_label: bool, + decoupled: bool = True, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + input: the shape should be BNH[WD], where N is the number of classes. + target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. + reduce_axis: the axis to be reduced. + ord: the order of the vector norm. + soft_label: whether the target contains non-binary values (soft labels) or not. + If True a soft label formulation of the loss will be used. + decoupled: whether the input and the target should be decoupled when computing fp and fn. + Only for the original implementation when soft_label is False. + + Adapted from: + https://github.com/zifuwanggg/JDTLosses + """ + + # the original implementation that is erroneous with soft labels + if ord == 1 and not soft_label: + tp = torch.sum(input * target, dim=reduce_axis) + # the original implementation of Dice and Jaccard loss + if decoupled: + fp = torch.sum(input, dim=reduce_axis) - tp + fn = torch.sum(target, dim=reduce_axis) - tp + # the original implementation of Tversky loss + else: + fp = torch.sum(input * (1 - target), dim=reduce_axis) + fn = torch.sum((1 - input) * target, dim=reduce_axis) + # the new implementation that is correct with soft labels + # and it is identical to the original implementation with hard labels + else: + pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis) + ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis) + difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis) + + if ord > 1: + pred_o = torch.pow(pred_o, exponent=ord) + ground_o = torch.pow(ground_o, exponent=ord) + difference = torch.pow(difference, exponent=ord) + + tp = (pred_o + ground_o - difference) / 2 + fp = pred_o - tp + fn = ground_o - tp + + return tp, fp, fn diff --git a/tests/test_dice_loss.py b/tests/test_dice_loss.py index 14aa6ec241..cea6ccf113 100644 --- a/tests/test_dice_loss.py +++ b/tests/test_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, { diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 5738f4a089..9706c2e746 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -34,6 +34,22 @@ }, 0.416597, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307748, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0}, { diff --git a/tests/test_tversky_loss.py b/tests/test_tversky_loss.py index 0365503ea2..73a841a55d 100644 --- a/tests/test_tversky_loss.py +++ b/tests/test_tversky_loss.py @@ -34,6 +34,22 @@ }, 0.416657, ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.0, + ], + [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) + {"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False}, + { + "input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + "target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]), + }, + 0.307773, + ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0}, {