Skip to content

Commit

Permalink
Modify Dice, Jaccard and Tversky losses (#8138)
Browse files Browse the repository at this point in the history
Fixes #8094.

### Description

The Dice, Jaccard and Tversky losses in `monai.losses.dice` and
`monai.losses.tversky` are modified based on
[JDTLoss](https://github.com/zifuwanggg/JDTLosses/blob/master/losses/jdt_loss.py)
and
[segmentation_models.pytorch](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/losses/_functional.py).

In the original versions, when `squared_pred=False`, the loss functions
are incompatible with soft labels. For example, with a ground truth
value of 0.5 for a single pixel, the Dice loss is minimized when the
predicted value is 1, which is clearly erroneous. To address this, the
intersection term is rewritten as $\frac{\|x\|_p^p + \|y\|_p^p -
\|x-y\|_p^p}{2}$. When $p$ is 2 (`squared_pred=True`), this
reformulation becomes the classical inner product: $\langle x,y
\rangle$. When $p$ is 1 (`squared_pred=False`), the reformulation has
been proven to retain equivalence with the original versions when the
ground truth is binary (i.e. one-hot hard labels). Moreover, since the
new versions are minimized if and only if the prediction is identical to
the ground truth, even when the ground truth include fractional numbers,
they resolves the issue with soft labels [1, 2].

In summary, there are three scenarios:
* [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions
are the same as the original versions.
* [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ
from the original versions. The new versions are minimized if and only
if $x=y$, while the original versions may not, making them incorrect.
* [Scenario 3] Either $x$ or $y$ is negative: The new versions differ
from the original versions. The new versions are minimized if and only
if $x=y$, while the original versions may not, making them incorrect.

Due to these differences, particularly in Scenarios 2 and 3, some tests
fail with the new versions:
* The target is non-binary: `test_multi_scale`
* The input is negative: `test_dice_loss`, `test_tversky_loss`,
`test_generalized_dice_loss`, `test_masked_loss`,
`test_seg_loss_integration`

The failures in `test_multi_scale` are expected since the original
versions are incorrectly defined for non-binary targets. Furthermore,
because Dice, Jaccard, and Tversky losses are fundamentally defined over
probabilities—which should be nonnegative—the new versions should not be
tested against negative input or target values.

### Example
```
import torch
import torch.linalg as LA
import torch.nn.functional as F

torch.manual_seed(0)

b, c, h, w = 4, 3, 32, 32
dims = (0, 2, 3)

pred = torch.rand(b, c, h, w).softmax(dim=1)
soft_label = torch.rand(b, c, h, w).softmax(dim=1)
hard_label = torch.randint(low=0, high=c, size=(b, h, w))
one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float()

def dice_old(x, y, ord, dims):
    cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
    intersection = torch.sum(x * y, dim=dims)
    return 2 * intersection / cardinality

def dice_new(x, y, ord, dims):
    cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord
    difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord
    intersection = (cardinality - difference) / 2
    return 2 * intersection / cardinality

print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims))
print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims))
print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims))

print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims))
print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims))
print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims))

# tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317])
# tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700])
# tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.])

# tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935])
# tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503])
# tensor([1., 1., 1.]) tensor([1., 1., 1.])
```

### References
[1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels.
Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew
B. Blaschko. *MICCAI 2023*.

[2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft
Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. *NeurIPS 2023*.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Zifu Wang <[email protected]>
Co-authored-by: YunLiu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Dec 3, 2024
1 parent 7f88a46 commit 9808ce2
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 30 deletions.
55 changes: 33 additions & 22 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from monai.losses.focal_loss import FocalLoss
from monai.losses.spatial_mask import MaskedLoss
from monai.losses.utils import compute_tp_fp_fn
from monai.networks import one_hot
from monai.utils import DiceCEReduction, LossReduction, Weight, look_up_option, pytorch_after

Expand All @@ -39,8 +40,16 @@ class DiceLoss(_Loss):
The `smooth_nr` and `smooth_dr` parameters are values added to the intersection and union components of
the inter-over-union calculation to smooth results respectively, these values should be small.
The original paper: Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks forVolumetric
Medical Image Segmentation, 3DV, 2016.
The original papers:
Milletari, F. et. al. (2016) V-Net: Fully Convolutional Neural Networks for Volumetric
Medical Image Segmentation. 3DV 2016.
Wang, Z. et. al. (2023) Jaccard Metric Losses: Optimizing the Jaccard Index with
Soft Labels. NeurIPS 2023.
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
Soft Labels. MICCAI 2023.
"""

Expand All @@ -58,6 +67,7 @@ def __init__(
smooth_dr: float = 1e-5,
batch: bool = False,
weight: Sequence[float] | float | int | torch.Tensor | None = None,
soft_label: bool = False,
) -> None:
"""
Args:
Expand Down Expand Up @@ -89,6 +99,8 @@ def __init__(
of the sequence should be the same as the number of classes. If not ``include_background``,
the number of classes should not include the background category class 0).
The value/values should be no less than 0. Defaults to None.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -114,6 +126,7 @@ def __init__(
weight = torch.as_tensor(weight) if weight is not None else None
self.register_buffer("class_weight", weight)
self.class_weight: None | torch.Tensor
self.soft_label = soft_label

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -174,21 +187,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

intersection = torch.sum(target * input, dim=reduce_axis)

if self.squared_pred:
ground_o = torch.sum(target**2, dim=reduce_axis)
pred_o = torch.sum(input**2, dim=reduce_axis)
else:
ground_o = torch.sum(target, dim=reduce_axis)
pred_o = torch.sum(input, dim=reduce_axis)

denominator = ground_o + pred_o

if self.jaccard:
denominator = 2.0 * (denominator - intersection)
ord = 2 if self.squared_pred else 1
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, ord, self.soft_label)
if not self.jaccard:
fp *= 0.5
fn *= 0.5
numerator = 2 * tp + self.smooth_nr
denominator = 2 * (tp + fp + fn) + self.smooth_dr

f: torch.Tensor = 1.0 - (2.0 * intersection + self.smooth_nr) / (denominator + self.smooth_dr)
f: torch.Tensor = 1 - numerator / denominator

num_of_classes = target.shape[1]
if self.class_weight is not None and num_of_classes != 1:
Expand Down Expand Up @@ -272,6 +279,7 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
) -> None:
"""
Args:
Expand All @@ -295,6 +303,8 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, intersection over union is computed from each item in the batch.
If True, the class-weighted intersection and union areas are first summed across the batches.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -319,6 +329,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label

def w_func(self, grnd):
if self.w_type == str(Weight.SIMPLE):
Expand Down Expand Up @@ -370,13 +381,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
reduce_axis = [0] + reduce_axis
intersection = torch.sum(target * input, reduce_axis)

ground_o = torch.sum(target, reduce_axis)
pred_o = torch.sum(input, reduce_axis)

denominator = ground_o + pred_o
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label)
fp *= 0.5
fn *= 0.5
denominator = 2 * (tp + fp + fn)

ground_o = torch.sum(target, reduce_axis)
w = self.w_func(ground_o.float())
infs = torch.isinf(w)
if self.batch:
Expand All @@ -388,7 +399,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
w = w + infs * max_values

final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
numer = 2.0 * (tp * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)

Expand Down
19 changes: 11 additions & 8 deletions monai/losses/tversky.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
from torch.nn.modules.loss import _Loss

from monai.losses.utils import compute_tp_fp_fn
from monai.networks import one_hot
from monai.utils import LossReduction

Expand All @@ -28,6 +29,9 @@ class TverskyLoss(_Loss):
Sadegh et al. (2017) Tversky loss function for image segmentation
using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)
Wang, Z. et. al. (2023) Dice Semimetric Losses: Optimizing the Dice Score with
Soft Labels. MICCAI 2023.
Adapted from:
https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631
Expand All @@ -46,6 +50,7 @@ def __init__(
smooth_nr: float = 1e-5,
smooth_dr: float = 1e-5,
batch: bool = False,
soft_label: bool = False,
) -> None:
"""
Args:
Expand All @@ -70,6 +75,8 @@ def __init__(
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
Defaults to False, a Dice loss value is computed independently from each item in the batch
before any `reduction`.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
Raises:
TypeError: When ``other_act`` is not an ``Optional[Callable]``.
Expand All @@ -93,6 +100,7 @@ def __init__(
self.smooth_nr = float(smooth_nr)
self.smooth_dr = float(smooth_dr)
self.batch = batch
self.soft_label = soft_label

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -134,20 +142,15 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
if target.shape != input.shape:
raise AssertionError(f"ground truth has differing shape ({target.shape}) from input ({input.shape})")

p0 = input
p1 = 1 - p0
g0 = target
g1 = 1 - g0

# reducing only spatial dimensions (not batch nor channels)
reduce_axis: list[int] = torch.arange(2, len(input.shape)).tolist()
if self.batch:
# reducing spatial dimensions and batch
reduce_axis = [0] + reduce_axis

tp = torch.sum(p0 * g0, reduce_axis)
fp = self.alpha * torch.sum(p0 * g1, reduce_axis)
fn = self.beta * torch.sum(p1 * g0, reduce_axis)
tp, fp, fn = compute_tp_fp_fn(input, target, reduce_axis, 1, self.soft_label, False)
fp *= self.alpha
fn *= self.beta
numerator = tp + self.smooth_nr
denominator = tp + fp + fn + self.smooth_dr

Expand Down
68 changes: 68 additions & 0 deletions monai/losses/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
import torch.linalg as LA


def compute_tp_fp_fn(
input: torch.Tensor,
target: torch.Tensor,
reduce_axis: list[int],
ord: int,
soft_label: bool,
decoupled: bool = True,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
reduce_axis: the axis to be reduced.
ord: the order of the vector norm.
soft_label: whether the target contains non-binary values (soft labels) or not.
If True a soft label formulation of the loss will be used.
decoupled: whether the input and the target should be decoupled when computing fp and fn.
Only for the original implementation when soft_label is False.
Adapted from:
https://github.com/zifuwanggg/JDTLosses
"""

# the original implementation that is erroneous with soft labels
if ord == 1 and not soft_label:
tp = torch.sum(input * target, dim=reduce_axis)
# the original implementation of Dice and Jaccard loss
if decoupled:
fp = torch.sum(input, dim=reduce_axis) - tp
fn = torch.sum(target, dim=reduce_axis) - tp
# the original implementation of Tversky loss
else:
fp = torch.sum(input * (1 - target), dim=reduce_axis)
fn = torch.sum((1 - input) * target, dim=reduce_axis)
# the new implementation that is correct with soft labels
# and it is identical to the original implementation with hard labels
else:
pred_o = LA.vector_norm(input, ord=ord, dim=reduce_axis)
ground_o = LA.vector_norm(target, ord=ord, dim=reduce_axis)
difference = LA.vector_norm(input - target, ord=ord, dim=reduce_axis)

if ord > 1:
pred_o = torch.pow(pred_o, exponent=ord)
ground_o = torch.pow(ground_o, exponent=ord)
difference = torch.pow(difference, exponent=ord)

tp = (pred_o + ground_o - difference) / 2
fp = pred_o - tp
fn = ground_o - tp

return tp, fp, fn
16 changes: 16 additions & 0 deletions tests/test_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@
},
0.416657,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
{
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
},
0.0,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
{
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
},
0.307773,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
{
Expand Down
16 changes: 16 additions & 0 deletions tests/test_generalized_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@
},
0.416597,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
{
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
},
0.0,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
{
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
},
0.307748,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0.0, "smooth_dr": 0.0},
{
Expand Down
16 changes: 16 additions & 0 deletions tests/test_tversky_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@
},
0.416657,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": True},
{
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
},
0.0,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4, "soft_label": False},
{
"input": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
"target": torch.tensor([[[[0.3, 0.4], [0.7, 0.9]]], [[[1.0, 0.1], [0.5, 0.3]]]]),
},
0.307773,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
{
Expand Down

0 comments on commit 9808ce2

Please sign in to comment.