Skip to content

Commit

Permalink
ADD: break out counts aka quadrants so that they are useful (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
benleetownsend authored Nov 6, 2024
1 parent 9641c0a commit f6cc5ce
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 76 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sequence-metrics"
version = "0.0.6"
version = "0.0.7"
description = "A set of metrics for Sequence Labelling tasks"
readme = "README.md"
authors = ["Indico Data <[email protected]>"]
Expand Down
142 changes: 79 additions & 63 deletions sequence_metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def sequence_labeling_token_confusion(text, true, predicted):
)


def sequence_labeling_token_counts(true, predicted):
def sequence_labeling_token_quadrants(true, predicted):
"""
Return FP, FN, and TP counts
Return FP, FN, and TP quadrants
"""

unique_classes = _get_unique_classes(true, predicted)
Expand Down Expand Up @@ -131,14 +131,22 @@ def sequence_labeling_token_counts(true, predicted):
and pred_token["end"] == true_token["end"]
):
if pred_token["label"] == true_token["label"]:
d[true_token["label"]]["true_positives"].append(true_token)
d[true_token["label"]]["true_positives"].append(
{"true": true_token, "pred": pred_token}
)
else:
d[true_token["label"]]["false_negatives"].append(true_token)
d[pred_token["label"]]["false_positives"].append(pred_token)
d[true_token["label"]]["false_negatives"].append(
{"true": true_token, "pred": None}
)
d[pred_token["label"]]["false_positives"].append(
{"true": None, "pred": pred_token}
)

break
else:
d[true_token["label"]]["false_negatives"].append(true_token)
d[true_token["label"]]["false_negatives"].append(
{"true": true_token, "pred": None}
)

# false positives
for pred_token in pred_tokens:
Expand All @@ -149,7 +157,9 @@ def sequence_labeling_token_counts(true, predicted):
):
break
else:
d[pred_token["label"]]["false_positives"].append(pred_token)
d[pred_token["label"]]["false_positives"].append(
{"true": None, "pred": pred_token}
)

return d

Expand All @@ -176,47 +186,47 @@ def calc_f1(recall, precision):


def seq_recall(true, predicted, span_type: str | Callable = "token"):
count_fn = get_seq_count_fn(span_type)
class_counts = count_fn(true, predicted)
quadrants_fn = get_seq_quadrants_fn(span_type)
class_quadrants = quadrants_fn(true, predicted)
results = {}
for cls_, counts in class_counts.items():
if counts is None:
for cls_, quadrants in class_quadrants.items():
if quadrants is None:
# Class is skipped due to missing start or end
results[cls_] = None
continue
FN = len(counts["false_negatives"])
TP = len(counts["true_positives"])
FN = len(quadrants["false_negatives"])
TP = len(quadrants["true_positives"])
results[cls_] = calc_recall(TP, FN)
return results


def seq_precision(true, predicted, span_type: str | Callable = "token"):
count_fn = get_seq_count_fn(span_type)
class_counts = count_fn(true, predicted)
quadrants_fn = get_seq_quadrants_fn(span_type)
class_quadrants = quadrants_fn(true, predicted)
results = {}
for cls_, counts in class_counts.items():
if counts is None:
for cls_, quadrants in class_quadrants.items():
if quadrants is None:
# Class is skipped due to missing start or end
results[cls_] = None
continue
FP = len(counts["false_positives"])
TP = len(counts["true_positives"])
FP = len(quadrants["false_positives"])
TP = len(quadrants["true_positives"])
results[cls_] = calc_precision(TP, FP)
return results


def micro_f1(true, predicted, span_type: str | Callable = "token"):
count_fn = get_seq_count_fn(span_type)
class_counts = count_fn(true, predicted)
quadrants_fn = get_seq_quadrants_fn(span_type)
class_quadrants = quadrants_fn(true, predicted)
TP, FP, FN = 0, 0, 0
for counts in class_counts.values():
if counts is None:
for quadrants in class_quadrants.values():
if quadrants is None:
# Class is skipped due to missing start or end
# We cannot calculate a micro_f1
return None
FN += len(counts["false_negatives"])
TP += len(counts["true_positives"])
FP += len(counts["false_positives"])
FN += len(quadrants["false_negatives"])
TP += len(quadrants["true_positives"])
FP += len(quadrants["false_positives"])
recall = calc_recall(TP, FN)
precision = calc_precision(TP, FP)
return calc_f1(recall, precision)
Expand All @@ -226,18 +236,18 @@ def per_class_f1(true, predicted, span_type: str | Callable = "token"):
"""
F1-scores per class
"""
count_fn = get_seq_count_fn(span_type)
class_counts = count_fn(true, predicted)
quadrants_fn = get_seq_quadrants_fn(span_type)
class_quadrants = quadrants_fn(true, predicted)
results = OrderedDict()
for cls_, counts in class_counts.items():
if counts is None:
for cls_, quadrants in class_quadrants.items():
if quadrants is None:
# Class is skipped due to missing start or end
results[cls_] = None
continue
results[cls_] = {}
FP = len(counts["false_positives"])
FN = len(counts["false_negatives"])
TP = len(counts["true_positives"])
FP = len(quadrants["false_positives"])
FN = len(quadrants["false_negatives"])
TP = len(quadrants["true_positives"])
recall = calc_recall(TP, FN)
precision = calc_precision(TP, FP)
results[cls_]["support"] = FN + TP
Expand Down Expand Up @@ -336,36 +346,42 @@ def sequence_superset(true_seq, pred_seq):
return pred_seq["start"] <= true_seq["start"] and pred_seq["end"] >= true_seq["end"]


def single_class_single_example_counts(true, predicted, equality_fn):
def single_class_single_example_quadrants(true, predicted, equality_fn):
"""
Return FP, FN, and TP counts for a single class
Return FP, FN, and TP quadrants for a single class
"""
# Some of the equality_fn checks are redundant, so it's helpful if the equality_fn is cached
counts = {"false_positives": [], "false_negatives": [], "true_positives": []}
quadrants = {"false_positives": [], "false_negatives": [], "true_positives": []}
try:
for true_annotation in true:
for pred_annotation in predicted:
if equality_fn(true_annotation, pred_annotation):
counts["true_positives"].append(true_annotation)
quadrants["true_positives"].append(
{"true": true_annotation, "pred": pred_annotation}
)
break
else:
counts["false_negatives"].append(true_annotation)
quadrants["false_negatives"].append(
{"true": true_annotation, "pred": None}
)

for pred_annotation in predicted:
for true_annotation in true:
if equality_fn(true_annotation, pred_annotation):
break
else:
counts["false_positives"].append(pred_annotation)
quadrants["false_positives"].append(
{"true": None, "pred": pred_annotation}
)
except KeyError:
# Missing start or end
return {"skip_class": True}
return counts
return quadrants


def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):
def sequence_labeling_quadrants(true, predicted, equality_fn, n_threads=5):
"""
Return FP, FN, and TP counts
Return FP, FN, and TP quadrants
"""
unique_classes = _get_unique_classes(true, predicted)

Expand All @@ -391,18 +407,18 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):
for annotation in annotations:
annotation["doc_idx"] = i

ex_counts_future = pool.submit(
single_class_single_example_counts,
ex_quadrants_future = pool.submit(
single_class_single_example_quadrants,
true_cls_annotations,
predicted_cls_annotations,
equality_fn,
)
future_to_cls[ex_counts_future] = cls_
future_to_cls[ex_quadrants_future] = cls_

for future in future_to_cls:
cls_ = future_to_cls[future]
ex_counts = future.result()
if ex_counts.get("skip_class", False) or cls_ in d and d[cls_] is None:
ex_quadrants = future.result()
if ex_quadrants.get("skip_class", False) or cls_ in d and d[cls_] is None:
# Class is skipped due to key error on equality function
d[cls_] = None
continue
Expand All @@ -412,7 +428,7 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):
"false_negatives": [],
"true_positives": [],
}
for key, value in ex_counts.items():
for key, value in ex_quadrants.items():
d[cls_][key].extend(value)

return d
Expand All @@ -427,20 +443,20 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):


SPAN_TYPE_FN_MAPPING = {
"token": sequence_labeling_token_counts,
"overlap": partial(sequence_labeling_counts, equality_fn=sequences_overlap),
"exact": partial(sequence_labeling_counts, equality_fn=sequence_exact_match),
"superset": partial(sequence_labeling_counts, equality_fn=sequence_superset),
"value": partial(sequence_labeling_counts, equality_fn=fuzzy_compare),
"token": sequence_labeling_token_quadrants,
"overlap": partial(sequence_labeling_quadrants, equality_fn=sequences_overlap),
"exact": partial(sequence_labeling_quadrants, equality_fn=sequence_exact_match),
"superset": partial(sequence_labeling_quadrants, equality_fn=sequence_superset),
"value": partial(sequence_labeling_quadrants, equality_fn=fuzzy_compare),
}


def get_seq_count_fn(span_type: str | Callable = "token"):
def get_seq_quadrants_fn(span_type: str | Callable = "token"):
if isinstance(span_type, str):
return SPAN_TYPE_FN_MAPPING[span_type]
elif callable(span_type):
# Interpret span_type as an equality function
return partial(sequence_labeling_counts, equality_fn=span_type)
return partial(sequence_labeling_quadrants, equality_fn=span_type)

raise ValueError(
f"Invalid span_type: {span_type}. Must either be a string or a callable."
Expand Down Expand Up @@ -524,7 +540,7 @@ def annotation_report(


def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dict]:
counts = get_seq_count_fn(span_type)(labels, preds)
quadrants = get_seq_quadrants_fn(span_type)(labels, preds)
precisions = seq_precision(labels, preds, span_type)
recalls = seq_recall(labels, preds, span_type)
per_class_f1s = sequence_f1(labels, preds, span_type)
Expand All @@ -534,17 +550,17 @@ def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dic
f1=(per_class_f1s[class_] or {}).get("f1-score", None),
recall=recalls[class_],
precision=precisions[class_],
false_positives=len(counts[class_]["false_positives"])
if counts[class_] is not None
false_positives=len(quadrants[class_]["false_positives"])
if quadrants[class_] is not None
else None,
false_negatives=len(counts[class_]["false_negatives"])
if counts[class_] is not None
false_negatives=len(quadrants[class_]["false_negatives"])
if quadrants[class_] is not None
else None,
true_positives=len(counts[class_]["true_positives"])
if counts[class_] is not None
true_positives=len(quadrants[class_]["true_positives"])
if quadrants[class_] is not None
else None,
)
if class_ in counts
if class_ in quadrants
else dict(
f1=0.0,
recall=0.0,
Expand Down
8 changes: 1 addition & 7 deletions sequence_metrics/wandb_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
EQUALITY_FN_MAP,
_get_unique_classes,
get_all_metrics,
get_seq_count_fn,
get_seq_quadrants_fn,
)


Expand All @@ -33,12 +33,6 @@ class PredSpan(LabelSpan):
}


def sequence_labeling_counts(true, predicted, equality_fn):
"""
Return FP, FN, and TP counts
"""


def map_span(pred, tpe):
output = {}
for key, value in pred.items():
Expand Down
35 changes: 32 additions & 3 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sequence_metrics.metrics import (
get_all_metrics,
get_seq_count_fn,
get_seq_quadrants_fn,
seq_precision,
seq_recall,
sequence_f1,
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_overlap(a, b, expected):


def check_metrics(Y, Y_pred, expected, span_type=None):
counts = get_seq_count_fn(span_type)(Y, Y_pred)
counts = get_seq_quadrants_fn(span_type)(Y, Y_pred)
precisions = seq_precision(Y, Y_pred, span_type=span_type)
recalls = seq_recall(Y, Y_pred, span_type=span_type)
micro_f1_score = sequence_f1(Y, Y_pred, span_type=span_type, average="micro")
Expand Down Expand Up @@ -876,7 +876,7 @@ def _same_charset(a: dict, b: dict):
],
)
def test_custom_equality_fn(true, pred, expected, expected_f1):
result = get_seq_count_fn(_same_charset)(true, pred)
result = get_seq_quadrants_fn(_same_charset)(true, pred)
result_subset = {
k: v
for k, v in result["class1"].items()
Expand All @@ -887,3 +887,32 @@ def test_custom_equality_fn(true, pred, expected, expected_f1):
assert len(result_subset["false_negatives"]) == expected["FN"]
predicted_f1 = sequence_f1(true, pred, span_type=_same_charset, average="macro")
assert abs(predicted_f1 - expected_f1) < 0.001


@pytest.mark.parametrize(
"span_type", ["value", "exact", "overlap", lambda x, y: x["text"] == y["text"]]
)
def test_sequence_labeling_quadrants(span_type):
true = [
[{"start": 0, "end": 1, "text": "a", "label": "class1", "other_key": "true_a"}],
[{"start": 0, "end": 1, "text": "b", "label": "class1", "other_key": "true_b"}],
]
pred = [
[{"start": 0, "end": 1, "text": "a", "label": "class1", "other_key": "pred_a"}],
[{"start": 0, "end": 1, "text": "b", "label": "class1", "other_key": "pred_b"}],
]
quadrants = get_seq_quadrants_fn(span_type=span_type)(true, pred)
assert quadrants.keys() == {"class1"}
for quadrant in ["true_positives", "false_positives", "false_negatives"]:
assert isinstance(quadrants["class1"][quadrant], list)
for instance in quadrants["class1"][quadrant]:
assert instance.keys() == {"true", "pred"}
if quadrant == "true_positives":
assert instance["true"] is not None and instance["pred"] is not None
else:
assert instance["true"] is None or instance["pred"] is None
for key in ["true", "pred"]:
# Assert that all keys are preserved.
pred_or_label = instance[key]
assert "other_key" in pred_or_label
assert pred_or_label["other_key"] == f"{key}_{pred_or_label['text']}"
4 changes: 2 additions & 2 deletions tests/test_metrics_robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sequence_metrics.metrics import (
get_all_metrics,
get_seq_count_fn,
get_seq_quadrants_fn,
seq_precision,
seq_recall,
sequence_f1,
Expand Down Expand Up @@ -40,7 +40,7 @@ def all_combos(true, pred):
assert f1 is None
else:
assert isinstance(f1, float)
counts = get_seq_count_fn(span_type)(true, pred)
counts = get_seq_quadrants_fn(span_type)(true, pred)
assert set(counts.keys()) == set(classes)
f1_by_class = sequence_f1(true, pred, span_type=span_type)
precision_by_class = seq_precision(true, pred, span_type=span_type)
Expand Down

0 comments on commit f6cc5ce

Please sign in to comment.