From b024399b1065648ccf6d1eb697afb81c368a4d90 Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Mon, 4 Nov 2024 13:30:35 +0000 Subject: [PATCH 1/2] ADD: make metrics robust to missing values and missing starts and ends --- sequence_metrics/metrics.py | 110 +++++++++++++++++++------- sequence_metrics/testing.py | 59 ++++++++++++++ tests/test_metrics.py | 50 +----------- tests/test_metrics_robust.py | 149 +++++++++++++++++++++++++++++++++++ 4 files changed, 292 insertions(+), 76 deletions(-) create mode 100644 sequence_metrics/testing.py create mode 100644 tests/test_metrics_robust.py diff --git a/sequence_metrics/metrics.py b/sequence_metrics/metrics.py index b4e3a90..0acd265 100644 --- a/sequence_metrics/metrics.py +++ b/sequence_metrics/metrics.py @@ -30,12 +30,14 @@ def _get_unique_classes(true, predicted): return list(set([seq["label"] for seqs in true_and_pred for seq in seqs])) -def _convert_to_token_list(annotations, doc_idx=None): +def _convert_to_token_list(annotations, doc_idx=None, unique_classes=None): nlp = get_spacy() tokens = [] annotations = copy.deepcopy(annotations) for annotation in annotations: + if unique_classes and annotation.get("label") not in unique_classes: + continue start_idx = annotation.get("start") tokens.extend( [ @@ -98,15 +100,28 @@ def sequence_labeling_token_counts(true, predicted): """ unique_classes = _get_unique_classes(true, predicted) + classes_to_skip = set( + l["label"] + for label in true + predicted + for l in label + if "start" not in l or "end" not in l + ) + unique_classes = [c for c in unique_classes if c not in classes_to_skip] d = { cls_: {"false_positives": [], "false_negatives": [], "true_positives": []} for cls_ in unique_classes } + for cls_ in classes_to_skip: + d[cls_] = None for i, (true_list, pred_list) in enumerate(zip(true, predicted)): - true_tokens = _convert_to_token_list(true_list, doc_idx=i) - pred_tokens = _convert_to_token_list(pred_list, doc_idx=i) + true_tokens = _convert_to_token_list( + true_list, doc_idx=i, unique_classes=unique_classes + ) + pred_tokens = _convert_to_token_list( + pred_list, doc_idx=i, unique_classes=unique_classes + ) # correct + false negatives for true_token in true_tokens: @@ -165,6 +180,10 @@ def seq_recall(true, predicted, span_type: str | Callable = "token"): class_counts = count_fn(true, predicted) results = {} for cls_, counts in class_counts.items(): + if counts is None: + # Class is skipped due to missing start or end + results[cls_] = None + continue FN = len(counts["false_negatives"]) TP = len(counts["true_positives"]) results[cls_] = calc_recall(TP, FN) @@ -176,6 +195,10 @@ def seq_precision(true, predicted, span_type: str | Callable = "token"): class_counts = count_fn(true, predicted) results = {} for cls_, counts in class_counts.items(): + if counts is None: + # Class is skipped due to missing start or end + results[cls_] = None + continue FP = len(counts["false_positives"]) TP = len(counts["true_positives"]) results[cls_] = calc_precision(TP, FP) @@ -187,6 +210,10 @@ def micro_f1(true, predicted, span_type: str | Callable = "token"): class_counts = count_fn(true, predicted) TP, FP, FN = 0, 0, 0 for counts in class_counts.values(): + if counts is None: + # Class is skipped due to missing start or end + # We cannot calculate a micro_f1 + return None FN += len(counts["false_negatives"]) TP += len(counts["true_positives"]) FP += len(counts["false_positives"]) @@ -203,6 +230,10 @@ def per_class_f1(true, predicted, span_type: str | Callable = "token"): class_counts = count_fn(true, predicted) results = OrderedDict() for cls_, counts in class_counts.items(): + if counts is None: + # Class is skipped due to missing start or end + results[cls_] = None + continue results[cls_] = {} FP = len(counts["false_positives"]) FN = len(counts["false_negatives"]) @@ -223,15 +254,22 @@ def sequence_f1(true, predicted, span_type: str | Callable = "token", average=No return micro_f1(true, predicted, span_type) f1s_by_class = per_class_f1(true, predicted, span_type) + if average is None: + return f1s_by_class + + if any(v is None for v in f1s_by_class.values()): + # Some classes are skipped due to missing start or end + return None f1s = [value.get("f1-score") for key, value in f1s_by_class.items()] supports = [value.get("support") for key, value in f1s_by_class.items()] if average == "weighted": + if sum(supports) == 0: + return 0.0 return np.average(np.array(f1s), weights=np.array(supports)) if average == "macro": return np.average(f1s) - else: - return f1s_by_class + raise ValueError(f"Unknown average: {average}") def strip_whitespace(y): @@ -304,20 +342,24 @@ def single_class_single_example_counts(true, predicted, equality_fn): """ # Some of the equality_fn checks are redundant, so it's helpful if the equality_fn is cached counts = {"false_positives": [], "false_negatives": [], "true_positives": []} - for true_annotation in true: - for pred_annotation in predicted: - if equality_fn(true_annotation, pred_annotation): - counts["true_positives"].append(true_annotation) - break - else: - counts["false_negatives"].append(true_annotation) - - for pred_annotation in predicted: + try: for true_annotation in true: - if equality_fn(true_annotation, pred_annotation): - break - else: - counts["false_positives"].append(pred_annotation) + for pred_annotation in predicted: + if equality_fn(true_annotation, pred_annotation): + counts["true_positives"].append(true_annotation) + break + else: + counts["false_negatives"].append(true_annotation) + + for pred_annotation in predicted: + for true_annotation in true: + if equality_fn(true_annotation, pred_annotation): + break + else: + counts["false_positives"].append(pred_annotation) + except KeyError: + # Missing start or end + return {"skip_class": True} return counts @@ -331,11 +373,6 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5): future_to_cls = {} with ThreadPoolExecutor(max_workers=n_threads) as pool: for cls_ in unique_classes: - d[cls_] = { - "false_positives": [], - "false_negatives": [], - "true_positives": [], - } for i, (true_annotations, predicted_annotations) in enumerate( zip(true, predicted) ): @@ -365,6 +402,16 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5): for future in future_to_cls: cls_ = future_to_cls[future] ex_counts = future.result() + if ex_counts.get("skip_class", False) or cls_ in d and d[cls_] is None: + # Class is skipped due to key error on equality function + d[cls_] = None + continue + if cls_ not in d: + d[cls_] = { + "false_positives": [], + "false_negatives": [], + "true_positives": [], + } for key, value in ex_counts.items(): d[cls_][key].extend(value) @@ -484,12 +531,18 @@ def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dic return { class_: ( dict( - f1=per_class_f1s[class_].get("f1-score"), + f1=(per_class_f1s[class_] or {}).get("f1-score", None), recall=recalls[class_], precision=precisions[class_], - false_positives=len(counts[class_]["false_positives"]), - false_negatives=len(counts[class_]["false_negatives"]), - true_positives=len(counts[class_]["true_positives"]), + false_positives=len(counts[class_]["false_positives"]) + if counts[class_] is not None + else None, + false_negatives=len(counts[class_]["false_negatives"]) + if counts[class_] is not None + else None, + true_positives=len(counts[class_]["true_positives"]) + if counts[class_] is not None + else None, ) if class_ in counts else dict( @@ -528,6 +581,9 @@ def summary_metrics(metrics): TP = 0 FP = 0 FN = 0 + if any(cls_metrics["f1"] is None for cls_metrics in span_metrics.values()): + summary[span_type] = None + continue for cls_metrics in span_metrics.values(): f1.append(cls_metrics["f1"]) precision.append(cls_metrics["precision"]) diff --git a/sequence_metrics/testing.py b/sequence_metrics/testing.py new file mode 100644 index 0000000..be293ee --- /dev/null +++ b/sequence_metrics/testing.py @@ -0,0 +1,59 @@ +def verify_all_metrics_structure(all_metrics, classes, none_classes=None): + span_types = ["token", "overlap", "exact", "superset", "value"] + assert len(all_metrics.keys()) == 2 + summary_metrics = all_metrics["summary_metrics"] + assert len(summary_metrics.keys()) == len(span_types) + for span_type in span_types: + if none_classes and span_type != "value": + assert summary_metrics[span_type] is None + else: + assert len(summary_metrics[span_type].keys()) == 9 + for metric in [ + "macro_f1", + "macro_precision", + "macro_recall", + "micro_precision", + "micro_recall", + "micro_f1", + "weighted_f1", + "weighted_precision", + "weighted_recall", + ]: + assert isinstance(summary_metrics[span_type][metric], float) + class_metrics = all_metrics["class_metrics"] + assert len(class_metrics) == len(span_types) + for span_type in span_types: + assert len(class_metrics[span_type]) == len(classes) + for cls_, metrics in class_metrics[span_type].items(): + assert cls_ in classes + assert len(metrics.keys()) == 6 + for metric in ["f1", "precision", "recall"]: + if none_classes and cls_ in none_classes and span_type != "value": + assert ( + metrics[metric] is None + ), f"{cls_} {metric}, {span_type} {metrics[metric]} should be None" + else: + assert isinstance(metrics[metric], float) + for metric in ["false_positives", "true_positives", "false_negatives"]: + if none_classes and cls_ in none_classes and span_type != "value": + assert metrics[metric] is None + else: + assert isinstance(metrics[metric], int) + + +def insert_text(docs, labels): + if len(docs) != len(labels): + raise ValueError("Number of documents must be equal to the number of labels") + for doc, label in zip(docs, labels): + for l in label: + if "text" not in l: + l["text"] = doc[l["start"] : l["end"]] + return labels + + +def extend_label(text, label, amt): + return insert_text([text for _ in range(amt)], [label for _ in range(amt)]) + + +def remove_label(recs, label): + return [[pred for pred in rec if not pred.get("label") == label] for rec in recs] diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fda900f..ef9c657 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -8,23 +8,7 @@ sequence_f1, sequences_overlap, ) - - -def insert_text(docs, labels): - if len(docs) != len(labels): - raise ValueError("Number of documents must be equal to the number of labels") - for doc, label in zip(docs, labels): - for l in label: - l["text"] = doc[l["start"] : l["end"]] - return labels - - -def extend_label(text, label, amt): - return insert_text([text for _ in range(amt)], [label for _ in range(amt)]) - - -def remove_label(recs, label): - return [[pred for pred in rec if not pred.get("label") == label] for rec in recs] +from sequence_metrics.testing import extend_label, verify_all_metrics_structure @pytest.mark.parametrize( @@ -772,38 +756,6 @@ def test_value_metrics(pred): ) -def verify_all_metrics_structure(all_metrics, classes): - span_types = ["token", "overlap", "exact", "superset", "value"] - assert len(all_metrics.keys()) == 2 - summary_metrics = all_metrics["summary_metrics"] - assert len(summary_metrics.keys()) == len(span_types) - for span_type in span_types: - assert len(summary_metrics[span_type].keys()) == 9 - for metric in [ - "macro_f1", - "macro_precision", - "macro_recall", - "micro_precision", - "micro_recall", - "micro_f1", - "weighted_f1", - "weighted_precision", - "weighted_recall", - ]: - assert isinstance(summary_metrics[span_type][metric], float) - class_metrics = all_metrics["class_metrics"] - assert len(class_metrics) == len(span_types) - for span_type in span_types: - assert len(class_metrics[span_type]) == len(classes) - for cls_, metrics in class_metrics[span_type].items(): - assert cls_ in classes - assert len(metrics.keys()) == 6 - for metric in ["f1", "precision", "recall"]: - assert isinstance(metrics[metric], float) - for metric in ["false_positives", "true_positives", "false_negatives"]: - assert isinstance(metrics[metric], int) - - def test_get_all_metrics(): text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." diff --git a/tests/test_metrics_robust.py b/tests/test_metrics_robust.py new file mode 100644 index 0000000..c9ab674 --- /dev/null +++ b/tests/test_metrics_robust.py @@ -0,0 +1,149 @@ +import pytest + +from sequence_metrics.metrics import ( + get_all_metrics, + get_seq_count_fn, + seq_precision, + seq_recall, + sequence_f1, + sequences_overlap, +) +from sequence_metrics.testing import extend_label, verify_all_metrics_structure + + +def exact_equality(a, b): + return a["text"] == b["text"] + + +def all_combos(true, pred): + classes = list(set([l["label"] for labels in true + pred for l in labels])) + classes_with_missing_keys = set( + [ + l["label"] + for label in true + pred + for l in label + if "start" not in l or "end" not in l + ] + ) + + for span_type, skips_missing in [ + ("token", True), + ("overlap", True), + ("exact", True), + ("superset", True), + ("value", False), + (exact_equality, False), + ]: + for average in ["micro", "macro", "weighted"]: + f1 = sequence_f1(true, pred, span_type=span_type, average=average) + if len(classes_with_missing_keys) > 0 and skips_missing: + assert f1 is None + else: + assert isinstance(f1, float) + counts = get_seq_count_fn(span_type)(true, pred) + assert set(counts.keys()) == set(classes) + f1_by_class = sequence_f1(true, pred, span_type=span_type) + precision_by_class = seq_precision(true, pred, span_type=span_type) + recall_by_class = seq_recall(true, pred, span_type=span_type) + for cls_ in classes: + f1 = f1_by_class[cls_] + prec = precision_by_class[cls_] + rec = recall_by_class[cls_] + cls_counts = counts[cls_] + + if cls_ in classes_with_missing_keys and skips_missing: + assert f1 is None + assert prec is None + assert rec is None + assert cls_counts is None + else: + assert isinstance(f1, dict) + assert isinstance(f1["f1-score"], float) + assert isinstance(f1["support"], int) + assert isinstance(prec, float) + assert isinstance(rec, float) + assert isinstance(cls_counts, dict) + assert len(cls_counts.keys()) == 3 + for metric in ["false_positives", "true_positives", "false_negatives"]: + assert isinstance( + cls_counts[metric], list + ) # Inexplicably, this is a list + + +def test_empty_preds(): + text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." + + y_true = extend_label( + text, + [ + {"start": 7, "end": 20, "label": "entity"}, + {"start": 41, "end": 54, "label": "date"}, + ], + 10, + ) + + y_pred = extend_label( + text, + [], + 10, + ) + + all_metrics = get_all_metrics(preds=y_pred, labels=y_true) + verify_all_metrics_structure(all_metrics=all_metrics, classes=["entity", "date"]) + all_combos(y_true, y_pred) + + +def test_empty_labels(): + text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." + + y_true = extend_label( + text, + [], + 10, + ) + + y_pred = extend_label( + text, + [ + {"start": 7, "end": 20, "label": "entity"}, + {"start": 41, "end": 54, "label": "date"}, + ], + 10, + ) + + all_metrics = get_all_metrics(preds=y_pred, labels=y_true) + verify_all_metrics_structure(all_metrics=all_metrics, classes=["entity", "date"]) + all_combos(y_true, y_pred) + + +def test_all_empty(): + text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." + + empty = extend_label( + text, + [], + 10, + ) + + all_metrics = get_all_metrics(preds=empty, labels=empty) + verify_all_metrics_structure(all_metrics=all_metrics, classes=[]) + all_combos(empty, empty) + + +def test_missing_start_end(): + text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." + + y_true = extend_label( + text, + [ + {"label": "entity", "text": "Pepsi Company"}, + {"start": 41, "end": 54, "label": "date"}, + ], + 10, + ) + + all_metrics = get_all_metrics(preds=y_true, labels=y_true) + verify_all_metrics_structure( + all_metrics=all_metrics, classes=["entity", "date"], none_classes=["entity"] + ) + all_combos(y_true, y_true) From ecac6b7678bf6bcd88b336e930c037876eac62be Mon Sep 17 00:00:00 2001 From: Madison May Date: Tue, 5 Nov 2024 13:15:06 -0500 Subject: [PATCH 2/2] Update pyproject.toml --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 235803c..0e10585 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sequence-metrics" -version = "0.0.5" +version = "0.0.6" description = "A set of metrics for Sequence Labelling tasks" readme = "README.md" authors = ["Indico Data "] @@ -23,4 +23,4 @@ en-core-web-sm = {url = "https://github.com/explosion/spacy-models/releases/down addopts = "-ra -sv" testpaths = [ "tests" -] \ No newline at end of file +]