Skip to content

Commit

Permalink
ADD: ability to pass custom comparison fn
Browse files Browse the repository at this point in the history
  • Loading branch information
madisonmay committed Nov 1, 2024
1 parent 0aeb491 commit 70ae6f7
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
30 changes: 20 additions & 10 deletions sequence_metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from collections import OrderedDict, defaultdict
from functools import partial
from typing import Callable

import numpy as np
import spacy
Expand Down Expand Up @@ -345,16 +346,25 @@ def sequence_labeling_counts(true, predicted, equality_fn):
}


# TODO: reqwite this to use the map above
def get_seq_count_fn(span_type="token"):
span_type_fn_mapping = {
"token": sequence_labeling_token_counts,
"overlap": partial(sequence_labeling_counts, equality_fn=sequences_overlap),
"exact": partial(sequence_labeling_counts, equality_fn=sequence_exact_match),
"superset": partial(sequence_labeling_counts, equality_fn=sequence_superset),
"value": partial(sequence_labeling_counts, equality_fn=fuzzy_compare),
}
return span_type_fn_mapping[span_type]
SPAN_TYPE_FN_MAPPING = {
"token": sequence_labeling_token_counts,
"overlap": partial(sequence_labeling_counts, equality_fn=sequences_overlap),
"exact": partial(sequence_labeling_counts, equality_fn=sequence_exact_match),
"superset": partial(sequence_labeling_counts, equality_fn=sequence_superset),
"value": partial(sequence_labeling_counts, equality_fn=fuzzy_compare),
}


def get_seq_count_fn(span_type: str | Callable = "token"):
if isinstance(span_type, str):
return SPAN_TYPE_FN_MAPPING[span_type]
elif callable(span_type):
# Interpret span_type as an equality function
return partial(sequence_labeling_counts, equality_fn=span_type)

raise ValueError(
f"Invalid span_type: {span_type}. Must either be a string or a callable."
)


def sequence_labeling_overlap_precision(true, predicted):
Expand Down
42 changes: 42 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,3 +888,45 @@ def test_empty_preds_metrics(classes):
def test_empty_preds_metrics(classes):
all_metrics = get_all_metrics(preds=[[]], labels=[[]], field_names=classes)
verify_all_metrics_structure(all_metrics=all_metrics, classes=classes)


def _same_charset(a: dict, b: dict):
return set(a["text"]) == set(b["text"])


@pytest.mark.parametrize(
"true,pred,expected",
[
(
[[{"text": "a", "label": "class1"}]],
[[{"text": "a", "label": "class1"}]],
{"TP": 1, "FP": 0, "FN": 0},
),
(
[[{"text": "a", "label": "class1"}]],
[[{"text": "b", "label": "class1"}]],
{"TP": 0, "FP": 1, "FN": 1},
),
(
[[{"text": "ab", "label": "class1"}]],
[[{"text": "ba", "label": "class1"}]],
{"TP": 1, "FP": 0, "FN": 0},
),
(
[[{"text": "ab", "label": "class1"}]],
[[{"text": "ba", "label": "class1"}, {"text": "ac", "label": "class1"}]],
{"TP": 1, "FP": 1, "FN": 0},
),
],
)
def test_custom_equality_fn(true, pred, expected):
result = get_seq_count_fn(_same_charset)(true, pred)
print(result)
result_subset = {
k: v
for k, v in result["class1"].items()
if k in ["true_positives", "false_positives", "false_negatives"]
}
assert len(result_subset["true_positives"]) == expected["TP"]
assert len(result_subset["false_positives"]) == expected["FP"]
assert len(result_subset["false_negatives"]) == expected["FN"]

0 comments on commit 70ae6f7

Please sign in to comment.