Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Modify Dice, Jaccard and Tversky losses (#8138)
Fixes #8094. ### Description The Dice, Jaccard and Tversky losses in `monai.losses.dice` and `monai.losses.tversky` are modified based on [JDTLoss](https://github.com/zifuwanggg/JDTLosses/blob/master/losses/jdt_loss.py) and [segmentation_models.pytorch](https://github.com/qubvel-org/segmentation_models.pytorch/blob/main/segmentation_models_pytorch/losses/_functional.py). In the original versions, when `squared_pred=False`, the loss functions are incompatible with soft labels. For example, with a ground truth value of 0.5 for a single pixel, the Dice loss is minimized when the predicted value is 1, which is clearly erroneous. To address this, the intersection term is rewritten as $\frac{\|x\|_p^p + \|y\|_p^p - \|x-y\|_p^p}{2}$. When $p$ is 2 (`squared_pred=True`), this reformulation becomes the classical inner product: $\langle x,y \rangle$. When $p$ is 1 (`squared_pred=False`), the reformulation has been proven to retain equivalence with the original versions when the ground truth is binary (i.e. one-hot hard labels). Moreover, since the new versions are minimized if and only if the prediction is identical to the ground truth, even when the ground truth include fractional numbers, they resolves the issue with soft labels [1, 2]. In summary, there are three scenarios: * [Scenario 1] $x$ is nonnegative and $y$ is binary: The new versions are the same as the original versions. * [Scenario 2] Both $x$ and $y$ are nonnegative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. * [Scenario 3] Either $x$ or $y$ is negative: The new versions differ from the original versions. The new versions are minimized if and only if $x=y$, while the original versions may not, making them incorrect. Due to these differences, particularly in Scenarios 2 and 3, some tests fail with the new versions: * The target is non-binary: `test_multi_scale` * The input is negative: `test_dice_loss`, `test_tversky_loss`, `test_generalized_dice_loss`, `test_masked_loss`, `test_seg_loss_integration` The failures in `test_multi_scale` are expected since the original versions are incorrectly defined for non-binary targets. Furthermore, because Dice, Jaccard, and Tversky losses are fundamentally defined over probabilities—which should be nonnegative—the new versions should not be tested against negative input or target values. ### Example ``` import torch import torch.linalg as LA import torch.nn.functional as F torch.manual_seed(0) b, c, h, w = 4, 3, 32, 32 dims = (0, 2, 3) pred = torch.rand(b, c, h, w).softmax(dim=1) soft_label = torch.rand(b, c, h, w).softmax(dim=1) hard_label = torch.randint(low=0, high=c, size=(b, h, w)) one_hot_label = F.one_hot(hard_label, c).permute(0, 3, 1, 2).float() def dice_old(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord intersection = torch.sum(x * y, dim=dims) return 2 * intersection / cardinality def dice_new(x, y, ord, dims): cardinality = LA.vector_norm(x, ord=ord, dim=dims) ** ord + LA.vector_norm(y, ord=ord, dim=dims) ** ord difference = LA.vector_norm(x - y, ord=ord, dim=dims) ** ord intersection = (cardinality - difference) / 2 return 2 * intersection / cardinality print(dice_old(pred, one_hot_label, 1, dims), dice_new(pred, one_hot_label, 1, dims)) print(dice_old(pred, soft_label, 1, dims), dice_new(pred, soft_label, 1, dims)) print(dice_old(pred, pred, 1, dims), dice_new(pred, pred, 1, dims)) print(dice_old(pred, one_hot_label, 2, dims), dice_new(pred, one_hot_label, 2, dims)) print(dice_old(pred, soft_label, 2, dims), dice_new(pred, soft_label, 2, dims)) print(dice_old(pred, pred, 2, dims), dice_new(pred, pred, 2, dims)) # tensor([0.3345, 0.3310, 0.3317]) tensor([0.3345, 0.3310, 0.3317]) # tensor([0.3321, 0.3333, 0.3350]) tensor([0.8680, 0.8690, 0.8700]) # tensor([0.3487, 0.3502, 0.3544]) tensor([1., 1., 1.]) # tensor([0.4921, 0.4904, 0.4935]) tensor([0.4921, 0.4904, 0.4935]) # tensor([0.9489, 0.9499, 0.9503]) tensor([0.9489, 0.9499, 0.9503]) # tensor([1., 1., 1.]) tensor([1., 1., 1.]) ``` ### References [1] Dice Semimetric Losses: Optimizing the Dice Score with Soft Labels. Zifu Wang, Teodora Popordanoska, Jeroen Bertels, Robin Lemmens, Matthew B. Blaschko. *MICCAI 2023*. [2] Jaccard Metric Losses: Optimizing the Jaccard Index with Soft Labels. Zifu Wang, Xuefei Ning, Matthew B. Blaschko. *NeurIPS 2023*. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Zifu Wang <[email protected]> Co-authored-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information