Skip to content

Commit

Permalink
WIP: sequence metrics threading
Browse files Browse the repository at this point in the history
  • Loading branch information
madisonmay committed Nov 1, 2024
1 parent d1e1284 commit 30a1274
Showing 1 changed file with 65 additions and 32 deletions.
97 changes: 65 additions & 32 deletions sequence_metrics/metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import re
from collections import OrderedDict, defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from functools import partial
from typing import Callable

Expand Down Expand Up @@ -297,43 +298,75 @@ def sequence_superset(true_seq, pred_seq):
return pred_seq["start"] <= true_seq["start"] and pred_seq["end"] >= true_seq["end"]


def sequence_labeling_counts(true, predicted, equality_fn):
def single_class_single_example_counts(true, predicted, equality_fn):
"""
Return FP, FN, and TP counts for a single class
"""
# Some of the equality_fn checks are redundant, so it's helpful if the equality_fn is cached
counts = {"false_positives": [], "false_negatives": [], "true_positives": []}
for true_annotation in true:
for pred_annotation in predicted:
if equality_fn(true_annotation, pred_annotation):
counts["true_positives"].append(true_annotation)
break
else:
counts["false_negatives"].append(true_annotation)

for pred_annotation in predicted:
for true_annotation in true:
if equality_fn(true_annotation, pred_annotation):
break
else:
counts["false_positives"].append(pred_annotation)
return counts


def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5):
"""
Return FP, FN, and TP counts
"""
unique_classes = _get_unique_classes(true, predicted)

d = {
cls_: {"false_positives": [], "false_negatives": [], "true_positives": []}
for cls_ in unique_classes
}

for i, (true_annotations, predicted_annotations) in enumerate(zip(true, predicted)):
# add doc idx to make verification easier
for annotations in [true_annotations, predicted_annotations]:
for annotation in annotations:
annotation["doc_idx"] = i

for true_annotation in true_annotations:
for pred_annotation in predicted_annotations:
if equality_fn(true_annotation, pred_annotation):
if pred_annotation["label"] == true_annotation["label"]:
d[true_annotation["label"]]["true_positives"].append(
true_annotation
)
break
else:
d[true_annotation["label"]]["false_negatives"].append(true_annotation)

for pred_annotation in predicted_annotations:
for true_annotation in true_annotations:
if (
equality_fn(true_annotation, pred_annotation)
and true_annotation["label"] == pred_annotation["label"]
):
break
else:
d[pred_annotation["label"]]["false_positives"].append(pred_annotation)
d = {}
future_to_cls = {}
with ThreadPoolExecutor(max_workers=n_threads) as pool:
for cls_ in unique_classes:
d[cls_] = {
"false_positives": [],
"false_negatives": [],
"true_positives": [],
}
for i, (true_annotations, predicted_annotations) in enumerate(
zip(true, predicted)
):
# Per example
true_cls_annotations = [
annotation
for annotation in true_annotations
if annotation["label"] == cls_
]
predicted_cls_annotations = [
annotation
for annotation in predicted_annotations
if annotation["label"] == cls_
]
for annotations in [predicted_cls_annotations, true_cls_annotations]:
for annotation in annotations:
annotation["doc_idx"] = i

ex_counts_future = pool.submit(
single_class_single_example_counts,
true_cls_annotations,
predicted_cls_annotations,
equality_fn,
)
future_to_cls[ex_counts_future] = cls_

for future in future_to_cls:
cls_ = future_to_cls[future]
ex_counts = future.result()
for key, value in ex_counts.items():
d[cls_][key].extend(value)

return d

Expand Down

0 comments on commit 30a1274

Please sign in to comment.