Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Segmentation IOU compute Ignore some tagged values that don't need to be recorded (such as 255) #2747

Open
woldier opened this issue Sep 16, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@woldier
Copy link

woldier commented Sep 16, 2024

🚀 Feature

when we compute IOU

import torch

_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU

miou = MeanIoU(num_classes=3)
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255))  # An index of 255 is a tag to be ignored.
miou(preds, target)
>>> This will result in an error

Motivation

When I generate the sample pairs, the opposite mask (assuming 3 classes), but not all pixels in the entire mask should be classified into a particular class, so I set these pixels to 255. The pixel is then ignored in the loss calculation using torch.nn.CrossEntropyLoss(ignore_index=255). However, the IOU calculation does not have this feature, which leads to errors in the IOU calculation, so I wondered if it could be made to support the ignore_index parameter as well, to ignore certain pixels.

Pitch

import torch
_ = torch.manual_seed(0)
from torchmetrics.segmentation import MeanIoU
miou = MeanIoU(num_classes=3, ignore_index=255)  # support ignore_index param to ignore index 255
preds = torch.randint(0, 2, (5,))
target = torch.as_tensor((0, 1, 2, 0, 255))  # An index of 255 is a tag to be ignored.
miou(preds, target)

Alternatives

if input_format == "index":
preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
target = torch.nn.functional.one_hot(target, num_classes=num_classes).movedim(-1, 1)

def _mean_iou_update(
    preds: Tensor,
    target: Tensor,
    num_classes: int,
    include_background: bool = False,
    input_format: Literal["one-hot", "index"] = "one-hot",
    ignore_index=255

) -> Tuple[Tensor, Tensor]:
    ...

    if input_format == "index":
        preds = torch.nn.functional.one_hot(preds, num_classes=num_classes).movedim(-1, 1)
        mask = torch.where(target == ignore_index)  #  Add removal of ignored labels
        target[mask] = 0
        target = torch.nn.functional.one_hot(target, num_classes=num_classes)
        target[mask] = 0  # set ont-hot to zero-hot from ignored labels
        target = target.movedim(-1, 1)
  ...

Additional context

@woldier woldier added the enhancement New feature or request label Sep 16, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant