From f6cc5ce861fc74b302577e90c9942e98cac80ae7 Mon Sep 17 00:00:00 2001 From: benleetownsend Date: Wed, 6 Nov 2024 10:18:10 +0000 Subject: [PATCH] ADD: break out counts aka quadrants so that they are useful (#12) --- pyproject.toml | 2 +- sequence_metrics/metrics.py | 142 ++++++++++++++++++-------------- sequence_metrics/wandb_tools.py | 8 +- tests/test_metrics.py | 35 +++++++- tests/test_metrics_robust.py | 4 +- 5 files changed, 115 insertions(+), 76 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0e10585..867105c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "sequence-metrics" -version = "0.0.6" +version = "0.0.7" description = "A set of metrics for Sequence Labelling tasks" readme = "README.md" authors = ["Indico Data "] diff --git a/sequence_metrics/metrics.py b/sequence_metrics/metrics.py index 0acd265..b128f78 100644 --- a/sequence_metrics/metrics.py +++ b/sequence_metrics/metrics.py @@ -94,9 +94,9 @@ def sequence_labeling_token_confusion(text, true, predicted): ) -def sequence_labeling_token_counts(true, predicted): +def sequence_labeling_token_quadrants(true, predicted): """ - Return FP, FN, and TP counts + Return FP, FN, and TP quadrants """ unique_classes = _get_unique_classes(true, predicted) @@ -131,14 +131,22 @@ def sequence_labeling_token_counts(true, predicted): and pred_token["end"] == true_token["end"] ): if pred_token["label"] == true_token["label"]: - d[true_token["label"]]["true_positives"].append(true_token) + d[true_token["label"]]["true_positives"].append( + {"true": true_token, "pred": pred_token} + ) else: - d[true_token["label"]]["false_negatives"].append(true_token) - d[pred_token["label"]]["false_positives"].append(pred_token) + d[true_token["label"]]["false_negatives"].append( + {"true": true_token, "pred": None} + ) + d[pred_token["label"]]["false_positives"].append( + {"true": None, "pred": pred_token} + ) break else: - d[true_token["label"]]["false_negatives"].append(true_token) + d[true_token["label"]]["false_negatives"].append( + {"true": true_token, "pred": None} + ) # false positives for pred_token in pred_tokens: @@ -149,7 +157,9 @@ def sequence_labeling_token_counts(true, predicted): ): break else: - d[pred_token["label"]]["false_positives"].append(pred_token) + d[pred_token["label"]]["false_positives"].append( + {"true": None, "pred": pred_token} + ) return d @@ -176,47 +186,47 @@ def calc_f1(recall, precision): def seq_recall(true, predicted, span_type: str | Callable = "token"): - count_fn = get_seq_count_fn(span_type) - class_counts = count_fn(true, predicted) + quadrants_fn = get_seq_quadrants_fn(span_type) + class_quadrants = quadrants_fn(true, predicted) results = {} - for cls_, counts in class_counts.items(): - if counts is None: + for cls_, quadrants in class_quadrants.items(): + if quadrants is None: # Class is skipped due to missing start or end results[cls_] = None continue - FN = len(counts["false_negatives"]) - TP = len(counts["true_positives"]) + FN = len(quadrants["false_negatives"]) + TP = len(quadrants["true_positives"]) results[cls_] = calc_recall(TP, FN) return results def seq_precision(true, predicted, span_type: str | Callable = "token"): - count_fn = get_seq_count_fn(span_type) - class_counts = count_fn(true, predicted) + quadrants_fn = get_seq_quadrants_fn(span_type) + class_quadrants = quadrants_fn(true, predicted) results = {} - for cls_, counts in class_counts.items(): - if counts is None: + for cls_, quadrants in class_quadrants.items(): + if quadrants is None: # Class is skipped due to missing start or end results[cls_] = None continue - FP = len(counts["false_positives"]) - TP = len(counts["true_positives"]) + FP = len(quadrants["false_positives"]) + TP = len(quadrants["true_positives"]) results[cls_] = calc_precision(TP, FP) return results def micro_f1(true, predicted, span_type: str | Callable = "token"): - count_fn = get_seq_count_fn(span_type) - class_counts = count_fn(true, predicted) + quadrants_fn = get_seq_quadrants_fn(span_type) + class_quadrants = quadrants_fn(true, predicted) TP, FP, FN = 0, 0, 0 - for counts in class_counts.values(): - if counts is None: + for quadrants in class_quadrants.values(): + if quadrants is None: # Class is skipped due to missing start or end # We cannot calculate a micro_f1 return None - FN += len(counts["false_negatives"]) - TP += len(counts["true_positives"]) - FP += len(counts["false_positives"]) + FN += len(quadrants["false_negatives"]) + TP += len(quadrants["true_positives"]) + FP += len(quadrants["false_positives"]) recall = calc_recall(TP, FN) precision = calc_precision(TP, FP) return calc_f1(recall, precision) @@ -226,18 +236,18 @@ def per_class_f1(true, predicted, span_type: str | Callable = "token"): """ F1-scores per class """ - count_fn = get_seq_count_fn(span_type) - class_counts = count_fn(true, predicted) + quadrants_fn = get_seq_quadrants_fn(span_type) + class_quadrants = quadrants_fn(true, predicted) results = OrderedDict() - for cls_, counts in class_counts.items(): - if counts is None: + for cls_, quadrants in class_quadrants.items(): + if quadrants is None: # Class is skipped due to missing start or end results[cls_] = None continue results[cls_] = {} - FP = len(counts["false_positives"]) - FN = len(counts["false_negatives"]) - TP = len(counts["true_positives"]) + FP = len(quadrants["false_positives"]) + FN = len(quadrants["false_negatives"]) + TP = len(quadrants["true_positives"]) recall = calc_recall(TP, FN) precision = calc_precision(TP, FP) results[cls_]["support"] = FN + TP @@ -336,36 +346,42 @@ def sequence_superset(true_seq, pred_seq): return pred_seq["start"] <= true_seq["start"] and pred_seq["end"] >= true_seq["end"] -def single_class_single_example_counts(true, predicted, equality_fn): +def single_class_single_example_quadrants(true, predicted, equality_fn): """ - Return FP, FN, and TP counts for a single class + Return FP, FN, and TP quadrants for a single class """ # Some of the equality_fn checks are redundant, so it's helpful if the equality_fn is cached - counts = {"false_positives": [], "false_negatives": [], "true_positives": []} + quadrants = {"false_positives": [], "false_negatives": [], "true_positives": []} try: for true_annotation in true: for pred_annotation in predicted: if equality_fn(true_annotation, pred_annotation): - counts["true_positives"].append(true_annotation) + quadrants["true_positives"].append( + {"true": true_annotation, "pred": pred_annotation} + ) break else: - counts["false_negatives"].append(true_annotation) + quadrants["false_negatives"].append( + {"true": true_annotation, "pred": None} + ) for pred_annotation in predicted: for true_annotation in true: if equality_fn(true_annotation, pred_annotation): break else: - counts["false_positives"].append(pred_annotation) + quadrants["false_positives"].append( + {"true": None, "pred": pred_annotation} + ) except KeyError: # Missing start or end return {"skip_class": True} - return counts + return quadrants -def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5): +def sequence_labeling_quadrants(true, predicted, equality_fn, n_threads=5): """ - Return FP, FN, and TP counts + Return FP, FN, and TP quadrants """ unique_classes = _get_unique_classes(true, predicted) @@ -391,18 +407,18 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5): for annotation in annotations: annotation["doc_idx"] = i - ex_counts_future = pool.submit( - single_class_single_example_counts, + ex_quadrants_future = pool.submit( + single_class_single_example_quadrants, true_cls_annotations, predicted_cls_annotations, equality_fn, ) - future_to_cls[ex_counts_future] = cls_ + future_to_cls[ex_quadrants_future] = cls_ for future in future_to_cls: cls_ = future_to_cls[future] - ex_counts = future.result() - if ex_counts.get("skip_class", False) or cls_ in d and d[cls_] is None: + ex_quadrants = future.result() + if ex_quadrants.get("skip_class", False) or cls_ in d and d[cls_] is None: # Class is skipped due to key error on equality function d[cls_] = None continue @@ -412,7 +428,7 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5): "false_negatives": [], "true_positives": [], } - for key, value in ex_counts.items(): + for key, value in ex_quadrants.items(): d[cls_][key].extend(value) return d @@ -427,20 +443,20 @@ def sequence_labeling_counts(true, predicted, equality_fn, n_threads=5): SPAN_TYPE_FN_MAPPING = { - "token": sequence_labeling_token_counts, - "overlap": partial(sequence_labeling_counts, equality_fn=sequences_overlap), - "exact": partial(sequence_labeling_counts, equality_fn=sequence_exact_match), - "superset": partial(sequence_labeling_counts, equality_fn=sequence_superset), - "value": partial(sequence_labeling_counts, equality_fn=fuzzy_compare), + "token": sequence_labeling_token_quadrants, + "overlap": partial(sequence_labeling_quadrants, equality_fn=sequences_overlap), + "exact": partial(sequence_labeling_quadrants, equality_fn=sequence_exact_match), + "superset": partial(sequence_labeling_quadrants, equality_fn=sequence_superset), + "value": partial(sequence_labeling_quadrants, equality_fn=fuzzy_compare), } -def get_seq_count_fn(span_type: str | Callable = "token"): +def get_seq_quadrants_fn(span_type: str | Callable = "token"): if isinstance(span_type, str): return SPAN_TYPE_FN_MAPPING[span_type] elif callable(span_type): # Interpret span_type as an equality function - return partial(sequence_labeling_counts, equality_fn=span_type) + return partial(sequence_labeling_quadrants, equality_fn=span_type) raise ValueError( f"Invalid span_type: {span_type}. Must either be a string or a callable." @@ -524,7 +540,7 @@ def annotation_report( def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dict]: - counts = get_seq_count_fn(span_type)(labels, preds) + quadrants = get_seq_quadrants_fn(span_type)(labels, preds) precisions = seq_precision(labels, preds, span_type) recalls = seq_recall(labels, preds, span_type) per_class_f1s = sequence_f1(labels, preds, span_type) @@ -534,17 +550,17 @@ def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dic f1=(per_class_f1s[class_] or {}).get("f1-score", None), recall=recalls[class_], precision=precisions[class_], - false_positives=len(counts[class_]["false_positives"]) - if counts[class_] is not None + false_positives=len(quadrants[class_]["false_positives"]) + if quadrants[class_] is not None else None, - false_negatives=len(counts[class_]["false_negatives"]) - if counts[class_] is not None + false_negatives=len(quadrants[class_]["false_negatives"]) + if quadrants[class_] is not None else None, - true_positives=len(counts[class_]["true_positives"]) - if counts[class_] is not None + true_positives=len(quadrants[class_]["true_positives"]) + if quadrants[class_] is not None else None, ) - if class_ in counts + if class_ in quadrants else dict( f1=0.0, recall=0.0, diff --git a/sequence_metrics/wandb_tools.py b/sequence_metrics/wandb_tools.py index ecea2f6..1f94a26 100644 --- a/sequence_metrics/wandb_tools.py +++ b/sequence_metrics/wandb_tools.py @@ -7,7 +7,7 @@ EQUALITY_FN_MAP, _get_unique_classes, get_all_metrics, - get_seq_count_fn, + get_seq_quadrants_fn, ) @@ -33,12 +33,6 @@ class PredSpan(LabelSpan): } -def sequence_labeling_counts(true, predicted, equality_fn): - """ - Return FP, FN, and TP counts - """ - - def map_span(pred, tpe): output = {} for key, value in pred.items(): diff --git a/tests/test_metrics.py b/tests/test_metrics.py index ef9c657..1579147 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -2,7 +2,7 @@ from sequence_metrics.metrics import ( get_all_metrics, - get_seq_count_fn, + get_seq_quadrants_fn, seq_precision, seq_recall, sequence_f1, @@ -57,7 +57,7 @@ def test_overlap(a, b, expected): def check_metrics(Y, Y_pred, expected, span_type=None): - counts = get_seq_count_fn(span_type)(Y, Y_pred) + counts = get_seq_quadrants_fn(span_type)(Y, Y_pred) precisions = seq_precision(Y, Y_pred, span_type=span_type) recalls = seq_recall(Y, Y_pred, span_type=span_type) micro_f1_score = sequence_f1(Y, Y_pred, span_type=span_type, average="micro") @@ -876,7 +876,7 @@ def _same_charset(a: dict, b: dict): ], ) def test_custom_equality_fn(true, pred, expected, expected_f1): - result = get_seq_count_fn(_same_charset)(true, pred) + result = get_seq_quadrants_fn(_same_charset)(true, pred) result_subset = { k: v for k, v in result["class1"].items() @@ -887,3 +887,32 @@ def test_custom_equality_fn(true, pred, expected, expected_f1): assert len(result_subset["false_negatives"]) == expected["FN"] predicted_f1 = sequence_f1(true, pred, span_type=_same_charset, average="macro") assert abs(predicted_f1 - expected_f1) < 0.001 + + +@pytest.mark.parametrize( + "span_type", ["value", "exact", "overlap", lambda x, y: x["text"] == y["text"]] +) +def test_sequence_labeling_quadrants(span_type): + true = [ + [{"start": 0, "end": 1, "text": "a", "label": "class1", "other_key": "true_a"}], + [{"start": 0, "end": 1, "text": "b", "label": "class1", "other_key": "true_b"}], + ] + pred = [ + [{"start": 0, "end": 1, "text": "a", "label": "class1", "other_key": "pred_a"}], + [{"start": 0, "end": 1, "text": "b", "label": "class1", "other_key": "pred_b"}], + ] + quadrants = get_seq_quadrants_fn(span_type=span_type)(true, pred) + assert quadrants.keys() == {"class1"} + for quadrant in ["true_positives", "false_positives", "false_negatives"]: + assert isinstance(quadrants["class1"][quadrant], list) + for instance in quadrants["class1"][quadrant]: + assert instance.keys() == {"true", "pred"} + if quadrant == "true_positives": + assert instance["true"] is not None and instance["pred"] is not None + else: + assert instance["true"] is None or instance["pred"] is None + for key in ["true", "pred"]: + # Assert that all keys are preserved. + pred_or_label = instance[key] + assert "other_key" in pred_or_label + assert pred_or_label["other_key"] == f"{key}_{pred_or_label['text']}" diff --git a/tests/test_metrics_robust.py b/tests/test_metrics_robust.py index c9ab674..597a2cf 100644 --- a/tests/test_metrics_robust.py +++ b/tests/test_metrics_robust.py @@ -2,7 +2,7 @@ from sequence_metrics.metrics import ( get_all_metrics, - get_seq_count_fn, + get_seq_quadrants_fn, seq_precision, seq_recall, sequence_f1, @@ -40,7 +40,7 @@ def all_combos(true, pred): assert f1 is None else: assert isinstance(f1, float) - counts = get_seq_count_fn(span_type)(true, pred) + counts = get_seq_quadrants_fn(span_type)(true, pred) assert set(counts.keys()) == set(classes) f1_by_class = sequence_f1(true, pred, span_type=span_type) precision_by_class = seq_precision(true, pred, span_type=span_type)