Skip to content

Commit

Permalink
Add specificity metric for binary features (#3025)
Browse files Browse the repository at this point in the history
  • Loading branch information
jppgks authored Jan 30, 2023
1 parent 6d3a8c3 commit 35fde29
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
1 change: 1 addition & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
JACCARD = "jaccard"
PRECISION = "precision"
RECALL = "recall"
SPECIFICITY = "specificity"
PREDICTIONS = "predictions"
TOP_K = "top_k"
TOP_K_PREDICTIONS = "top_k_predictions"
Expand Down
19 changes: 18 additions & 1 deletion ludwig/modules/metric_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torchmetrics import Accuracy as _Accuracy
from torchmetrics import AUROC, MeanAbsoluteError
from torchmetrics import MeanMetric as _MeanMetric
from torchmetrics import MeanSquaredError, Metric, Precision, Recall
from torchmetrics import MeanSquaredError, Metric, Precision, Recall, Specificity
from torchmetrics.functional.regression.r2 import _r2_score_compute, _r2_score_update
from torchmetrics.metric import jit_distributed_available

Expand Down Expand Up @@ -51,6 +51,7 @@
ROOT_MEAN_SQUARED_PERCENTAGE_ERROR,
SEQUENCE,
SET,
SPECIFICITY,
TEXT,
TOKEN_ACCURACY,
VECTOR,
Expand Down Expand Up @@ -184,6 +185,22 @@ def get_inputs(cls):
return PROBABILITIES


@register_metric(SPECIFICITY, [BINARY])
class SpecificityMetric(Specificity, LudwigMetric):
"""Specificity metric."""

def __init__(self, **kwargs):
super().__init__(dist_sync_fn=_gather_all_tensors_fn())

@classmethod
def get_objective(cls):
return MAXIMIZE

@classmethod
def get_inputs(cls):
return PROBABILITIES


class MeanMetric(LudwigMetric):
"""Abstract class for computing mean of metrics."""

Expand Down
9 changes: 9 additions & 0 deletions tests/ludwig/modules/test_metric_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ def test_roc_auc_metric(preds: torch.Tensor, target: torch.Tensor, output: torch
assert output == metric.compute()


@pytest.mark.parametrize("preds", [torch.tensor([0.2, 0.3, 0.8, 0.1, 0.8])])
@pytest.mark.parametrize("target", [torch.tensor([0, 0, 1, 1, 0])])
@pytest.mark.parametrize("output", [torch.tensor(0.6667).float()])
def test_specificity_metric(preds: torch.Tensor, target: torch.Tensor, output: torch.Tensor):
metric = metric_modules.SpecificityMetric()
metric.update(preds, target)
assert torch.isclose(output, metric.compute(), rtol=0.0001)


@pytest.mark.parametrize("preds", [torch.arange(6).reshape(3, 2).float()])
@pytest.mark.parametrize("target", [torch.arange(6, 12).reshape(3, 2).float()])
@pytest.mark.parametrize("output", [torch.tensor(0.7527).float()])
Expand Down

0 comments on commit 35fde29

Please sign in to comment.