Skip to content

Commit

Permalink
ADD: get_all_metrics and associated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
benleetownsend committed Jan 18, 2024
1 parent 28b8ce6 commit ab24415
Show file tree
Hide file tree
Showing 3 changed files with 484 additions and 126 deletions.
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/matthorgan/pre-commit-conventional-commits
rev: 20fb9631be1385998138432592d0b6d4dfa38fc9
hooks:
- id: conventional-commit-check
stages:
- commit-msg
- repo: https://github.com/pycqa/isort
rev: 5.12.0
hooks:
- id: isort
name: isort (python)
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 23.11.0
hooks:
- id: black
98 changes: 96 additions & 2 deletions sequence_metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def micro_f1(true, predicted, span_type="token"):
count_fn = get_seq_count_fn(span_type)
class_counts = count_fn(true, predicted)
TP, FP, FN = 0, 0, 0
for cls_, counts in class_counts.items():
for counts in class_counts.values():
FN += len(counts["false_negatives"])
TP += len(counts["true_positives"])
FP += len(counts["false_positives"])
Expand Down Expand Up @@ -246,7 +246,7 @@ def strip_whitespace(y):


def _norm_text(text: str) -> str:
return re.sub(r"[^A-Za-z0-9]", "", text)
return re.sub(r"[^A-Za-z0-9]", "", text).lower()


def fuzzy_compare(x: dict, y: dict) -> bool:
Expand Down Expand Up @@ -422,3 +422,97 @@ def annotation_report(
]
report += row_fmt.format(last_line_heading, *averages, width=width, digits=digits)
return report


def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dict]:
counts = get_seq_count_fn(span_type)(labels, preds)
precisions = seq_precision(labels, preds, span_type)
recalls = seq_recall(labels, preds, span_type)
per_class_f1s = sequence_f1(labels, preds, span_type)
return {
class_: (
dict(
f1=per_class_f1s[class_].get("f1-score"),
recall=recalls[class_],
precision=precisions[class_],
false_positives=len(counts[class_]["false_positives"]),
false_negatives=len(counts[class_]["false_negatives"]),
true_positives=len(counts[class_]["true_positives"]),
)
if class_ in counts
else dict(
f1=0.0,
recall=0.0,
precision=0.0,
false_positives=0,
false_negatives=0,
true_positives=0,
)
)
for class_ in field_names
}


def weighted_mean(value, weights):
if sum(weights) == 0.0:
return 0.0
return sum(v * w for v, w in zip(value, weights)) / sum(weights)


def mean(value: list):
if sum(value) == 0:
return 0.0
return sum(value) / len(value)


def summary_metrics(metrics):
summary = {}
for span_type, span_metrics in metrics.items():
span_type_summary = {}
f1 = []
precision = []
recall = []
weight = []
TP = 0
FP = 0
FN = 0
for cls_metrics in span_metrics.values():
f1.append(cls_metrics["f1"])
precision.append(cls_metrics["precision"])
recall.append(cls_metrics["recall"])
TP += cls_metrics["true_positives"]
FP += cls_metrics["false_positives"]
FN += cls_metrics["false_negatives"]
weight.append(
cls_metrics["true_positives"] + cls_metrics["false_negatives"]
)
span_type_summary["macro_f1"] = mean(f1)
span_type_summary["macro_precision"] = mean(precision)
span_type_summary["macro_recall"] = mean(recall)

span_type_summary["micro_precision"] = calc_precision(TP, FP)
span_type_summary["micro_recall"] = calc_recall(TP, FN)
span_type_summary["micro_f1"] = calc_f1(
span_type_summary["micro_recall"], span_type_summary["micro_precision"]
)

span_type_summary["weighted_f1"] = weighted_mean(f1, weight)
span_type_summary["weighted_precision"] = weighted_mean(precision, weight)
span_type_summary["weighted_recall"] = weighted_mean(recall, weight)
summary[span_type] = span_type_summary

return summary


def get_all_metrics(preds, labels, field_names=None):
if field_names is None:
field_names = sorted(set(l["label"] for li in (labels + preds) for l in li))
detailed_metrics = dict()
for span_type in ["token", "overlap", "exact", "superset", "value"]:
detailed_metrics[span_type] = get_spantype_metrics(
span_type=span_type, preds=preds, labels=labels, field_names=field_names
)
return {
"summary_metrics": summary_metrics(detailed_metrics),
"class_metrics": detailed_metrics,
}
Loading

0 comments on commit ab24415

Please sign in to comment.