diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e1c6fd2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: https://github.com/matthorgan/pre-commit-conventional-commits + rev: 20fb9631be1385998138432592d0b6d4dfa38fc9 + hooks: + - id: conventional-commit-check + stages: + - commit-msg +- repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + name: isort (python) + args: ["--profile", "black"] +- repo: https://github.com/psf/black + rev: 23.11.0 + hooks: + - id: black diff --git a/sequence_metrics/metrics.py b/sequence_metrics/metrics.py index 3558342..9eff6b8 100644 --- a/sequence_metrics/metrics.py +++ b/sequence_metrics/metrics.py @@ -184,7 +184,7 @@ def micro_f1(true, predicted, span_type="token"): count_fn = get_seq_count_fn(span_type) class_counts = count_fn(true, predicted) TP, FP, FN = 0, 0, 0 - for cls_, counts in class_counts.items(): + for counts in class_counts.values(): FN += len(counts["false_negatives"]) TP += len(counts["true_positives"]) FP += len(counts["false_positives"]) @@ -246,7 +246,7 @@ def strip_whitespace(y): def _norm_text(text: str) -> str: - return re.sub(r"[^A-Za-z0-9]", "", text) + return re.sub(r"[^A-Za-z0-9]", "", text).lower() def fuzzy_compare(x: dict, y: dict) -> bool: @@ -422,3 +422,97 @@ def annotation_report( ] report += row_fmt.format(last_line_heading, *averages, width=width, digits=digits) return report + + +def get_spantype_metrics(span_type, preds, labels, field_names) -> dict[str, dict]: + counts = get_seq_count_fn(span_type)(labels, preds) + precisions = seq_precision(labels, preds, span_type) + recalls = seq_recall(labels, preds, span_type) + per_class_f1s = sequence_f1(labels, preds, span_type) + return { + class_: ( + dict( + f1=per_class_f1s[class_].get("f1-score"), + recall=recalls[class_], + precision=precisions[class_], + false_positives=len(counts[class_]["false_positives"]), + false_negatives=len(counts[class_]["false_negatives"]), + true_positives=len(counts[class_]["true_positives"]), + ) + if class_ in counts + else dict( + f1=0.0, + recall=0.0, + precision=0.0, + false_positives=0, + false_negatives=0, + true_positives=0, + ) + ) + for class_ in field_names + } + + +def weighted_mean(value, weights): + if sum(weights) == 0.0: + return 0.0 + return sum(v * w for v, w in zip(value, weights)) / sum(weights) + + +def mean(value: list): + if sum(value) == 0: + return 0.0 + return sum(value) / len(value) + + +def summary_metrics(metrics): + summary = {} + for span_type, span_metrics in metrics.items(): + span_type_summary = {} + f1 = [] + precision = [] + recall = [] + weight = [] + TP = 0 + FP = 0 + FN = 0 + for cls_metrics in span_metrics.values(): + f1.append(cls_metrics["f1"]) + precision.append(cls_metrics["precision"]) + recall.append(cls_metrics["recall"]) + TP += cls_metrics["true_positives"] + FP += cls_metrics["false_positives"] + FN += cls_metrics["false_negatives"] + weight.append( + cls_metrics["true_positives"] + cls_metrics["false_negatives"] + ) + span_type_summary["macro_f1"] = mean(f1) + span_type_summary["macro_precision"] = mean(precision) + span_type_summary["macro_recall"] = mean(recall) + + span_type_summary["micro_precision"] = calc_precision(TP, FP) + span_type_summary["micro_recall"] = calc_recall(TP, FN) + span_type_summary["micro_f1"] = calc_f1( + span_type_summary["micro_recall"], span_type_summary["micro_precision"] + ) + + span_type_summary["weighted_f1"] = weighted_mean(f1, weight) + span_type_summary["weighted_precision"] = weighted_mean(precision, weight) + span_type_summary["weighted_recall"] = weighted_mean(recall, weight) + summary[span_type] = span_type_summary + + return summary + + +def get_all_metrics(preds, labels, field_names=None): + if field_names is None: + field_names = sorted(set(l["label"] for li in (labels + preds) for l in li)) + detailed_metrics = dict() + for span_type in ["token", "overlap", "exact", "superset", "value"]: + detailed_metrics[span_type] = get_spantype_metrics( + span_type=span_type, preds=preds, labels=labels, field_names=field_names + ) + return { + "summary_metrics": summary_metrics(detailed_metrics), + "class_metrics": detailed_metrics, + } diff --git a/tests/test_metrics.py b/tests/test_metrics.py index f03d9ea..7f395f3 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,6 +1,7 @@ import pytest from sequence_metrics.metrics import ( + get_all_metrics, get_seq_count_fn, seq_precision, seq_recall, @@ -82,41 +83,81 @@ def check_metrics(Y, Y_pred, expected, span_type=None): for cls_ in counts: for metric in counts[cls_]: assert len(counts[cls_][metric]) == expected[cls_][metric] - assert recalls[cls_] == pytest.approx(expected[cls_]["recall"], abs=1e-3) - assert per_class_f1s[cls_]["f1-score"] == pytest.approx( - expected[cls_]["f1-score"], abs=1e-3 - ) - assert precisions[cls_] == pytest.approx( - expected[cls_]["precision"], abs=1e-3 - ) - - assert micro_f1_score == pytest.approx(expected["micro-f1"], abs=1e-3) - assert weighted_f1 == pytest.approx(expected["weighted-f1"], abs=1e-3) - assert macro_f1 == pytest.approx(expected["macro-f1"], abs=1e-3) + assert recalls[cls_] == pytest.approx(expected[cls_]["recall"], abs=1e-3) + assert per_class_f1s[cls_]["f1-score"] == pytest.approx( + expected[cls_]["f1-score"], abs=1e-3 + ) + assert precisions[cls_] == pytest.approx(expected[cls_]["precision"], abs=1e-3) + + assert micro_f1_score == pytest.approx(expected["micro_f1"], abs=1e-3) + assert weighted_f1 == pytest.approx(expected["weighted_f1"], abs=1e-3) + assert macro_f1 == pytest.approx(expected["macro_f1"], abs=1e-3) + + all_metrics = get_all_metrics(preds=Y_pred, labels=Y) + cls_metrics = all_metrics["class_metrics"][span_type] + for cls_, metrics in cls_metrics.items(): + assert metrics["false_positives"] == expected[cls_]["false_positives"] + assert metrics["false_negatives"] == expected[cls_]["false_negatives"] + assert metrics["true_positives"] == expected[cls_]["true_positives"] + + assert metrics["recall"] == pytest.approx(expected[cls_]["recall"], abs=1e-3) + assert metrics["f1"] == pytest.approx(expected[cls_]["f1-score"], abs=1e-3) + assert precisions[cls_] == pytest.approx(metrics["precision"], abs=1e-3) + summary_metrics = all_metrics["summary_metrics"][span_type] + assert summary_metrics["micro_f1"] == pytest.approx(expected["micro_f1"], abs=1e-3) + assert summary_metrics["weighted_f1"] == pytest.approx( + expected["weighted_f1"], abs=1e-3 + ) + assert summary_metrics["macro_f1"] == pytest.approx(expected["macro_f1"], abs=1e-3) + assert summary_metrics["micro_recall"] == pytest.approx( + expected["micro_recall"], abs=1e-3 + ) + assert summary_metrics["weighted_recall"] == pytest.approx( + expected["weighted_recall"], abs=1e-3 + ) + assert summary_metrics["macro_recall"] == pytest.approx( + expected["macro_recall"], abs=1e-3 + ) + assert summary_metrics["micro_precision"] == pytest.approx( + expected["micro_precision"], abs=1e-3 + ) + assert summary_metrics["weighted_precision"] == pytest.approx( + expected["weighted_precision"], abs=1e-3 + ) + assert summary_metrics["macro_precision"] == pytest.approx( + expected["macro_precision"], abs=1e-3 + ) -def test_token_incorrect(): +@pytest.mark.parametrize("span_type", ["overlap", "exact", "superset", "value"]) +def test_incorrect(span_type): text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." expected = { "entity": { - "false_positives": 10, - "false_negatives": 20, + "false_positives": 10 if span_type == "token" else 10, + "false_negatives": 20 if span_type == "token" else 10, "true_positives": 0, "precision": 0.0, "recall": 0.0, "f1-score": 0.0, }, "date": { - "false_positives": 10, - "false_negatives": 40, + "false_positives": 10 if span_type == "token" else 10, + "false_negatives": 40 if span_type == "token" else 10, "true_positives": 0, "precision": 0.0, "recall": 0.0, "f1-score": 0.0, }, - "micro-f1": 0.0, - "macro-f1": 0.0, - "weighted-f1": 0.0, + "micro_f1": 0.0, + "macro_f1": 0.0, + "weighted_f1": 0.0, + "micro_precision": 0.0, + "macro_precision": 0.0, + "weighted_precision": 0.0, + "micro_recall": 0.0, + "macro_recall": 0.0, + "weighted_recall": 0.0, } y_true = extend_label( text, @@ -134,7 +175,7 @@ def test_token_incorrect(): ], 10, ) - check_metrics(y_true, y_false_pos, expected, span_type="token") + check_metrics(y_true, y_false_pos, expected, span_type=span_type) def test_token_correct(): @@ -156,9 +197,15 @@ def test_token_correct(): "recall": 1.0, "f1-score": 1.0, }, - "micro-f1": 1.0, - "macro-f1": 1.0, - "weighted-f1": 1.0, + "micro_f1": 1.0, + "macro_f1": 1.0, + "weighted_f1": 1.0, + "micro_precision": 1.0, + "macro_precision": 1.0, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 1.0, + "weighted_recall": 1.0, } y_true = extend_label( text, @@ -215,9 +262,15 @@ def test_token_mixed(): "recall": 0.5, "f1-score": 0.6153, }, - "micro-f1": 0.6, - "macro-f1": 0.593, - "weighted-f1": 0.601, + "micro_f1": 0.6, + "macro_f1": 0.593, + "weighted_f1": 0.601, + "micro_precision": 0.75, + "macro_precision": 0.7333, + "weighted_precision": 0.7555, + "micro_recall": 0.5, + "macro_recall": 0.5, + "weighted_recall": 0.5, } check_metrics( y_true, @@ -273,9 +326,15 @@ def test_token_mixed_2(): "recall": 0.2, "f1-score": 0.302, }, - "micro-f1": 0.409, - "macro-f1": 0.437, - "weighted-f1": 0.392, + "micro_f1": 0.409, + "macro_f1": 0.437, + "weighted_f1": 0.392, + "micro_precision": 0.642, + "macro_precision": 0.641, + "weighted_precision": 0.632, + "micro_recall": 0.3, + "macro_recall": 0.35, + "weighted_recall": 0.3, } check_metrics( y_true, @@ -285,71 +344,6 @@ def test_token_mixed_2(): ) -def test_seq_incorrect(): - text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." - y_true = extend_label( - text, - [ - {"start": 7, "end": 20, "label": "entity"}, - {"start": 41, "end": 54, "label": "date"}, - ], - 10, - ) - y_false_pos = extend_label( - text, - [ - {"start": 21, "end": 28, "label": "entity"}, - {"start": 62, "end": 65, "label": "date"}, - ], - 10, - ) - seq_expected_incorrect = { - "entity": { - "false_positives": 10, - "false_negatives": 10, - "true_positives": 0, - "precision": 0.0, - "recall": 0.0, - "f1-score": 0.0, - }, - "date": { - "false_positives": 10, - "false_negatives": 10, - "true_positives": 0, - "precision": 0.0, - "recall": 0.0, - "f1-score": 0.0, - }, - "micro-f1": 0.0, - "macro-f1": 0.0, - "weighted-f1": 0.0, - } - - # Overlap - check_metrics( - y_true, - y_false_pos, - seq_expected_incorrect, - span_type="overlap", - ) - - # Exact - check_metrics( - y_true, - y_false_pos, - seq_expected_incorrect, - span_type="exact", - ) - - # Superset - check_metrics( - y_true, - y_false_pos, - seq_expected_incorrect, - span_type="superset", - ) - - @pytest.mark.parametrize( "overlapping", [ @@ -368,7 +362,7 @@ def test_seq_incorrect(): ], ], ) -def test_seq_mixed(overlapping): +def test_seq_mixed_overlap(overlapping): text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." expected = { "entity": { @@ -387,9 +381,15 @@ def test_seq_mixed(overlapping): "recall": 0.6, "f1-score": 0.6, }, - "micro-f1": 0.6, - "macro-f1": 0.6, - "weighted-f1": 0.6, + "micro_f1": 0.6, + "macro_f1": 0.6, + "weighted_f1": 0.6, + "micro_precision": 0.6, + "macro_precision": 0.6, + "weighted_precision": 0.6, + "micro_recall": 0.6, + "macro_recall": 0.6, + "weighted_recall": 0.6, } y_true = extend_label( text, @@ -434,6 +434,20 @@ def test_seq_mixed(overlapping): {"start": 38, "end": 60, "label": "date"}, ], ), + ( + "value", + [ + {"start": 7, "end": 20, "label": "entity"}, + {"start": 41, "end": 54, "label": "date"}, + ], + ), + ( + "value", + [ + {"start": 6, "end": 21, "label": "entity"}, + {"start": 40, "end": 54, "label": "date"}, + ], + ), ], ) def test_mixed_overlap(span_type, true_positive_non_exact): @@ -472,9 +486,15 @@ def test_mixed_overlap(span_type, true_positive_non_exact): "recall": 0.3, "f1-score": 0.353, }, - "micro-f1": 0.4864, - "macro-f1": 0.476, - "weighted-f1": 0.476, + "micro_f1": 0.4864, + "macro_f1": 0.476, + "weighted_f1": 0.476, + "micro_precision": 0.52941, + "macro_precision": 0.51428, + "weighted_precision": 0.51428, + "micro_recall": 0.45, + "macro_recall": 0.45, + "weighted_recall": 0.45, } check_metrics( extend_label(text, y_true, 10), @@ -508,9 +528,15 @@ def test_overlapping_2_class(): "recall": 0.0, "f1-score": 0.0, }, - "micro-f1": 0.66666, # Calculated as the harmonic mean of Recall = 1, Precision = 0.5 - "macro-f1": 0.5, - "weighted-f1": 1.0, # because there is no support for class2 + "micro_f1": 0.66666, # Calculated as the harmonic mean of Recall = 1, Precision = 0.5 + "macro_f1": 0.5, + "weighted_f1": 1.0, # because there is no support for class2 + "micro_precision": 0.5, + "macro_precision": 0.5, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 0.5, + "weighted_recall": 1.0, } check_metrics( [y_true], @@ -544,9 +570,15 @@ def test_overlapping_2_class_swapped(): "recall": 0.0, "f1-score": 0.0, }, - "micro-f1": 0.66666, # Calculated as the harmonic mean of Recall = 1, Precision = 0.5 - "macro-f1": 0.5, - "weighted-f1": 1.0, # because there is no support for class2 + "micro_f1": 0.66666, # Calculated as the harmonic mean of Recall = 1, Precision = 0.5 + "macro_f1": 0.5, + "weighted_f1": 1.0, # because there is no support for class2 + "micro_precision": 0.5, + "macro_precision": 0.5, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 0.5, + "weighted_recall": 1.0, } check_metrics( [y_true], @@ -572,9 +604,15 @@ def test_overlapping_1_class(): "recall": 1.0, "f1-score": 1.0, }, - "micro-f1": 1.0, - "macro-f1": 1.0, - "weighted-f1": 1.0, + "micro_f1": 1.0, + "macro_f1": 1.0, + "weighted_f1": 1.0, + "micro_precision": 1.0, + "macro_precision": 1.0, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 1.0, + "weighted_recall": 1.0, } check_metrics( [y_true], @@ -602,9 +640,15 @@ def test_2_class(): "recall": 1.0, "f1-score": 1.0, }, - "micro-f1": 1.0, - "macro-f1": 1.0, - "weighted-f1": 1.0, + "micro_f1": 1.0, + "macro_f1": 1.0, + "weighted_f1": 1.0, + "micro_precision": 1.0, + "macro_precision": 1.0, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 1.0, + "weighted_recall": 1.0, } for span_type in ["overlap", "superset"]: check_metrics( @@ -615,7 +659,10 @@ def test_2_class(): ) -def test_whitespace(): +@pytest.mark.parametrize( + "span_type", ["overlap", "exact", "superset", "value", "token"] +) +def test_whitespace(span_type): x = "a and b" y_true = [{"start": 0, "end": 7, "text": x, "label": "class1"}] y_pred = [ @@ -623,6 +670,83 @@ def test_whitespace(): ] expected = { "class1": { + "false_positives": 0, + "false_negatives": 0, + "true_positives": 3 if span_type == "token" else 1, + "precision": 1.0, + "recall": 1.0, + "f1-score": 1.0, + }, + "micro_f1": 1.0, + "macro_f1": 1.0, + "weighted_f1": 1.0, + "micro_precision": 1.0, + "macro_precision": 1.0, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 1.0, + "weighted_recall": 1.0, + } + check_metrics( + [y_true], + [y_pred], + expected=expected, + span_type=span_type, + ) + + +def test_class_filtering_get_all_metrics(): + y_true = [{"start": 0, "end": 7, "text": "a and b", "label": "class1"}] + y_pred = [ + {"start": 0, "end": 7, "text": "a and b", "label": "class1"}, + {"start": 8, "end": 17, "text": "b", "label": "class2"}, + ] + all_metrics = get_all_metrics( + preds=[y_pred], labels=[y_true], field_names=["class1"] + ) + for span_type, metrics in all_metrics["class_metrics"].items(): + assert len(metrics.keys()) == 1 + summary_metrics = all_metrics["summary_metrics"][span_type] + assert all( + [sm == 1.0 for sm in summary_metrics.values()] + ), summary_metrics # All 1 because we have only one class. + + +@pytest.mark.parametrize( + "pred", + [ + { + "start": 10, + "end": 16, + "text": "friday", + "label": "label1", + }, + { + "start": 10, + "end": 17, + "text": "friday ", + "label": "label1", + }, + { + "start": 10, + "end": 17, + "text": "Fri-day", + "label": "label1", + }, + ], +) +def test_value_metrics(pred): + y_true = [ + { + "start": 5, + "end": 11, + "text": "Friday", + "label": "label1", + } + ] + y_pred = [pred] + expected = { + "label1": { "false_positives": 0, "false_negatives": 0, "true_positives": 1, @@ -630,14 +754,137 @@ def test_whitespace(): "recall": 1.0, "f1-score": 1.0, }, - "micro-f1": 1.0, - "macro-f1": 1.0, - "weighted-f1": 1.0, + "micro_f1": 1.0, + "macro_f1": 1.0, + "weighted_f1": 1.0, + "micro_precision": 1.0, + "macro_precision": 1.0, + "weighted_precision": 1.0, + "micro_recall": 1.0, + "macro_recall": 1.0, + "weighted_recall": 1.0, } - for span_type in ["superset", "overlap", "exact"]: - check_metrics( - [y_true], - [y_pred], - expected=expected, - span_type=span_type, - ) + check_metrics( + [y_true], + [y_pred], + expected=expected, + span_type="value", + ) + + +def verify_all_metrics_structure(all_metrics, classes): + span_types = ["token", "overlap", "exact", "superset", "value"] + assert len(all_metrics.keys()) == 2 + summary_metrics = all_metrics["summary_metrics"] + assert len(summary_metrics.keys()) == len(span_types) + for span_type in span_types: + assert len(summary_metrics[span_type].keys()) == 9 + for metric in [ + "macro_f1", + "macro_precision", + "macro_recall", + "micro_precision", + "micro_recall", + "micro_f1", + "weighted_f1", + "weighted_precision", + "weighted_recall", + ]: + assert isinstance(summary_metrics[span_type][metric], float) + class_metrics = all_metrics["class_metrics"] + assert len(class_metrics) == len(span_types) + for span_type in span_types: + assert len(class_metrics[span_type]) == len(classes) + for cls_, metrics in class_metrics[span_type].items(): + assert cls_ in classes + assert len(metrics.keys()) == 6 + for metric in ["f1", "precision", "recall"]: + assert isinstance(metrics[metric], float) + for metric in ["false_positives", "true_positives", "false_negatives"]: + assert isinstance(metrics[metric], int) + + +def test_get_all_metrics(): + text = "Alert: Pepsi Company stocks are up today April 5, 2010 and no one profited." + + y_mixed = extend_label( + text, + [ + {"start": 21, "end": 28, "label": "entity"}, + {"start": 62, "end": 65, "label": "date"}, + ], + 5, + ) + extend_label( + text, + [ + {"start": 7, "end": 20, "label": "entity"}, + {"start": 41, "end": 54, "label": "date"}, + ], + 5, + ) + y_true = extend_label( + text, + [ + {"start": 7, "end": 20, "label": "entity"}, + {"start": 41, "end": 54, "label": "date"}, + ], + 10, + ) + + all_metrics = get_all_metrics(preds=y_mixed, labels=y_true) + verify_all_metrics_structure(all_metrics=all_metrics, classes=["entity", "date"]) + + +def test_get_all_metrics_missing_class(): + y_true = [[{"start": 0, "end": 7, "text": "a and b", "label": "entity"}]] + span_types = ["token", "overlap", "exact", "superset", "value"] + + all_metrics_2_classes = get_all_metrics( + preds=y_true, labels=y_true, field_names=["entity", "date"] + ) + verify_all_metrics_structure( + all_metrics=all_metrics_2_classes, classes=["entity", "date"] + ) + + all_metrics_1_classes = get_all_metrics( + preds=y_true, labels=y_true, field_names=["entity"] + ) + verify_all_metrics_structure(all_metrics=all_metrics_1_classes, classes=["entity"]) + + for m in ["f1", "precision", "recall"]: + # Macro should differ - but micro and weighted should be exactly the same. + for span_type in span_types: + assert ( + all_metrics_2_classes["summary_metrics"][span_type][f"macro_{m}"] + == 0.5 + * all_metrics_1_classes["summary_metrics"][span_type][f"macro_{m}"] + ) + assert ( + all_metrics_2_classes["summary_metrics"][span_type][f"micro_{m}"] + == all_metrics_1_classes["summary_metrics"][span_type][f"micro_{m}"] + ) + assert ( + all_metrics_2_classes["summary_metrics"][span_type][f"weighted_{m}"] + == all_metrics_1_classes["summary_metrics"][span_type][f"weighted_{m}"] + ) + + +@pytest.mark.parametrize("classes", [[], ["date"]]) +def test_all_metrics_no_class_match(classes): + y_true = [[{"start": 0, "end": 7, "text": "a and b", "label": "entity"}]] + all_metrics_0_classes = get_all_metrics( + preds=y_true, labels=y_true, field_names=classes + ) + verify_all_metrics_structure(all_metrics=all_metrics_0_classes, classes=classes) + + +@pytest.mark.parametrize("classes", [[], ["date"]]) +def test_empty_preds_metrics(classes): + all_metrics = get_all_metrics(preds=[], labels=[], field_names=classes) + verify_all_metrics_structure(all_metrics=all_metrics, classes=classes) + + +@pytest.mark.parametrize("classes", [[], ["date"]]) +def test_empty_preds_metrics(classes): + all_metrics = get_all_metrics(preds=[[]], labels=[[]], field_names=classes) + verify_all_metrics_structure(all_metrics=all_metrics, classes=classes)