Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom comparison fns #11

Merged
merged 4 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "sequence-metrics"
version = "0.0.5"
version = "0.0.6"
description = "A set of metrics for Sequence Labelling tasks"
readme = "README.md"
authors = ["Indico Data <[email protected]>"]
Expand All @@ -23,4 +23,4 @@ en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/down
addopts = "-ra -sv"
testpaths = [
"tests"
]
]
110 changes: 83 additions & 27 deletions sequence_metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ def _get_unique_classes(true, predicted):
return list(set([seq["label"] for seqs in true_and_pred for seq in seqs]))


def _convert_to_token_list(annotations, doc_idx=None):
def _convert_to_token_list(annotations, doc_idx=None, unique_classes=None):
nlp = get_spacy()
tokens = []
annotations = copy.deepcopy(annotations)

for annotation in annotations:
if unique_classes and annotation.get("label") not in unique_classes:
continue
start_idx = annotation.get("start")
tokens.extend(
[
Expand Down Expand Up @@ -98,15 +100,28 @@ def sequence_labeling_token_counts(true, predicted):
"""

unique_classes = _get_unique_classes(true, predicted)
classes_to_skip = set(
l["label"]
for label in true + predicted
for l in label
if "start" not in l or "end" not in l
)
unique_classes = [c for c in unique_classes if c not in classes_to_skip]

d = {
cls_: {"false_positives": [], "false_negatives": [], "true_positives": []}
for cls_ in unique_classes
}
for cls_ in classes_to_skip:
d[cls_] = None

for i, (true_list, pred_list) in enumerate(zip(true, predicted)):
true_tokens = _convert_to_token_list(true_list, doc_idx=i)
pred_tokens = _convert_to_token_list(pred_list, doc_idx=i)
true_tokens = _convert_to_token_list(
true_list, doc_idx=i, unique_classes=unique_classes
)
pred_tokens = _convert_to_token_list(
pred_list, doc_idx=i, unique_classes=unique_classes
)

# correct + false negatives
for true_token in true_tokens:
Expand Down Expand Up @@ -165,6 +180,10 @@ def seq_recall(true, predicted, span_type: str | Callable = "token"):
class_counts = count_fn(true, predicted)
results = {}
for cls_, counts in class_counts.items():
if counts is None:
# Class is skipped due to missing start or end
results[cls_] = None
continue
FN = len(counts["false_negatives"])
TP = len(counts["true_positives"])
results[cls_] = calc_recall(TP, FN)
Expand All @@ -176,6 +195,10 @@ def seq_precision(true, predicted, span_type: str | Callable = "token"):
class_counts = count_fn(true, predicted)
results = {}
for cls_, counts in class_counts.items():
if counts is None:
# Class is skipped due to missing start or end
results[cls_] = None
continue
FP = len(counts["false_positives"])
TP = len(counts["true_positives"])
results[cls_] = calc_precision(TP, FP)
Expand All @@ -187,6 +210,10 @@ def micro_f1(true, predicted, span_type: str | Callable = "token"):
class_counts = count_fn(true, predicted)
TP, FP, FN = 0, 0, 0
for counts in class_counts.values():
if counts is None:
# Class is skipped due to missing start or end
# We cannot calculate a micro_f1
return None
FN += len(counts["false_negatives"])
TP += len(counts["true_positives"])
FP += len(counts["false_positives"])
Expand All @@ -203,6 +230,10 @@ def per_class_f1(true, predicted, span_type: str | Callable = "token"):
class_counts = count_fn(true, predicted)
results = OrderedDict()
for cls_, counts in class_counts.items():
if counts is None:
# Class is skipped due to missing start or end
results[cls_] = None
continue
results[cls_] = {}
FP = len(counts["false_positives"])
FN = len(counts["false_negatives"])
Expand All @@ -223,15 +254,22 @@ def sequence_f1(true, predicted, span_type: str | Callable = "token", average=No
return micro_f1(true, predicted, span_type)

f1s_by_class = per_class_f1(true, predicted, span_type)
if average is None:
return f1s_by_class

if any(v is None for v in f1s_by_class.values()):
# Some classes are skipped due to missing start or end
return None
f1s = [value.get("f1-score") for key, value in f1s_by_class.items()]
supports = [value.get("support") for key, value in f1s_by_class.items()]

if average == "weighted":
if sum(supports) == 0:
return 0.0
return np.average(np.array(f1s), weights=np.array(supports))
if average == "macro":
return np.average(f1s)
else:
return f1s_by_class
raise ValueError(f"Unknown average: {average}")


def strip_whitespace(y):
Expand Down Expand Up @@ -304,20 +342,24 @@ def single_class_single_example_counts(true, predicted, equality_fn):
"""
# Some of the equality_fn checks are redundant, so it's helpful if the equality_fn is cached
counts = {"false_positives": [], "false_negatives": [], "true_positives": []}
for true_annotation in true:
for pred_annotation in predicted:
if equality_fn(true_annotation, pred_annotation):
counts["true_positives"].append(true_annotation)
break
else:
counts["false_negatives"].append(true_annotation)

for pred_annotation in predicted:
try:
for true_annotation in true:
if equality_fn(true_annotation, pred_annotation):
break
else:
counts["false_positives"].append(pred_annotation)
for pred_annotation in predicted:
if equality_fn(true_annotation, pred_annotation):
counts["true_positives"].append(true_annotation)
break
else:
counts["false_negatives"].append(true_annotation)

for pred_annotation in predicted:
for true_annotation in true:
if equality_fn(true_annotation, pred_annotation):
break
else:
counts["false_positives"].append(pred_annotation)
except KeyError:
# Missing start or end
return {"skip_class": True}
return counts


Expand All @@ -331,11 +373,6 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):
future_to_cls = {}
with ThreadPoolExecutor(max_workers=n_threads) as pool:
for cls_ in unique_classes:
d[cls_] = {
"false_positives": [],
"false_negatives": [],
"true_positives": [],
}
for i, (true_annotations, predicted_annotations) in enumerate(
zip(true, predicted)
):
Expand Down Expand Up @@ -365,6 +402,16 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):
for future in future_to_cls:
cls_ = future_to_cls[future]
ex_counts = future.result()
if ex_counts.get("skip_class", False) or cls_ in d and d[cls_] is None:
# Class is skipped due to key error on equality function
d[cls_] = None
continue
if cls_ not in d:
d[cls_] = {
"false_positives": [],
"false_negatives": [],
"true_positives": [],
}
for key, value in ex_counts.items():
d[cls_][key].extend(value)

Expand Down Expand Up @@ -484,12 +531,18 @@ def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dic
return {
class_: (
dict(
f1=per_class_f1s[class_].get("f1-score"),
f1=(per_class_f1s[class_] or {}).get("f1-score", None),
recall=recalls[class_],
precision=precisions[class_],
false_positives=len(counts[class_]["false_positives"]),
false_negatives=len(counts[class_]["false_negatives"]),
true_positives=len(counts[class_]["true_positives"]),
false_positives=len(counts[class_]["false_positives"])
if counts[class_] is not None
else None,
false_negatives=len(counts[class_]["false_negatives"])
if counts[class_] is not None
else None,
true_positives=len(counts[class_]["true_positives"])
if counts[class_] is not None
else None,
)
if class_ in counts
else dict(
Expand Down Expand Up @@ -528,6 +581,9 @@ def summary_metrics(metrics):
TP = 0
FP = 0
FN = 0
if any(cls_metrics["f1"] is None for cls_metrics in span_metrics.values()):
summary[span_type] = None
continue
for cls_metrics in span_metrics.values():
f1.append(cls_metrics["f1"])
precision.append(cls_metrics["precision"])
Expand Down
59 changes: 59 additions & 0 deletions sequence_metrics/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
def verify_all_metrics_structure(all_metrics, classes, none_classes=None):
span_types = ["token", "overlap", "exact", "superset", "value"]
assert len(all_metrics.keys()) == 2
summary_metrics = all_metrics["summary_metrics"]
assert len(summary_metrics.keys()) == len(span_types)
for span_type in span_types:
if none_classes and span_type != "value":
assert summary_metrics[span_type] is None
else:
assert len(summary_metrics[span_type].keys()) == 9
for metric in [
"macro_f1",
"macro_precision",
"macro_recall",
"micro_precision",
"micro_recall",
"micro_f1",
"weighted_f1",
"weighted_precision",
"weighted_recall",
]:
assert isinstance(summary_metrics[span_type][metric], float)
class_metrics = all_metrics["class_metrics"]
assert len(class_metrics) == len(span_types)
for span_type in span_types:
assert len(class_metrics[span_type]) == len(classes)
for cls_, metrics in class_metrics[span_type].items():
assert cls_ in classes
assert len(metrics.keys()) == 6
for metric in ["f1", "precision", "recall"]:
if none_classes and cls_ in none_classes and span_type != "value":
assert (
metrics[metric] is None
), f"{cls_} {metric}, {span_type} {metrics[metric]} should be None"
else:
assert isinstance(metrics[metric], float)
for metric in ["false_positives", "true_positives", "false_negatives"]:
if none_classes and cls_ in none_classes and span_type != "value":
assert metrics[metric] is None
else:
assert isinstance(metrics[metric], int)


def insert_text(docs, labels):
if len(docs) != len(labels):
raise ValueError("Number of documents must be equal to the number of labels")
for doc, label in zip(docs, labels):
for l in label:
if "text" not in l:
l["text"] = doc[l["start"] : l["end"]]
return labels


def extend_label(text, label, amt):
return insert_text([text for _ in range(amt)], [label for _ in range(amt)])


def remove_label(recs, label):
return [[pred for pred in rec if not pred.get("label") == label] for rec in recs]
50 changes: 1 addition & 49 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,7 @@
sequence_f1,
sequences_overlap,
)


def insert_text(docs, labels):
if len(docs) != len(labels):
raise ValueError("Number of documents must be equal to the number of labels")
for doc, label in zip(docs, labels):
for l in label:
l["text"] = doc[l["start"] : l["end"]]
return labels


def extend_label(text, label, amt):
return insert_text([text for _ in range(amt)], [label for _ in range(amt)])


def remove_label(recs, label):
return [[pred for pred in rec if not pred.get("label") == label] for rec in recs]
from sequence_metrics.testing import extend_label, verify_all_metrics_structure


@pytest.mark.parametrize(
Expand Down Expand Up @@ -772,38 +756,6 @@ def test_value_metrics(pred):
)


def verify_all_metrics_structure(all_metrics, classes):
span_types = ["token", "overlap", "exact", "superset", "value"]
assert len(all_metrics.keys()) == 2
summary_metrics = all_metrics["summary_metrics"]
assert len(summary_metrics.keys()) == len(span_types)
for span_type in span_types:
assert len(summary_metrics[span_type].keys()) == 9
for metric in [
"macro_f1",
"macro_precision",
"macro_recall",
"micro_precision",
"micro_recall",
"micro_f1",
"weighted_f1",
"weighted_precision",
"weighted_recall",
]:
assert isinstance(summary_metrics[span_type][metric], float)
class_metrics = all_metrics["class_metrics"]
assert len(class_metrics) == len(span_types)
for span_type in span_types:
assert len(class_metrics[span_type]) == len(classes)
for cls_, metrics in class_metrics[span_type].items():
assert cls_ in classes
assert len(metrics.keys()) == 6
for metric in ["f1", "precision", "recall"]:
assert isinstance(metrics[metric], float)
for metric in ["false_positives", "true_positives", "false_negatives"]:
assert isinstance(metrics[metric], int)


def test_get_all_metrics():
text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited."

Expand Down
Loading
Loading