Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ignore_index work when all batch elements are to be ignored #2685

Open
fteufel opened this issue Aug 12, 2024 · 1 comment
Open

Make ignore_index work when all batch elements are to be ignored #2685

fteufel opened this issue Aug 12, 2024 · 1 comment
Labels
enhancement New feature or request

Comments

@fteufel
Copy link

fteufel commented Aug 12, 2024

🚀 Feature

The ignore_index argument in e.g. the AUROC metric allows one to specify a label that will be ignored. This works great when some batch elements are to be ignored. When calling the metric, and providing a tensor as input where all entries are the ignore_index, we get an IndexError.

    self.aucs[f"val_label_{i}"](label_logits[:, i].squeeze(-1), labels_target[:, i])
  File "lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "lib/python3.10/site-packages/torchmetrics/metric.py", line 312, in forward
    self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
  File "lib/python3.10/site-packages/torchmetrics/metric.py", line 382, in _forward_reduce_state_update
    batch_val = self.compute()
  File "/lib/python3.10/site-packages/torchmetrics/metric.py", line 633, in wrapped_func
    value = _squeeze_if_scalar(compute(*args, **kwargs))
  File "lib/python3.10/site-packages/torchmetrics/classification/auroc.py", line 124, in compute
    return _binary_auroc_compute(state, self.thresholds, self.max_fpr)
  File "lib/python3.10/site-packages/torchmetrics/functional/classification/auroc.py", line 89, in _binary_auroc_compute
    fpr, tpr, _ = _binary_roc_compute(state, thresholds, pos_label)
  File "lib/python3.10/site-packages/torchmetrics/functional/classification/roc.py", line 54, in _binary_roc_compute
    fps, tps, thres = _binary_clf_curve(preds=state[0], target=state[1], pos_label=pos_label)
  File "lib/python3.10/site-packages/torchmetrics/functional/classification/precision_recall_curve.py", line 72, in _binary_clf_curve
    tps = _cumsum(target * weight, dim=0)[threshold_idxs]

Motivation

Having batches without labels may sound counterintuitive at first, but in multitask problems this can happen quite easily, when a metric only tracks a given subtask and batches are random.

Pitch

It would be helpful if this just worked (and maybe print a warning) - maybe return 0 or nan?

Alternatives

Right now, this needs to be handled manually like

if (target == -100).all():
    pass
else:
    self.auc(logits, target)

or, when calling compute after some update steps

if all([len(x)==0 for x in self.auc.metric_state['preds']]):
    pass
else:
   self.auc.compute()
@fteufel fteufel added the enhancement New feature or request label Aug 12, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant