From 191e21f424faad0c7cae0a1d9e588f380f991c5a Mon Sep 17 00:00:00 2001 From: Ashwin Vaidya Date: Wed, 2 Oct 2024 13:46:23 +0200 Subject: [PATCH] Add PIMO (#2329) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * PIMO (#1726) * update Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * test binclf curves numpy and numba and fixes Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * correct som docstrings Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * torch interface and tests Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * torch interface and tests Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * constants regrouped in dataclass as class vars Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * result class was unneccesary for per_image_binclf_curve Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * factorize function _get_threshs_minmax_linspace Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * small docs fixes Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add pimo numpy version and test Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * move validation Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add `shared_fpr_metric` option Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add pimo torch functional version and test Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add torchmetrics interface and test Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * renames and put things in init Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * validate inputs in result objects Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * result objects to from dict and tests Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add save and load methods to result objects and test Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor validations and minor changes Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * test result objects' properties Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * minor refactors Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add missing docstrings Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * minore vocabulary fix for consistency Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add per image scores statistics and test it Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor constants notation Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add stats tests and test it Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * change the meaning of AUPIMO.num_thresh Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * interface to format pairwise test results Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * improve doc Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add optional `paths` to result objects and some minor fixes and refactors Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * remove frozen from dataclasses and some done todos Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * review headers Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * doc modifs Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor `score_less_than_thresh` in `_binclf_one_curve_python` Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * correct license comments Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * fix doc Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * numba as extra requirement Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor copyrights from jpcbertoldo Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * remove from __future__ import annotations Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor validations names Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * dedupe file path validation Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * fix tests Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * Add todo Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor enums Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * only logger.warning Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor test imports Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor docs Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * refactor some docs Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * correct pre commit errors Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * remove author tag Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add thrid party program Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * Update src/anomalib/metrics/per_image/pimo.py * move HAS_NUMBA Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * remove PIMOSharedFPRMetric Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * make torchmetrics compute avg by dft Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * pre-commit hooks corrections Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * correct numpy.trapezoid Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> --------- Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Co-authored-by: Samet Akcay * πŸ—‘οΈ Remove numba (#2313) * remove numba Signed-off-by: Ashwin Vaidya * fix pre-commit checks Signed-off-by: Ashwin Vaidya * add third-party-programs.txt Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya * πŸ—‘οΈ Remove unused methods (#2315) * remove numba Signed-off-by: Ashwin Vaidya * fix pre-commit checks Signed-off-by: Ashwin Vaidya * remove all unused methods Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya * PIMO: Port Numpy β†’ Torch (#2316) * remove numba Signed-off-by: Ashwin Vaidya * fix pre-commit checks Signed-off-by: Ashwin Vaidya * remove all unused methods Signed-off-by: Ashwin Vaidya * replace numpy with torch Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya * πŸ”¨Refactor methods across files (#2321) * remove numba Signed-off-by: Ashwin Vaidya * fix pre-commit checks Signed-off-by: Ashwin Vaidya * remove all unused methods Signed-off-by: Ashwin Vaidya * replace numpy with torch Signed-off-by: Ashwin Vaidya * refactor code Signed-off-by: Ashwin Vaidya * refactor move functional inside update remove path from the metric * Add changes from comments Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya * Remove model to model comparison (#2325) * rename to pimo Signed-off-by: Ashwin Vaidya * minor refactor Signed-off-by: Ashwin Vaidya * remove model to model comparison Signed-off-by: Ashwin Vaidya * fix test Signed-off-by: Ashwin Vaidya * PR comments Signed-off-by: Ashwin Vaidya * Minor refactor Signed-off-by: Ashwin Vaidya --------- Signed-off-by: Ashwin Vaidya * PR comments Signed-off-by: Ashwin Vaidya * Remove unused enums Signed-off-by: Ashwin Vaidya * update doc strings Signed-off-by: Ashwin Vaidya * update param names Signed-off-by: Ashwin Vaidya * add aupimo basic usage tutorial notebook (#2330) * add aupimo basic usage tutorial notebook Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * update scipy import Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * add cite us Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * minor Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * modify texts and add illustration Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> * udpate working dir Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> --------- Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> --------- Signed-off-by: jpcbertoldo <24547377+jpcbertoldo@users.noreply.github.com> Signed-off-by: Ashwin Vaidya Co-authored-by: Joao P C Bertoldo <24547377+jpcbertoldo@users.noreply.github.com> Co-authored-by: Samet Akcay --- notebooks/700_metrics/701a_aupimo.ipynb | 549 ++++++++++++++ notebooks/700_metrics/roc_pro_pimo.svg | 690 ++++++++++++++++++ src/anomalib/data/utils/path.py | 14 +- src/anomalib/metrics/__init__.py | 3 + src/anomalib/metrics/pimo/__init__.py | 23 + src/anomalib/metrics/pimo/_validate.py | 427 +++++++++++ .../pimo/binary_classification_curve.py | 334 +++++++++ src/anomalib/metrics/pimo/dataclasses.py | 226 ++++++ src/anomalib/metrics/pimo/functional.py | 355 +++++++++ src/anomalib/metrics/pimo/pimo.py | 296 ++++++++ src/anomalib/metrics/pimo/utils.py | 19 + tests/unit/data/utils/test_path.py | 6 + tests/unit/metrics/pimo/__init__.py | 8 + .../pimo/test_binary_classification_curve.py | 423 +++++++++++ tests/unit/metrics/pimo/test_pimo.py | 368 ++++++++++ third-party-programs.txt | 4 + 16 files changed, 3744 insertions(+), 1 deletion(-) create mode 100644 notebooks/700_metrics/701a_aupimo.ipynb create mode 100644 notebooks/700_metrics/roc_pro_pimo.svg create mode 100644 src/anomalib/metrics/pimo/__init__.py create mode 100644 src/anomalib/metrics/pimo/_validate.py create mode 100644 src/anomalib/metrics/pimo/binary_classification_curve.py create mode 100644 src/anomalib/metrics/pimo/dataclasses.py create mode 100644 src/anomalib/metrics/pimo/functional.py create mode 100644 src/anomalib/metrics/pimo/pimo.py create mode 100644 src/anomalib/metrics/pimo/utils.py create mode 100644 tests/unit/metrics/pimo/__init__.py create mode 100644 tests/unit/metrics/pimo/test_binary_classification_curve.py create mode 100644 tests/unit/metrics/pimo/test_pimo.py diff --git a/notebooks/700_metrics/701a_aupimo.ipynb b/notebooks/700_metrics/701a_aupimo.ipynb new file mode 100644 index 0000000000..e6333df6df --- /dev/null +++ b/notebooks/700_metrics/701a_aupimo.ipynb @@ -0,0 +1,549 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# AUPIMO\n", + "\n", + "Basic usage of the metric AUPIMO (pronounced \"a-u-pee-mo\")." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "# What is AUPIMO?\n", + "\n", + "The `Area Under the Per-Image Overlap [curve]` (AUPIMO) is a metric of recall (higher is better) designed for visual anomaly detection.\n", + "\n", + "Inspired by the [ROC](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) and [PRO](https://link.springer.com/article/10.1007/s11263-020-01400-4) curves, \n", + "\n", + "> AUPIMO is the area under a curve of True Positive Rate (TPR or _recall_) as a function of False Positive Rate (FPR) restricted to a fixed range. \n", + "\n", + "But:\n", + "- the TPR (Y-axis) is *per-image* (1 image = 1 curve/score);\n", + "- the FPR (X-axis) considers the (average of) **normal** images only; \n", + "- the FPR (X-axis) is in log scale and its range is [1e-5, 1e-4]\\* (harder detection task!).\n", + "\n", + "\\* The score (the area under the curve) is normalized to be in [0, 1].\n", + "\n", + "AUPIMO can be interpreted as\n", + "\n", + "> average segmentation recall in an image given that the model (nearly) does not yield false positives in normal images.\n", + "\n", + "References in the last cell.\n", + "\n", + "![AUROC vs. AUPRO vs. AUPIMO](./roc_pro_pimo.svg)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Install `anomalib` using `pip`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# TODO(jpcbertoldo): replace by `pip install anomalib` when AUPIMO is released # noqa: TD003\n", + "%pip install ../.." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Change the directory to have access to the datasets." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from pathlib import Path\n", + "\n", + "# NOTE: Provide the path to the dataset root directory.\n", + "# If the datasets is not downloaded, it will be downloaded\n", + "# to this directory.\n", + "dataset_root = Path.cwd().parent.parent / \"datasets\" / \"MVTec\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib.ticker import MaxNLocator, PercentFormatter\n", + "from scipy import stats\n", + "\n", + "from anomalib import TaskType\n", + "from anomalib.data import MVTec\n", + "from anomalib.engine import Engine\n", + "from anomalib.metrics import AUPIMO\n", + "from anomalib.models import Padim" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Module\n", + "\n", + "We will use dataset Leather from MVTec AD. \n", + "\n", + "> See the notebooks below for more details on datamodules. \n", + "> [github.com/openvinotoolkit/anomalib/tree/main/notebooks/100_datamodules]((https://github.com/openvinotoolkit/anomalib/tree/main/notebooks/100_datamodules))" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "task = TaskType.SEGMENTATION\n", + "datamodule = MVTec(\n", + " root=dataset_root,\n", + " category=\"leather\",\n", + " image_size=256,\n", + " train_batch_size=32,\n", + " eval_batch_size=32,\n", + " num_workers=8,\n", + " task=task,\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model\n", + "\n", + "We will use `PaDiM` (performance is not the best, but it is fast to train).\n", + "\n", + "> See the notebooks below for more details on models. \n", + "> [github.com/openvinotoolkit/anomalib/tree/main/notebooks/200_models](https://github.com/openvinotoolkit/anomalib/tree/main/notebooks/200_models)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Instantiate the model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "model = Padim(\n", + " # only use one layer to speed it up\n", + " layers=[\"layer1\"],\n", + " n_features=32,\n", + " backbone=\"resnet18\",\n", + " pre_trained=True,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Average AUPIMO (Basic)\n", + "\n", + "The easiest way to use AUPIMO is via the collection of pixel metrics in the engine.\n", + "\n", + "By default, the average AUPIMO is calculated." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "engine = Engine(\n", + " pixel_metrics=\"AUPIMO\", # others can be added\n", + " accelerator=\"auto\", # \\<\"cpu\", \"gpu\", \"tpu\", \"ipu\", \"hpu\", \"auto\">,\n", + " devices=1,\n", + " logger=False,\n", + ")\n", + "engine.fit(datamodule=datamodule, model=model)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "F1Score class exists for backwards compatibility. It will be removed in v1.1. Please use BinaryF1Score from torchmetrics instead\n", + "Metric `AUPIMO` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "58335955473a43dab43e586caf66aa11", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Testing: | | 0/? [00:00┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃ Test metric ┃ DataLoader 0 ┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚ image_AUROC β”‚ 0.9735053777694702 β”‚\n", + "β”‚ image_F1Score β”‚ 0.9518716335296631 β”‚\n", + "β”‚ pixel_AUPIMO β”‚ 0.6273086756193275 β”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n", + "\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Test metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┑━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "β”‚\u001b[36m \u001b[0m\u001b[36m image_AUROC \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.9735053777694702 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", + "β”‚\u001b[36m \u001b[0m\u001b[36m image_F1Score \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.9518716335296631 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", + "β”‚\u001b[36m \u001b[0m\u001b[36m pixel_AUPIMO \u001b[0m\u001b[36m \u001b[0mβ”‚\u001b[35m \u001b[0m\u001b[35m 0.6273086756193275 \u001b[0m\u001b[35m \u001b[0mβ”‚\n", + "β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "[{'pixel_AUPIMO': 0.6273086756193275,\n", + " 'image_AUROC': 0.9735053777694702,\n", + " 'image_F1Score': 0.9518716335296631}]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# will output the AUPIMO score on the test set\n", + "engine.test(datamodule=datamodule, model=model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Individual AUPIMO Scores (Detailed)\n", + "\n", + "AUPIMO assigns one recall score per anomalous image in the dataset.\n", + "\n", + "It is possible to access each of the individual AUPIMO scores and look at the distribution." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Collect the predictions and the ground truth." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ckpt_path is not provided. Model weights will not be loaded.\n", + "F1Score class exists for backwards compatibility. It will be removed in v1.1. Please use BinaryF1Score from torchmetrics instead\n", + "Metric `AUPIMO` will save all targets and predictions in buffer. For large datasets this may lead to large memory footprint.\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "678cb90805ee4b7bb1dd0c30944edab9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Predicting: | | 0/? [00:00" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "ax.hist(aupimo_result.aupimos.numpy(), bins=np.linspace(0, 1, 11), edgecolor=\"black\")\n", + "ax.set_ylabel(\"Count (number of images)\")\n", + "ax.yaxis.set_major_locator(MaxNLocator(5, integer=True))\n", + "ax.set_xlim(0, 1)\n", + "ax.set_xlabel(\"AUPIMO [%]\")\n", + "ax.xaxis.set_major_formatter(PercentFormatter(1))\n", + "ax.grid()\n", + "ax.set_title(\"AUPIMO distribution\")\n", + "fig # noqa: B018, RUF100" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cite Us\n", + "\n", + "AUPIMO was developed during Google Summer of Code 2023 (GSoC 2023) with the `anomalib` team from OpenVINO Toolkit.\n", + "\n", + "Our work was accepted to the British Machine Vision Conference 2024 (BMVC 2024).\n", + "\n", + "```bibtex\n", + "@misc{bertoldo2024aupimo,\n", + " title={{AUPIMO: Redefining Visual Anomaly Detection Benchmarks with High Speed and Low Tolerance}}, \n", + " author={Joao P. C. Bertoldo and Dick Ameln and Ashwin Vaidya and Samet AkΓ§ay},\n", + " year={2024},\n", + " eprint={2401.01984},\n", + " archivePrefix={arXiv},\n", + " primaryClass={cs.CV},\n", + " url={https://arxiv.org/abs/2401.01984}, \n", + "}\n", + "```\n", + "\n", + "Paper on arXiv: [arxiv.org/abs/2401.01984](https://arxiv.org/abs/2401.01984) (accepted to BMVC 2024)\n", + "\n", + "Medium post: [medium.com/p/c653ac30e802](https://medium.com/p/c653ac30e802)\n", + "\n", + "Official repository: [github.com/jpcbertoldo/aupimo](https://github.com/jpcbertoldo/aupimo) (numpy-only API and numba-accelerated versions available)\n", + "\n", + "GSoC 2023 page: [summerofcode.withgoogle.com/archive/2023/projects/SPMopugd](https://summerofcode.withgoogle.com/archive/2023/projects/SPMopugd)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "anomalib-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/700_metrics/roc_pro_pimo.svg b/notebooks/700_metrics/roc_pro_pimo.svg new file mode 100644 index 0000000000..b580e89d17 --- /dev/null +++ b/notebooks/700_metrics/roc_pro_pimo.svg @@ -0,0 +1,690 @@ + + + +image/svg+xmlEach curve summarizesthe test set with di + + +ff + + +erent aggregations. + + +ROC + + +PRO + + +One per image! + + +AUROC + + +AUPRO + + +AUPIMO + + +PIMO + + +i + + +i + + +Recall + + diff --git a/src/anomalib/data/utils/path.py b/src/anomalib/data/utils/path.py index 9c3f56273b..7bc61b27fe 100644 --- a/src/anomalib/data/utils/path.py +++ b/src/anomalib/data/utils/path.py @@ -142,13 +142,20 @@ def contains_non_printable_characters(path: str | Path) -> bool: return not printable_pattern.match(str(path)) -def validate_path(path: str | Path, base_dir: str | Path | None = None, should_exist: bool = True) -> Path: +def validate_path( + path: str | Path, + base_dir: str | Path | None = None, + should_exist: bool = True, + extensions: tuple[str, ...] | None = None, +) -> Path: """Validate the path. Args: path (str | Path): Path to validate. base_dir (str | Path): Base directory to restrict file access. should_exist (bool): If True, do not raise an exception if the path does not exist. + extensions (tuple[str, ...] | None): Accepted extensions for the path. An exception is raised if the + path does not have one of the accepted extensions. If None, no check is performed. Defaults to None. Returns: Path: Validated path. @@ -213,6 +220,11 @@ def validate_path(path: str | Path, base_dir: str | Path | None = None, should_e msg = f"Read or execute permissions denied for the path: {path}" raise PermissionError(msg) + # Check if the path has one of the accepted extensions + if extensions is not None and path.suffix not in extensions: + msg = f"Path extension is not accepted. Accepted extensions: {extensions}. Path: {path}" + raise ValueError(msg) + return path diff --git a/src/anomalib/metrics/__init__.py b/src/anomalib/metrics/__init__.py index 4c3eafa811..81bab3c93f 100644 --- a/src/anomalib/metrics/__init__.py +++ b/src/anomalib/metrics/__init__.py @@ -19,6 +19,7 @@ from .f1_max import F1Max from .f1_score import F1Score from .min_max import MinMax +from .pimo import AUPIMO, PIMO from .precision_recall_curve import BinaryPrecisionRecallCurve from .pro import PRO from .threshold import F1AdaptiveThreshold, ManualThreshold @@ -35,6 +36,8 @@ "ManualThreshold", "MinMax", "PRO", + "PIMO", + "AUPIMO", ] logger = logging.getLogger(__name__) diff --git a/src/anomalib/metrics/pimo/__init__.py b/src/anomalib/metrics/pimo/__init__.py new file mode 100644 index 0000000000..174f546e4d --- /dev/null +++ b/src/anomalib/metrics/pimo/__init__.py @@ -0,0 +1,23 @@ +"""Per-Image Metrics.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from .binary_classification_curve import ThresholdMethod +from .pimo import AUPIMO, PIMO, AUPIMOResult, PIMOResult + +__all__ = [ + # constants + "ThresholdMethod", + # result classes + "PIMOResult", + "AUPIMOResult", + # torchmetrics interfaces + "PIMO", + "AUPIMO", + "StatsOutliersPolicy", +] diff --git a/src/anomalib/metrics/pimo/_validate.py b/src/anomalib/metrics/pimo/_validate.py new file mode 100644 index 0000000000..f0ba7af4bf --- /dev/null +++ b/src/anomalib/metrics/pimo/_validate.py @@ -0,0 +1,427 @@ +"""Utils for validating arguments and results. + +TODO(jpcbertoldo): Move validations to a common place and reuse them across the codebase. +https://github.com/openvinotoolkit/anomalib/issues/2093 +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torch import Tensor + +from .utils import images_classes_from_masks + +logger = logging.getLogger(__name__) + + +def is_num_thresholds_gte2(num_thresholds: int) -> None: + """Validate the number of thresholds is a positive integer >= 2.""" + if not isinstance(num_thresholds, int): + msg = f"Expected the number of thresholds to be an integer, but got {type(num_thresholds)}" + raise TypeError(msg) + + if num_thresholds < 2: + msg = f"Expected the number of thresholds to be larger than 1, but got {num_thresholds}" + raise ValueError(msg) + + +def is_same_shape(*args) -> None: + """Works both for tensors and ndarrays.""" + assert len(args) > 0 + shapes = sorted({tuple(arg.shape) for arg in args}) + if len(shapes) > 1: + msg = f"Expected arguments to have the same shape, but got {shapes}" + raise ValueError(msg) + + +def is_rate(rate: float | int, zero_ok: bool, one_ok: bool) -> None: + """Validates a rate parameter. + + Args: + rate (float | int): The rate to be validated. + zero_ok (bool): Flag indicating if rate can be 0. + one_ok (bool): Flag indicating if rate can be 1. + """ + if not isinstance(rate, float | int): + msg = f"Expected rate to be a float or int, but got {type(rate)}." + raise TypeError(msg) + + if rate < 0.0 or rate > 1.0: + msg = f"Expected rate to be in [0, 1], but got {rate}." + raise ValueError(msg) + + if not zero_ok and rate == 0.0: + msg = "Rate cannot be 0." + raise ValueError(msg) + + if not one_ok and rate == 1.0: + msg = "Rate cannot be 1." + raise ValueError(msg) + + +def is_rate_range(bounds: tuple[float, float]) -> None: + """Validates the range of rates within the bounds. + + Args: + bounds (tuple[float, float]): The lower and upper bounds of the rates. + """ + if not isinstance(bounds, tuple): + msg = f"Expected the bounds to be a tuple, but got {type(bounds)}" + raise TypeError(msg) + + if len(bounds) != 2: + msg = f"Expected the bounds to be a tuple of length 2, but got {len(bounds)}" + raise ValueError(msg) + + lower, upper = bounds + is_rate(lower, zero_ok=False, one_ok=False) + is_rate(upper, zero_ok=False, one_ok=True) + + if lower >= upper: + msg = f"Expected the upper bound to be larger than the lower bound, but got {upper=} <= {lower=}" + raise ValueError(msg) + + +def is_valid_threshold(thresholds: Tensor) -> None: + """Validate that the thresholds are valid and monotonically increasing.""" + if not isinstance(thresholds, Tensor): + msg = f"Expected thresholds to be an Tensor, but got {type(thresholds)}" + raise TypeError(msg) + + if thresholds.ndim != 1: + msg = f"Expected thresholds to be 1D, but got {thresholds.ndim}" + raise ValueError(msg) + + if not thresholds.dtype.is_floating_point: + msg = f"Expected thresholds to be of float type, but got Tensor with dtype {thresholds.dtype}" + raise TypeError(msg) + + # make sure they are strictly increasing + if not torch.all(torch.diff(thresholds) > 0): + msg = "Expected thresholds to be strictly increasing, but it is not." + raise ValueError(msg) + + +def validate_threshold_bounds(threshold_bounds: tuple[float, float]) -> None: + if not isinstance(threshold_bounds, tuple): + msg = f"Expected threshold bounds to be a tuple, but got {type(threshold_bounds)}." + raise TypeError(msg) + + if len(threshold_bounds) != 2: + msg = f"Expected threshold bounds to be a tuple of length 2, but got {len(threshold_bounds)}." + raise ValueError(msg) + + lower, upper = threshold_bounds + + if not isinstance(lower, float): + msg = f"Expected lower threshold bound to be a float, but got {type(lower)}." + raise TypeError(msg) + + if not isinstance(upper, float): + msg = f"Expected upper threshold bound to be a float, but got {type(upper)}." + raise TypeError(msg) + + if upper <= lower: + msg = f"Expected the upper bound to be greater than the lower bound, but got {upper} <= {lower}." + raise ValueError(msg) + + +def is_anomaly_maps(anomaly_maps: Tensor) -> None: + if anomaly_maps.ndim != 3: + msg = f"Expected anomaly maps have 3 dimensions (N, H, W), but got {anomaly_maps.ndim} dimensions" + raise ValueError(msg) + + if not anomaly_maps.dtype.is_floating_point: + msg = ( + "Expected anomaly maps to be an floating Tensor with anomaly scores," + f" but got Tensor with dtype {anomaly_maps.dtype}" + ) + raise TypeError(msg) + + +def is_masks(masks: Tensor) -> None: + if masks.ndim != 3: + msg = f"Expected masks have 3 dimensions (N, H, W), but got {masks.ndim} dimensions" + raise ValueError(msg) + + if masks.dtype == torch.bool: + pass + elif masks.dtype.is_floating_point: + msg = ( + "Expected masks to be an integer or boolean Tensor with ground truth labels, " + f"but got Tensor with dtype {masks.dtype}" + ) + raise TypeError(msg) + else: + # assumes the type to be (signed or unsigned) integer + # this will change with the dataclass refactor + masks_unique_vals = torch.unique(masks) + if torch.any((masks_unique_vals != 0) & (masks_unique_vals != 1)): + msg = ( + "Expected masks to be a *binary* Tensor with ground truth labels, " + f"but got Tensor with unique values {sorted(masks_unique_vals)}" + ) + raise ValueError(msg) + + +def is_binclf_curves(binclf_curves: Tensor, valid_thresholds: Tensor | None) -> None: + if binclf_curves.ndim != 4: + msg = f"Expected binclf curves to be 4D, but got {binclf_curves.ndim}D" + raise ValueError(msg) + + if binclf_curves.shape[-2:] != (2, 2): + msg = f"Expected binclf curves to have shape (..., 2, 2), but got {binclf_curves.shape}" + raise ValueError(msg) + + if binclf_curves.dtype != torch.int64: + msg = f"Expected binclf curves to have dtype int64, but got {binclf_curves.dtype}." + raise TypeError(msg) + + if (binclf_curves < 0).any(): + msg = "Expected binclf curves to have non-negative values, but got negative values." + raise ValueError(msg) + + neg = binclf_curves[:, :, 0, :].sum(axis=-1) # (num_images, num_thresholds) + + if (neg != neg[:, :1]).any(): + msg = "Expected binclf curves to have the same number of negatives per image for every thresh." + raise ValueError(msg) + + pos = binclf_curves[:, :, 1, :].sum(axis=-1) # (num_images, num_thresholds) + + if (pos != pos[:, :1]).any(): + msg = "Expected binclf curves to have the same number of positives per image for every thresh." + raise ValueError(msg) + + if valid_thresholds is None: + return + + if binclf_curves.shape[1] != valid_thresholds.shape[0]: + msg = ( + "Expected the binclf curves to have as many confusion matrices as the thresholds sequence, " + f"but got {binclf_curves.shape[1]} and {valid_thresholds.shape[0]}" + ) + raise RuntimeError(msg) + + +def is_images_classes(images_classes: Tensor) -> None: + if images_classes.ndim != 1: + msg = f"Expected image classes to be 1D, but got {images_classes.ndim}D." + raise ValueError(msg) + + if images_classes.dtype == torch.bool: + pass + elif images_classes.dtype.is_floating_point: + msg = ( + "Expected image classes to be an integer or boolean Tensor with ground truth labels, " + f"but got Tensor with dtype {images_classes.dtype}" + ) + raise TypeError(msg) + else: + # assumes the type to be (signed or unsigned) integer + # this will change with the dataclass refactor + unique_vals = torch.unique(images_classes) + if torch.any((unique_vals != 0) & (unique_vals != 1)): + msg = ( + "Expected image classes to be a *binary* Tensor with ground truth labels, " + f"but got Tensor with unique values {sorted(unique_vals)}" + ) + raise ValueError(msg) + + +def is_rates(rates: Tensor, nan_allowed: bool) -> None: + if rates.ndim != 1: + msg = f"Expected rates to be 1D, but got {rates.ndim}D." + raise ValueError(msg) + + if not rates.dtype.is_floating_point: + msg = f"Expected rates to have dtype of float type, but got {rates.dtype}." + raise ValueError(msg) + + isnan_mask = torch.isnan(rates) + if nan_allowed: + # if they are all nan, then there is nothing to validate + if isnan_mask.all(): + return + valid_values = rates[~isnan_mask] + elif isnan_mask.any(): + msg = "Expected rates to not contain NaN values, but got NaN values." + raise ValueError(msg) + else: + valid_values = rates + + if (valid_values < 0).any(): + msg = "Expected rates to have values in the interval [0, 1], but got values < 0." + raise ValueError(msg) + + if (valid_values > 1).any(): + msg = "Expected rates to have values in the interval [0, 1], but got values > 1." + raise ValueError(msg) + + +def is_rate_curve(rate_curve: Tensor, nan_allowed: bool, decreasing: bool) -> None: + is_rates(rate_curve, nan_allowed=nan_allowed) + + diffs = torch.diff(rate_curve) + diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs + + if decreasing and (diffs_valid > 0).any(): + msg = "Expected rate curve to be monotonically decreasing, but got non-monotonically decreasing values." + raise ValueError(msg) + + if not decreasing and (diffs_valid < 0).any(): + msg = "Expected rate curve to be monotonically increasing, but got non-monotonically increasing values." + raise ValueError(msg) + + +def is_per_image_rate_curves(rate_curves: Tensor, nan_allowed: bool, decreasing: bool | None) -> None: + if rate_curves.ndim != 2: + msg = f"Expected per-image rate curves to be 2D, but got {rate_curves.ndim}D." + raise ValueError(msg) + + if not rate_curves.dtype.is_floating_point: + msg = f"Expected per-image rate curves to have dtype of float type, but got {rate_curves.dtype}." + raise ValueError(msg) + + isnan_mask = torch.isnan(rate_curves) + if nan_allowed: + # if they are all nan, then there is nothing to validate + if isnan_mask.all(): + return + valid_values = rate_curves[~isnan_mask] + elif isnan_mask.any(): + msg = "Expected per-image rate curves to not contain NaN values, but got NaN values." + raise ValueError(msg) + else: + valid_values = rate_curves + + if (valid_values < 0).any(): + msg = "Expected per-image rate curves to have values in the interval [0, 1], but got values < 0." + raise ValueError(msg) + + if (valid_values > 1).any(): + msg = "Expected per-image rate curves to have values in the interval [0, 1], but got values > 1." + raise ValueError(msg) + + if decreasing is None: + return + + diffs = torch.diff(rate_curves, axis=1) + diffs_valid = diffs[~torch.isnan(diffs)] if nan_allowed else diffs + + if decreasing and (diffs_valid > 0).any(): + msg = ( + "Expected per-image rate curves to be monotonically decreasing, " + "but got non-monotonically decreasing values." + ) + raise ValueError(msg) + + if not decreasing and (diffs_valid < 0).any(): + msg = ( + "Expected per-image rate curves to be monotonically increasing, " + "but got non-monotonically increasing values." + ) + raise ValueError(msg) + + +def is_scores_batch(scores_batch: torch.Tensor) -> None: + """scores_batch (torch.Tensor): floating (N, D).""" + if not isinstance(scores_batch, torch.Tensor): + msg = f"Expected `scores_batch` to be an torch.Tensor, but got {type(scores_batch)}" + raise TypeError(msg) + + if not scores_batch.dtype.is_floating_point: + msg = ( + "Expected `scores_batch` to be an floating torch.Tensor with anomaly scores_batch," + f" but got torch.Tensor with dtype {scores_batch.dtype}" + ) + raise TypeError(msg) + + if scores_batch.ndim != 2: + msg = f"Expected `scores_batch` to be 2D, but got {scores_batch.ndim}" + raise ValueError(msg) + + +def is_gts_batch(gts_batch: torch.Tensor) -> None: + """gts_batch (torch.Tensor): boolean (N, D).""" + if not isinstance(gts_batch, torch.Tensor): + msg = f"Expected `gts_batch` to be an torch.Tensor, but got {type(gts_batch)}" + raise TypeError(msg) + + if gts_batch.dtype != torch.bool: + msg = ( + "Expected `gts_batch` to be an boolean torch.Tensor with anomaly scores_batch," + f" but got torch.Tensor with dtype {gts_batch.dtype}" + ) + raise TypeError(msg) + + if gts_batch.ndim != 2: + msg = f"Expected `gts_batch` to be 2D, but got {gts_batch.ndim}" + raise ValueError(msg) + + +def has_at_least_one_anomalous_image(masks: torch.Tensor) -> None: + is_masks(masks) + image_classes = images_classes_from_masks(masks) + if (image_classes == 1).sum() == 0: + msg = "Expected at least one ANOMALOUS image, but found none." + raise ValueError(msg) + + +def has_at_least_one_normal_image(masks: torch.Tensor) -> None: + is_masks(masks) + image_classes = images_classes_from_masks(masks) + if (image_classes == 0).sum() == 0: + msg = "Expected at least one NORMAL image, but found none." + raise ValueError(msg) + + +def joint_validate_thresholds_shared_fpr(thresholds: torch.Tensor, shared_fpr: torch.Tensor) -> None: + if thresholds.shape[0] != shared_fpr.shape[0]: + msg = ( + "Expected `thresholds` and `shared_fpr` to have the same number of elements, " + f"but got {thresholds.shape[0]} != {shared_fpr.shape[0]}" + ) + raise ValueError(msg) + + +def is_per_image_tprs(per_image_tprs: torch.Tensor, image_classes: torch.Tensor) -> None: + is_images_classes(image_classes) + # general validations + is_per_image_rate_curves( + per_image_tprs, + nan_allowed=True, # normal images have NaN TPRs + decreasing=None, # not checked here + ) + + # specific to anomalous images + is_per_image_rate_curves( + per_image_tprs[image_classes == 1], + nan_allowed=False, + decreasing=True, + ) + + # specific to normal images + normal_images_tprs = per_image_tprs[image_classes == 0] + if not normal_images_tprs.isnan().all(): + msg = "Expected all normal images to have NaN TPRs, but some have non-NaN values." + raise ValueError(msg) + + +def is_per_image_scores(per_image_scores: torch.Tensor) -> None: + if per_image_scores.ndim != 1: + msg = f"Expected per-image scores to be 1D, but got {per_image_scores.ndim}D." + raise ValueError(msg) + + +def is_image_class(image_class: int) -> None: + if image_class not in {0, 1}: + msg = f"Expected image class to be either 0 for 'normal' or 1 for 'anomalous', but got {image_class}." + raise ValueError(msg) diff --git a/src/anomalib/metrics/pimo/binary_classification_curve.py b/src/anomalib/metrics/pimo/binary_classification_curve.py new file mode 100644 index 0000000000..1a80944041 --- /dev/null +++ b/src/anomalib/metrics/pimo/binary_classification_curve.py @@ -0,0 +1,334 @@ +"""Binary classification curve (numpy-only implementation). + +A binary classification (binclf) matrix (TP, FP, FN, TN) is evaluated at multiple thresholds. + +The thresholds are shared by all instances/images, but their binclf are computed independently for each instance/image. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import itertools +import logging +from enum import Enum +from functools import partial + +import numpy as np +import torch + +from . import _validate + +logger = logging.getLogger(__name__) + + +class ThresholdMethod(Enum): + """Sequence of thresholds to use.""" + + GIVEN: str = "given" + MINMAX_LINSPACE: str = "minmax-linspace" + MEAN_FPR_OPTIMIZED: str = "mean-fpr-optimized" + + +def _binary_classification_curve(scores: np.ndarray, gts: np.ndarray, thresholds: np.ndarray) -> np.ndarray: + """One binary classification matrix at each threshold. + + In the case where the thresholds are given (i.e. not considering all possible thresholds based on the scores), + this weird-looking function is faster than the two options in `torchmetrics` on the CPU: + - `_binary_precision_recall_curve_update_vectorized` + - `_binary_precision_recall_curve_update_loop` + (both in module `torchmetrics.functional.classification.precision_recall_curve` in `torchmetrics==1.1.0`). + Note: VALIDATION IS NOT DONE HERE. Make sure to validate the arguments before calling this function. + + Args: + scores (np.ndarray): Anomaly scores (D,). + gts (np.ndarray): Binary (bool) ground truth of shape (D,). + thresholds (np.ndarray): Sequence of thresholds in ascending order (K,). + + Returns: + np.ndarray: Binary classification matrix curve (K, 2, 2) + Details: `anomalib.metrics.per_image.binclf_curve_numpy.binclf_multiple_curves`. + """ + num_th = len(thresholds) + + # POSITIVES + scores_positives = scores[gts] + # the sorting is very important for the algorithm to work and the speedup + scores_positives = np.sort(scores_positives) + # variable updated in the loop; start counting with lowest thresh ==> everything is predicted as positive + num_pos = current_count_tp = scores_positives.size + tps = np.empty((num_th,), dtype=np.int64) + + # NEGATIVES + # same thing but for the negative samples + scores_negatives = scores[~gts] + scores_negatives = np.sort(scores_negatives) + num_neg = current_count_fp = scores_negatives.size + fps = np.empty((num_th,), dtype=np.int64) + + def score_less_than_thresh(score: float, thresh: float) -> bool: + return score < thresh + + # it will progressively drop the scores that are below the current thresh + for thresh_idx, thresh in enumerate(thresholds): + # UPDATE POSITIVES + # < becasue it is the same as ~(>=) + num_drop = sum(1 for _ in itertools.takewhile(partial(score_less_than_thresh, thresh=thresh), scores_positives)) + scores_positives = scores_positives[num_drop:] + current_count_tp -= num_drop + tps[thresh_idx] = current_count_tp + + # UPDATE NEGATIVES + # same with the negatives + num_drop = sum(1 for _ in itertools.takewhile(partial(score_less_than_thresh, thresh=thresh), scores_negatives)) + scores_negatives = scores_negatives[num_drop:] + current_count_fp -= num_drop + fps[thresh_idx] = current_count_fp + + # deduce the rest of the matrix counts + fns = num_pos * np.ones((num_th,), dtype=np.int64) - tps + tns = num_neg * np.ones((num_th,), dtype=np.int64) - fps + + # sequence of dimensions is (thresholds, true class, predicted class) (see docstring) + return np.stack( + [ + np.stack([tns, fps], axis=-1), + np.stack([fns, tps], axis=-1), + ], + axis=-1, + ).transpose(0, 2, 1) + + +def binary_classification_curve( + scores_batch: torch.Tensor, + gts_batch: torch.Tensor, + thresholds: torch.Tensor, +) -> torch.Tensor: + """Returns a binary classification matrix at each threshold for each image in the batch. + + This is a wrapper around `_binary_classification_curve`. + Validation of the arguments is done here (not in the actual implementation functions). + + Note: predicted as positive condition is `score >= thresh`. + + Args: + scores_batch (torch.Tensor): Anomaly scores (N, D,). + gts_batch (torch.Tensor): Binary (bool) ground truth of shape (N, D,). + thresholds (torch.Tensor): Sequence of thresholds in ascending order (K,). + + Returns: + torch.Tensor: Binary classification matrix curves (N, K, 2, 2) + + The last two dimensions are the confusion matrix (ground truth, predictions) + So for each thresh it gives: + - `tp`: `[... , 1, 1]` + - `fp`: `[... , 0, 1]` + - `fn`: `[... , 1, 0]` + - `tn`: `[... , 0, 0]` + + `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: + - `tp` stands for `true positive` + - `fp` stands for `false positive` + - `fn` stands for `false negative` + - `tn` stands for `true negative` + + The numbers in each confusion matrix are the counts (not the ratios). + + Counts are relative to each instance (i.e. from 0 to D, e.g. the total is the number of pixels in the image). + + Thresholds are shared across all instances, so all confusion matrices, for instance, + at position [:, 0, :, :] are relative to the 1st threshold in `thresholds`. + + Thresholds are sorted in ascending order. + """ + _validate.is_scores_batch(scores_batch) + _validate.is_gts_batch(gts_batch) + _validate.is_same_shape(scores_batch, gts_batch) + _validate.is_valid_threshold(thresholds) + # TODO(ashwinvaidya17): this is kept as numpy for now because it is much faster. + # TEMP-0 + result = np.vectorize(_binary_classification_curve, signature="(n),(n),(k)->(k,2,2)")( + scores_batch.detach().cpu().numpy(), + gts_batch.detach().cpu().numpy(), + thresholds.detach().cpu().numpy(), + ) + return torch.from_numpy(result).to(scores_batch.device) + + +def _get_linspaced_thresholds(anomaly_maps: torch.Tensor, num_thresholds: int) -> torch.Tensor: + """Get thresholds linearly spaced between the min and max of the anomaly maps.""" + _validate.is_num_thresholds_gte2(num_thresholds) + # this operation can be a bit expensive + thresh_low, thresh_high = thresh_bounds = (anomaly_maps.min().item(), anomaly_maps.max().item()) + try: + _validate.validate_threshold_bounds(thresh_bounds) + except ValueError as ex: + msg = f"Invalid threshold bounds computed from the given anomaly maps. Cause: {ex}" + raise ValueError(msg) from ex + return torch.linspace(thresh_low, thresh_high, num_thresholds, dtype=anomaly_maps.dtype) + + +def threshold_and_binary_classification_curve( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, + threshold_choice: ThresholdMethod | str = ThresholdMethod.MINMAX_LINSPACE, + thresholds: torch.Tensor | None = None, + num_thresholds: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return thresholds and binary classification matrix at each threshold for each image in the batch. + + Args: + anomaly_maps (torch.Tensor): Anomaly score maps of shape (N, H, W) + masks (torch.Tensor): Binary ground truth masks of shape (N, H, W) + threshold_choice (str, optional): Sequence of thresholds to use. Defaults to THRESH_SEQUENCE_MINMAX_LINSPACE. + thresholds (torch.Tensor, optional): Sequence of thresholds to use. + Only applicable when threshold_choice is THRESH_SEQUENCE_GIVEN. + num_thresholds (int, optional): Number of thresholds between the min and max of the anomaly maps. + Only applicable when threshold_choice is THRESH_SEQUENCE_MINMAX_LINSPACE. + + Returns: + tuple[torch.Tensor, torch.Tensor]: + [0] Thresholds of shape (K,) and dtype is the same as `anomaly_maps.dtype`. + + [1] Binary classification matrices of shape (N, K, 2, 2) + + N: number of images/instances + K: number of thresholds + + The last two dimensions are the confusion matrix (ground truth, predictions) + So for each thresh it gives: + - `tp`: `[... , 1, 1]` + - `fp`: `[... , 0, 1]` + - `fn`: `[... , 1, 0]` + - `tn`: `[... , 0, 0]` + + `t` is for `true` and `f` is for `false`, `p` is for `positive` and `n` is for `negative`, so: + - `tp` stands for `true positive` + - `fp` stands for `false positive` + - `fn` stands for `false negative` + - `tn` stands for `true negative` + + The numbers in each confusion matrix are the counts of pixels in the image (not the ratios). + + Thresholds are shared across all images, so all confusion matrices, for instance, + at position [:, 0, :, :] are relative to the 1st threshold in `thresholds`. + + Thresholds are sorted in ascending order. + """ + threshold_choice = ThresholdMethod(threshold_choice) + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) + + if threshold_choice == ThresholdMethod.GIVEN: + assert thresholds is not None + _validate.is_valid_threshold(thresholds) + if num_thresholds is not None: + logger.warning( + "Argument `num_thresholds` was given, " + f"but it is ignored because `thresholds_choice` is '{threshold_choice.value}'.", + ) + thresholds = thresholds.to(anomaly_maps.dtype) + + elif threshold_choice == ThresholdMethod.MINMAX_LINSPACE: + assert num_thresholds is not None + if thresholds is not None: + logger.warning( + "Argument `thresholds_given` was given, " + f"but it is ignored because `thresholds_choice` is '{threshold_choice.value}'.", + ) + # `num_thresholds` is validated in the function below + thresholds = _get_linspaced_thresholds(anomaly_maps, num_thresholds) + + elif threshold_choice == ThresholdMethod.MEAN_FPR_OPTIMIZED: + raise NotImplementedError(f"TODO implement {threshold_choice.value}") # noqa: EM102 + + else: + msg = ( + f"Expected `threshs_choice` to be from {list(ThresholdMethod.__members__)}," + f" but got '{threshold_choice.value}'" + ) + raise NotImplementedError(msg) + + # keep the batch dimension and flatten the rest + scores_batch = anomaly_maps.reshape(anomaly_maps.shape[0], -1) + gts_batch = masks.reshape(masks.shape[0], -1).to(bool) # make sure it is boolean + + binclf_curves = binary_classification_curve(scores_batch, gts_batch, thresholds) + + num_images = anomaly_maps.shape[0] + + try: + _validate.is_binclf_curves(binclf_curves, valid_thresholds=thresholds) + + # these two validations cannot be done in `_validate.binclf_curves` because it does not have access to the + # original shapes of `anomaly_maps` + if binclf_curves.shape[0] != num_images: + msg = ( + "Expected `binclf_curves` to have the same number of images as `anomaly_maps`, " + f"but got {binclf_curves.shape[0]} and {anomaly_maps.shape[0]}" + ) + raise RuntimeError(msg) + + except (TypeError, ValueError) as ex: + msg = f"Invalid `binclf_curves` was computed. Cause: {ex}" + raise RuntimeError(msg) from ex + + return thresholds, binclf_curves + + +def per_image_tpr(binclf_curves: torch.Tensor) -> torch.Tensor: + """True positive rates (TPR) for image for each thresh. + + TPR = TP / P = TP / (TP + FN) + + TP: true positives + FM: false negatives + P: positives (TP + FN) + + Args: + binclf_curves (torch.Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + + Returns: + torch.Tensor: shape (N, K), dtype float64 + N: number of images + K: number of thresholds + + Thresholds are sorted in ascending order, so TPR is in descending order. + """ + # shape: (num images, num thresholds) + tps = binclf_curves[..., 1, 1] + pos = binclf_curves[..., 1, :].sum(axis=2) # 2 was the 3 originally + + # tprs will be nan if pos == 0 (normal image), which is expected + return tps.to(torch.float64) / pos.to(torch.float64) + + +def per_image_fpr(binclf_curves: torch.Tensor) -> torch.Tensor: + """False positive rates (TPR) for image for each thresh. + + FPR = FP / N = FP / (FP + TN) + + FP: false positives + TN: true negatives + N: negatives (FP + TN) + + Args: + binclf_curves (torch.Tensor): Binary classification matrix curves (N, K, 2, 2). See `per_image_binclf_curve`. + + Returns: + torch.Tensor: shape (N, K), dtype float64 + N: number of images + K: number of thresholds + + Thresholds are sorted in ascending order, so FPR is in descending order. + """ + # shape: (num images, num thresholds) + fps = binclf_curves[..., 0, 1] + neg = binclf_curves[..., 0, :].sum(axis=2) # 2 was the 3 originally + + # it can be `nan` if an anomalous image is fully covered by the mask + return fps.to(torch.float64) / neg.to(torch.float64) diff --git a/src/anomalib/metrics/pimo/dataclasses.py b/src/anomalib/metrics/pimo/dataclasses.py new file mode 100644 index 0000000000..0c5aeb025d --- /dev/null +++ b/src/anomalib/metrics/pimo/dataclasses.py @@ -0,0 +1,226 @@ +"""Dataclasses for PIMO metrics.""" + +# Based on the code: https://github.com/jpcbertoldo/aupimo +# +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass, field + +import torch + +from . import _validate, functional + + +@dataclass +class PIMOResult: + """Per-Image Overlap (PIMO, pronounced pee-mo) curve. + + This interface gathers the PIMO curve data and metadata and provides several utility methods. + + Notation: + - N: number of images + - K: number of thresholds + - FPR: False Positive Rate + - TPR: True Positive Rate + + Attributes: + thresholds (torch.Tensor): sequence of K (monotonically increasing) thresholds used to compute the PIMO curve + shared_fpr (torch.Tensor): K values of the shared FPR metric at the corresponding thresholds + per_image_tprs (torch.Tensor): for each of the N images, the K values of in-image TPR at the corresponding + thresholds + """ + + # data + thresholds: torch.Tensor = field(repr=False) # shape => (K,) + shared_fpr: torch.Tensor = field(repr=False) # shape => (K,) + per_image_tprs: torch.Tensor = field(repr=False) # shape => (N, K) + + @property + def num_threshsholds(self) -> int: + """Number of thresholds.""" + return self.thresholds.shape[0] + + @property + def num_images(self) -> int: + """Number of images.""" + return self.per_image_tprs.shape[0] + + @property + def image_classes(self) -> torch.Tensor: + """Image classes (0: normal, 1: anomalous). + + Deduced from the per-image TPRs. + If any TPR value is not NaN, the image is considered anomalous. + """ + return (~torch.isnan(self.per_image_tprs)).any(dim=1).to(torch.int32) + + def __post_init__(self) -> None: + """Validate the inputs for the result object are consistent.""" + try: + _validate.is_valid_threshold(self.thresholds) + _validate.is_rate_curve(self.shared_fpr, nan_allowed=False, decreasing=True) # is_shared_apr + _validate.is_per_image_tprs(self.per_image_tprs, self.image_classes) + + except (TypeError, ValueError) as ex: + msg = f"Invalid inputs for {self.__class__.__name__} object. Cause: {ex}." + raise TypeError(msg) from ex + + if self.thresholds.shape != self.shared_fpr.shape: + msg = ( + f"Invalid {self.__class__.__name__} object. Attributes have inconsistent shapes: " + f"{self.thresholds.shape=} != {self.shared_fpr.shape=}." + ) + raise TypeError(msg) + + if self.thresholds.shape[0] != self.per_image_tprs.shape[1]: + msg = ( + f"Invalid {self.__class__.__name__} object. Attributes have inconsistent shapes: " + f"{self.thresholds.shape[0]=} != {self.per_image_tprs.shape[1]=}." + ) + raise TypeError(msg) + + def thresh_at(self, fpr_level: float) -> tuple[int, float, float]: + """Return the threshold at the given shared FPR. + + See `anomalib.metrics.per_image.pimo_numpy.thresh_at_shared_fpr_level` for details. + + Args: + fpr_level (float): shared FPR level + + Returns: + tuple[int, float, float]: + [0] index of the threshold + [1] threshold + [2] the actual shared FPR value at the returned threshold + """ + return functional.thresh_at_shared_fpr_level( + self.thresholds, + self.shared_fpr, + fpr_level, + ) + + +@dataclass +class AUPIMOResult: + """Area Under the Per-Image Overlap (AUPIMO, pronounced a-u-pee-mo) curve. + + This interface gathers the AUPIMO data and metadata and provides several utility methods. + + Attributes: + fpr_lower_bound (float): [metadata] LOWER bound of the FPR integration range + fpr_upper_bound (float): [metadata] UPPER bound of the FPR integration range + num_thresholds (int): [metadata] number of thresholds used to effectively compute AUPIMO; + should not be confused with the number of thresholds used to compute the PIMO curve + thresh_lower_bound (float): LOWER threshold bound --> corresponds to the UPPER FPR bound + thresh_upper_bound (float): UPPER threshold bound --> corresponds to the LOWER FPR bound + aupimos (torch.Tensor): values of AUPIMO scores (1 per image) + """ + + # metadata + fpr_lower_bound: float + fpr_upper_bound: float + num_thresholds: int + + # data + thresh_lower_bound: float = field(repr=False) + thresh_upper_bound: float = field(repr=False) + aupimos: torch.Tensor = field(repr=False) # shape => (N,) + + @property + def num_images(self) -> int: + """Number of images.""" + return self.aupimos.shape[0] + + @property + def num_normal_images(self) -> int: + """Number of normal images.""" + return int((self.image_classes == 0).sum()) + + @property + def num_anomalous_images(self) -> int: + """Number of anomalous images.""" + return int((self.image_classes == 1).sum()) + + @property + def image_classes(self) -> torch.Tensor: + """Image classes (0: normal, 1: anomalous).""" + # if an instance has `nan` aupimo it's because it's a normal image + return self.aupimos.isnan().to(torch.int32) + + @property + def fpr_bounds(self) -> tuple[float, float]: + """Lower and upper bounds of the FPR integration range.""" + return self.fpr_lower_bound, self.fpr_upper_bound + + @property + def thresh_bounds(self) -> tuple[float, float]: + """Lower and upper bounds of the threshold integration range. + + Recall: they correspond to the FPR bounds in reverse order. + I.e.: + fpr_lower_bound --> thresh_upper_bound + fpr_upper_bound --> thresh_lower_bound + """ + return self.thresh_lower_bound, self.thresh_upper_bound + + def __post_init__(self) -> None: + """Validate the inputs for the result object are consistent.""" + try: + _validate.is_rate_range((self.fpr_lower_bound, self.fpr_upper_bound)) + # TODO(jpcbertoldo): warn when it's too low (use parameters from the numpy code) # noqa: TD003 + _validate.is_num_thresholds_gte2(self.num_thresholds) + _validate.is_rates(self.aupimos, nan_allowed=True) # validate is_aupimos + + _validate.validate_threshold_bounds((self.thresh_lower_bound, self.thresh_upper_bound)) + + except (TypeError, ValueError) as ex: + msg = f"Invalid inputs for {self.__class__.__name__} object. Cause: {ex}." + raise TypeError(msg) from ex + + @classmethod + def from_pimo_result( + cls: type["AUPIMOResult"], + pimo_result: PIMOResult, + fpr_bounds: tuple[float, float], + num_thresholds_auc: int, + aupimos: torch.Tensor, + ) -> "AUPIMOResult": + """Return an AUPIMO result object from a PIMO result object. + + Args: + pimo_result: PIMO result object + fpr_bounds: lower and upper bounds of the FPR integration range + num_thresholds_auc: number of thresholds used to effectively compute AUPIMO; + NOT the number of thresholds used to compute the PIMO curve! + aupimos: AUPIMO scores + paths: paths to the source images to which the AUPIMO scores correspond. + """ + if pimo_result.per_image_tprs.shape[0] != aupimos.shape[0]: + msg = ( + f"Invalid {cls.__name__} object. Attributes have inconsistent shapes: " + f"there are {pimo_result.per_image_tprs.shape[0]} PIMO curves but {aupimos.shape[0]} AUPIMO scores." + ) + raise TypeError(msg) + + if not torch.isnan(aupimos[pimo_result.image_classes == 0]).all(): + msg = "Expected all normal images to have NaN AUPIMOs, but some have non-NaN values." + raise TypeError(msg) + + if torch.isnan(aupimos[pimo_result.image_classes == 1]).any(): + msg = "Expected all anomalous images to have valid AUPIMOs (not nan), but some have NaN values." + raise TypeError(msg) + + fpr_lower_bound, fpr_upper_bound = fpr_bounds + # recall: fpr upper/lower bounds are the same as the thresh lower/upper bounds + _, thresh_lower_bound, __ = pimo_result.thresh_at(fpr_upper_bound) + _, thresh_upper_bound, __ = pimo_result.thresh_at(fpr_lower_bound) + # `_` is the threshold's index, `__` is the actual fpr value + return cls( + fpr_lower_bound=fpr_lower_bound, + fpr_upper_bound=fpr_upper_bound, + num_thresholds=num_thresholds_auc, + thresh_lower_bound=float(thresh_lower_bound), + thresh_upper_bound=float(thresh_upper_bound), + aupimos=aupimos, + ) diff --git a/src/anomalib/metrics/pimo/functional.py b/src/anomalib/metrics/pimo/functional.py new file mode 100644 index 0000000000..7eac07b1bd --- /dev/null +++ b/src/anomalib/metrics/pimo/functional.py @@ -0,0 +1,355 @@ +"""Per-Image Overlap curve (PIMO, pronounced pee-mo) and its area under the curve (AUPIMO). + +Details: `anomalib.metrics.per_image.pimo`. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import numpy as np +import torch + +from . import _validate +from .binary_classification_curve import ( + ThresholdMethod, + _get_linspaced_thresholds, + per_image_fpr, + per_image_tpr, + threshold_and_binary_classification_curve, +) +from .utils import images_classes_from_masks + +logger = logging.getLogger(__name__) + + +def pimo_curves( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, + num_thresholds: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the Per-IMage Overlap (PIMO, pronounced pee-mo) curves. + + PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. + The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on + the normal images. + + Details: `anomalib.metrics.per_image.pimo`. + + Args' notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Args: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + num_thresholds: number of thresholds to compute (K) + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + [0] thresholds of shape (K,) in ascending order + [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) + [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) + [3] image classes of shape (N,) with values 0 (normal) or 1 (anomalous) + """ + # validate the strings are valid + _validate.is_num_thresholds_gte2(num_thresholds) + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) + _validate.has_at_least_one_anomalous_image(masks) + _validate.has_at_least_one_normal_image(masks) + + image_classes = images_classes_from_masks(masks) + + # the thresholds are computed here so that they can be restrained to the normal images + # therefore getting a better resolution in terms of FPR quantization + # otherwise the function `binclf_curve_numpy.per_image_binclf_curve` would have the range of thresholds + # computed from all the images (normal + anomalous) + thresholds = _get_linspaced_thresholds( + anomaly_maps[image_classes == 0], + num_thresholds, + ) + + # N: number of images, K: number of thresholds + # shapes are (K,) and (N, K, 2, 2) + thresholds, binclf_curves = threshold_and_binary_classification_curve( + anomaly_maps=anomaly_maps, + masks=masks, + threshold_choice=ThresholdMethod.GIVEN.value, + thresholds=thresholds, + num_thresholds=None, + ) + + shared_fpr: torch.Tensor + # mean-per-image-fpr on normal images + # shape -> (N, K) + per_image_fprs_normals = per_image_fpr(binclf_curves[image_classes == 0]) + try: + _validate.is_per_image_rate_curves(per_image_fprs_normals, nan_allowed=False, decreasing=True) + except ValueError as ex: + msg = f"Cannot compute PIMO because the per-image FPR curves from normal images are invalid. Cause: {ex}" + raise RuntimeError(msg) from ex + + # shape -> (K,) + # this is the only shared FPR metric implemented so far, + # see note about shared FPR in Details: `anomalib.metrics.per_image.pimo`. + shared_fpr = per_image_fprs_normals.mean(axis=0) + + # shape -> (N, K) + per_image_tprs = per_image_tpr(binclf_curves) + + return thresholds, shared_fpr, per_image_tprs, image_classes + + +# =========================================== AUPIMO =========================================== + + +def aupimo_scores( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, + num_thresholds: int = 300_000, + fpr_bounds: tuple[float, float] = (1e-5, 1e-4), + force: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int]: + """Compute the PIMO curves and their Area Under the Curve (i.e. AUPIMO) scores. + + Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. + It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. + + Details: `anomalib.metrics.per_image.pimo`. + + Args' notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Args: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + num_thresholds: number of thresholds to compute (K) + fpr_bounds: lower and upper bounds of the FPR integration range + force: whether to force the computation despite bad conditions + + Returns: + tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + [0] thresholds of shape (K,) in ascending order + [1] shared FPR values of shape (K,) in descending order (indices correspond to the thresholds) + [2] per-image TPR curves of shape (N, K), axis 1 in descending order (indices correspond to the thresholds) + [3] image classes of shape (N,) with values 0 (normal) or 1 (anomalous) + [4] AUPIMO scores of shape (N,) in [0, 1] + [5] number of points used in the AUC integration + """ + _validate.is_rate_range(fpr_bounds) + + # other validations are done in the `pimo` function + thresholds, shared_fpr, per_image_tprs, image_classes = pimo_curves( + anomaly_maps=anomaly_maps, + masks=masks, + num_thresholds=num_thresholds, + ) + try: + _validate.is_valid_threshold(thresholds) + _validate.is_rate_curve(shared_fpr, nan_allowed=False, decreasing=True) + _validate.is_images_classes(image_classes) + _validate.is_per_image_rate_curves(per_image_tprs[image_classes == 1], nan_allowed=False, decreasing=True) + + except ValueError as ex: + msg = f"Cannot compute AUPIMO because the PIMO curves are invalid. Cause: {ex}" + raise RuntimeError(msg) from ex + + fpr_lower_bound, fpr_upper_bound = fpr_bounds + + # get the threshold indices where the fpr bounds are achieved + fpr_lower_bound_thresh_idx, _, fpr_lower_bound_defacto = thresh_at_shared_fpr_level( + thresholds, + shared_fpr, + fpr_lower_bound, + ) + fpr_upper_bound_thresh_idx, _, fpr_upper_bound_defacto = thresh_at_shared_fpr_level( + thresholds, + shared_fpr, + fpr_upper_bound, + ) + + if not torch.isclose( + fpr_lower_bound_defacto, + torch.tensor(fpr_lower_bound, dtype=fpr_lower_bound_defacto.dtype, device=fpr_lower_bound_defacto.device), + rtol=(rtol := 1e-2), + ): + logger.warning( + "The lower bound of the shared FPR integration range is not exactly achieved. " + f"Expected {fpr_lower_bound} but got {fpr_lower_bound_defacto}, which is not within {rtol=}.", + ) + + if not torch.isclose( + fpr_upper_bound_defacto, + torch.tensor(fpr_upper_bound, dtype=fpr_upper_bound_defacto.dtype, device=fpr_upper_bound_defacto.device), + rtol=rtol, + ): + logger.warning( + "The upper bound of the shared FPR integration range is not exactly achieved. " + f"Expected {fpr_upper_bound} but got {fpr_upper_bound_defacto}, which is not within {rtol=}.", + ) + + # reminder: fpr lower/upper bound is threshold upper/lower bound (reversed) + thresh_lower_bound_idx = fpr_upper_bound_thresh_idx + thresh_upper_bound_idx = fpr_lower_bound_thresh_idx + + # deal with edge cases + if thresh_lower_bound_idx >= thresh_upper_bound_idx: + msg = ( + "The thresholds corresponding to the given `fpr_bounds` are not valid because " + "they matched the same threshold or the are in the wrong order. " + f"FPR upper/lower = threshold lower/upper = {thresh_lower_bound_idx} and {thresh_upper_bound_idx}." + ) + raise RuntimeError(msg) + + # limit the curves to the integration range [lbound, ubound] + shared_fpr_bounded: torch.Tensor = shared_fpr[thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] + per_image_tprs_bounded: torch.Tensor = per_image_tprs[:, thresh_lower_bound_idx : (thresh_upper_bound_idx + 1)] + + # `shared_fpr` and `tprs` are in descending order; `flip()` reverts to ascending order + shared_fpr_bounded = torch.flip(shared_fpr_bounded, dims=[0]) + per_image_tprs_bounded = torch.flip(per_image_tprs_bounded, dims=[1]) + + # the log's base does not matter because it's a constant factor canceled by normalization factor + shared_fpr_bounded_log = torch.log(shared_fpr_bounded) + + # deal with edge cases + invalid_shared_fpr = ~torch.isfinite(shared_fpr_bounded_log) + + if invalid_shared_fpr.all(): + msg = ( + "Cannot compute AUPIMO because the shared fpr integration range is invalid). " + "Try increasing the number of thresholds." + ) + raise RuntimeError(msg) + + if invalid_shared_fpr.any(): + logger.warning( + "Some values in the shared fpr integration range are nan. " + "The AUPIMO will be computed without these values.", + ) + + # get rid of nan values by removing them from the integration range + shared_fpr_bounded_log = shared_fpr_bounded_log[~invalid_shared_fpr] + per_image_tprs_bounded = per_image_tprs_bounded[:, ~invalid_shared_fpr] + + num_points_integral = int(shared_fpr_bounded_log.shape[0]) + + if num_points_integral <= 30: + msg = ( + "Cannot compute AUPIMO because the shared fpr integration range doesn't have enough points. " + f"Found {num_points_integral} points in the integration range. " + "Try increasing `num_thresholds`." + ) + if not force: + raise RuntimeError(msg) + msg += " Computation was forced!" + logger.warning(msg) + + if num_points_integral < 300: + logger.warning( + "The AUPIMO may be inaccurate because the shared fpr integration range doesn't have enough points. " + f"Found {num_points_integral} points in the integration range. " + "Try increasing `num_thresholds`.", + ) + + aucs: torch.Tensor = torch.trapezoid(per_image_tprs_bounded, x=shared_fpr_bounded_log, axis=1) + + # normalize, then clip(0, 1) makes sure that the values are in [0, 1] in case of numerical errors + normalization_factor = aupimo_normalizing_factor(fpr_bounds) + aucs = (aucs / normalization_factor).clip(0, 1) + + return thresholds, shared_fpr, per_image_tprs, image_classes, aucs, num_points_integral + + +# =========================================== AUX =========================================== + + +def thresh_at_shared_fpr_level( + thresholds: torch.Tensor, + shared_fpr: torch.Tensor, + fpr_level: float, +) -> tuple[int, float, torch.Tensor]: + """Return the threshold and its index at the given shared FPR level. + + Three cases are possible: + - fpr_level == 0: the lowest threshold that achieves 0 FPR is returned + - fpr_level == 1: the highest threshold that achieves 1 FPR is returned + - 0 < fpr_level < 1: the threshold that achieves the closest (higher or lower) FPR to `fpr_level` is returned + + Args: + thresholds: thresholds at which the shared FPR was computed. + shared_fpr: shared FPR values. + fpr_level: shared FPR value at which to get the threshold. + + Returns: + tuple[int, float, float]: + [0] index of the threshold + [1] threshold + [2] the actual shared FPR value at the returned threshold + """ + _validate.is_valid_threshold(thresholds) + _validate.is_rate_curve(shared_fpr, nan_allowed=False, decreasing=True) + _validate.joint_validate_thresholds_shared_fpr(thresholds, shared_fpr) + _validate.is_rate(fpr_level, zero_ok=True, one_ok=True) + + shared_fpr_min, shared_fpr_max = shared_fpr.min(), shared_fpr.max() + + if fpr_level < shared_fpr_min: + msg = ( + "Invalid `fpr_level` because it's out of the range of `shared_fpr` = " + f"[{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." + ) + raise ValueError(msg) + + if fpr_level > shared_fpr_max: + msg = ( + "Invalid `fpr_level` because it's out of the range of `shared_fpr` = " + f"[{shared_fpr_min}, {shared_fpr_max}], and got {fpr_level}." + ) + raise ValueError(msg) + + # fpr_level == 0 or 1 are special case + # because there may be multiple solutions, and the chosen should their MINIMUM/MAXIMUM respectively + if fpr_level == 0.0: + index = torch.min(torch.where(shared_fpr == fpr_level)[0]) + + elif fpr_level == 1.0: + index = torch.max(torch.where(shared_fpr == fpr_level)[0]) + + else: + index = torch.argmin(torch.abs(shared_fpr - fpr_level)) + + index = int(index) + fpr_level_defacto = shared_fpr[index] + thresh = thresholds[index] + return index, thresh, fpr_level_defacto + + +def aupimo_normalizing_factor(fpr_bounds: tuple[float, float]) -> float: + """Constant that normalizes the AUPIMO integral to 0-1 range. + + It is the maximum possible value from the integral in AUPIMO's definition. + It corresponds to assuming a constant function T_i: thresh --> 1. + + Args: + fpr_bounds: lower and upper bounds of the FPR integration range. + + Returns: + float: the normalization factor (>0). + """ + _validate.is_rate_range(fpr_bounds) + fpr_lower_bound, fpr_upper_bound = fpr_bounds + # the log's base must be the same as the one used in the integration! + return float(np.log(fpr_upper_bound / fpr_lower_bound)) diff --git a/src/anomalib/metrics/pimo/pimo.py b/src/anomalib/metrics/pimo/pimo.py new file mode 100644 index 0000000000..9703b60b59 --- /dev/null +++ b/src/anomalib/metrics/pimo/pimo.py @@ -0,0 +1,296 @@ +"""Per-Image Overlap curve (PIMO, pronounced pee-mo) and its area under the curve (AUPIMO). + +# PIMO + +PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. +The anomaly score thresholds are indexed by a (shared) valued of False Positive Rate (FPR) measure on the normal images. + +Each *anomalous* image has its own curve such that the X-axis is shared by all of them. + +At a given threshold: + X-axis: Shared FPR (may vary) + 1. Log of the Average of per-image FPR on normal images. + SEE NOTE BELOW. + Y-axis: per-image TP Rate (TPR), or "Overlap" between the ground truth and the predicted masks. + +*** Note about other shared FPR alternatives *** +The shared FPR metric can be made harder by using the cross-image max (or high-percentile) FPRs instead of the mean. +Rationale: this will further punish models that have exceptional FPs in normal images. +So far there is only one shared FPR metric implemented but others will be added in the future. + +# AUPIMO + +`AUPIMO` is the area under each `PIMO` curve with bounded integration range in terms of shared FPR. + +# Disclaimer + +This module implements torch interfaces to access the numpy code in `pimo_numpy.py`. +Tensors are converted to numpy arrays and then passed and validated in the numpy code. +The results are converted back to tensors and eventually wrapped in an dataclass object. + +Validations will preferably happen in ndarray so the numpy code can be reused without torch, +so often times the Tensor arguments will be converted to ndarray and then validated. +""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch +from torchmetrics import Metric + +from . import _validate, functional +from .dataclasses import AUPIMOResult, PIMOResult + +logger = logging.getLogger(__name__) + + +class PIMO(Metric): + """Per-IMage Overlap (PIMO, pronounced pee-mo) curves. + + This torchmetrics interface is a wrapper around the functional interface, which is a wrapper around the numpy code. + The tensors are converted to numpy arrays and then passed and validated in the numpy code. + The results are converted back to tensors and wrapped in an dataclass object. + + PIMO is a curve of True Positive Rate (TPR) values on each image across multiple anomaly score thresholds. + The anomaly score thresholds are indexed by a (cross-image shared) value of False Positive Rate (FPR) measure on + the normal images. + + Details: `anomalib.metrics.per_image.pimo`. + + Notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Attributes: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + + Args: + num_thresholds: number of thresholds to compute (K) + binclf_algorithm: algorithm to compute the binary classifier curve (see `binclf_curve_numpy.Algorithm`) + + Returns: + PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + """ + + is_differentiable: bool = False + higher_is_better: bool | None = None + full_state_update: bool = False + + num_thresholds: int + binclf_algorithm: str + + anomaly_maps: list[torch.Tensor] + masks: list[torch.Tensor] + + @property + def _is_empty(self) -> bool: + """Return True if the metric has not been updated yet.""" + return len(self.anomaly_maps) == 0 + + @property + def num_images(self) -> int: + """Number of images.""" + return sum(am.shape[0] for am in self.anomaly_maps) + + @property + def image_classes(self) -> torch.Tensor: + """Image classes (0: normal, 1: anomalous).""" + return functional.images_classes_from_masks(self.masks) + + def __init__(self, num_thresholds: int) -> None: + """Per-Image Overlap (PIMO) curve. + + Args: + num_thresholds: number of thresholds used to compute the PIMO curve (K) + """ + super().__init__() + + logger.warning( + f"Metric `{self.__class__.__name__}` will save all targets and predictions in buffer." + " For large datasets this may lead to large memory footprint.", + ) + + # the options below are, redundantly, validated here to avoid reaching + # an error later in the execution + + _validate.is_num_thresholds_gte2(num_thresholds) + self.num_thresholds = num_thresholds + + self.add_state("anomaly_maps", default=[], dist_reduce_fx="cat") + self.add_state("masks", default=[], dist_reduce_fx="cat") + + def update(self, anomaly_maps: torch.Tensor, masks: torch.Tensor) -> None: + """Update lists of anomaly maps and masks. + + Args: + anomaly_maps (torch.Tensor): predictions of the model (ndim == 2, float) + masks (torch.Tensor): ground truth masks (ndim == 2, binary) + """ + _validate.is_anomaly_maps(anomaly_maps) + _validate.is_masks(masks) + _validate.is_same_shape(anomaly_maps, masks) + self.anomaly_maps.append(anomaly_maps) + self.masks.append(masks) + + def compute(self) -> PIMOResult: + """Compute the PIMO curves. + + Call the functional interface `pimo_curves()`, which is a wrapper around the numpy code. + + Returns: + PIMOResult: PIMO curves dataclass object. See `PIMOResult` for details. + """ + if self._is_empty: + msg = "No anomaly maps and masks have been added yet. Please call `update()` first." + raise RuntimeError(msg) + anomaly_maps = torch.concat(self.anomaly_maps, dim=0) + masks = torch.concat(self.masks, dim=0) + + thresholds, shared_fpr, per_image_tprs, _ = functional.pimo_curves( + anomaly_maps, + masks, + self.num_thresholds, + ) + return PIMOResult( + thresholds=thresholds, + shared_fpr=shared_fpr, + per_image_tprs=per_image_tprs, + ) + + +class AUPIMO(PIMO): + """Area Under the Per-Image Overlap (PIMO) curve. + + This torchmetrics interface is a wrapper around the functional interface, which is a wrapper around the numpy code. + The tensors are converted to numpy arrays and then passed and validated in the numpy code. + The results are converted back to tensors and wrapped in an dataclass object. + + Scores are computed from the integration of the PIMO curves within the given FPR bounds, then normalized to [0, 1]. + It can be thought of as the average TPR of the PIMO curves within the given FPR bounds. + + Details: `anomalib.metrics.per_image.pimo`. + + Notation: + N: number of images + H: image height + W: image width + K: number of thresholds + + Attributes: + anomaly_maps: floating point anomaly score maps of shape (N, H, W) + masks: binary (bool or int) ground truth masks of shape (N, H, W) + + Args: + num_thresholds: number of thresholds to compute (K) + fpr_bounds: lower and upper bounds of the FPR integration range + force: whether to force the computation despite bad conditions + + Returns: + tuple[PIMOResult, AUPIMOResult]: PIMO and AUPIMO results dataclass objects. See `PIMOResult` and `AUPIMOResult`. + """ + + fpr_bounds: tuple[float, float] + return_average: bool + force: bool + + @staticmethod + def normalizing_factor(fpr_bounds: tuple[float, float]) -> float: + """Constant that normalizes the AUPIMO integral to 0-1 range. + + It is the maximum possible value from the integral in AUPIMO's definition. + It corresponds to assuming a constant function T_i: thresh --> 1. + + Args: + fpr_bounds: lower and upper bounds of the FPR integration range. + + Returns: + float: the normalization factor (>0). + """ + return functional.aupimo_normalizing_factor(fpr_bounds) + + def __repr__(self) -> str: + """Show the metric name and its integration bounds.""" + lower, upper = self.fpr_bounds + return f"{self.__class__.__name__}([{lower:.2g}, {upper:.2g}])" + + def __init__( + self, + num_thresholds: int = 300_000, + fpr_bounds: tuple[float, float] = (1e-5, 1e-4), + return_average: bool = True, + force: bool = False, + ) -> None: + """Area Under the Per-Image Overlap (PIMO) curve. + + Args: + num_thresholds: [passed to parent `PIMO`] number of thresholds used to compute the PIMO curve + fpr_bounds: lower and upper bounds of the FPR integration range + return_average: if True, return the average AUPIMO score; if False, return all the individual AUPIMO scores + force: if True, force the computation of the AUPIMO scores even in bad conditions (e.g. few points) + """ + super().__init__(num_thresholds=num_thresholds) + + # other validations are done in PIMO.__init__() + + _validate.is_rate_range(fpr_bounds) + self.fpr_bounds = fpr_bounds + self.return_average = return_average + self.force = force + + def compute(self, force: bool | None = None) -> tuple[PIMOResult, AUPIMOResult]: # type: ignore[override] + """Compute the PIMO curves and their Area Under the curve (AUPIMO) scores. + + Call the functional interface `aupimo_scores()`, which is a wrapper around the numpy code. + + Args: + force: if given (not None), override the `force` attribute. + + Returns: + tuple[PIMOResult, AUPIMOResult]: PIMO curves and AUPIMO scores dataclass objects. + See `PIMOResult` and `AUPIMOResult` for details. + """ + if self._is_empty: + msg = "No anomaly maps and masks have been added yet. Please call `update()` first." + raise RuntimeError(msg) + anomaly_maps = torch.concat(self.anomaly_maps, dim=0) + masks = torch.concat(self.masks, dim=0) + force = force if force is not None else self.force + + # other validations are done in the numpy code + + thresholds, shared_fpr, per_image_tprs, _, aupimos, num_thresholds_auc = functional.aupimo_scores( + anomaly_maps, + masks, + self.num_thresholds, + fpr_bounds=self.fpr_bounds, + force=force, + ) + + pimo_result = PIMOResult( + thresholds=thresholds, + shared_fpr=shared_fpr, + per_image_tprs=per_image_tprs, + ) + aupimo_result = AUPIMOResult.from_pimo_result( + pimo_result, + fpr_bounds=self.fpr_bounds, + # not `num_thresholds`! + # `num_thresholds` is the number of thresholds used to compute the PIMO curve + # this is the number of thresholds used to compute the AUPIMO integral + num_thresholds_auc=num_thresholds_auc, + aupimos=aupimos, + ) + if self.return_average: + # normal images have NaN AUPIMO scores + is_nan = torch.isnan(aupimo_result.aupimos) + return aupimo_result.aupimos[~is_nan].mean() + return pimo_result, aupimo_result diff --git a/src/anomalib/metrics/pimo/utils.py b/src/anomalib/metrics/pimo/utils.py new file mode 100644 index 0000000000..f0cac45657 --- /dev/null +++ b/src/anomalib/metrics/pimo/utils.py @@ -0,0 +1,19 @@ +"""Torch-oriented interfaces for `utils.py`.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +def images_classes_from_masks(masks: torch.Tensor) -> torch.Tensor: + """Deduce the image classes from the masks.""" + return (masks == 1).any(axis=(1, 2)).to(torch.int32) diff --git a/tests/unit/data/utils/test_path.py b/tests/unit/data/utils/test_path.py index c3f134b021..09f88496ad 100644 --- a/tests/unit/data/utils/test_path.py +++ b/tests/unit/data/utils/test_path.py @@ -76,3 +76,9 @@ def test_no_read_execute_permission() -> None: Path(tmp_dir).chmod(0o222) # Remove read and execute permission with pytest.raises(PermissionError, match=r"Read or execute permissions denied for the path:*"): validate_path(tmp_dir, base_dir=Path(tmp_dir)) + + @staticmethod + def test_file_wrongsuffix() -> None: + """Test ``validate_path`` raises ValueError for a file with wrong suffix.""" + with pytest.raises(ValueError, match="Path extension is not accepted."): + validate_path("file.png", should_exist=False, extensions=(".json", ".txt")) diff --git a/tests/unit/metrics/pimo/__init__.py b/tests/unit/metrics/pimo/__init__.py new file mode 100644 index 0000000000..555d67a102 --- /dev/null +++ b/tests/unit/metrics/pimo/__init__.py @@ -0,0 +1,8 @@ +"""Per-Image Metrics Tests.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 diff --git a/tests/unit/metrics/pimo/test_binary_classification_curve.py b/tests/unit/metrics/pimo/test_binary_classification_curve.py new file mode 100644 index 0000000000..5459d08a14 --- /dev/null +++ b/tests/unit/metrics/pimo/test_binary_classification_curve.py @@ -0,0 +1,423 @@ +"""Tests for per-image binary classification curves using numpy version.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa: SLF001, PT011 + +import pytest +import torch + +from anomalib.metrics.pimo.binary_classification_curve import ( + _binary_classification_curve, + binary_classification_curve, + per_image_fpr, + per_image_tpr, + threshold_and_binary_classification_curve, +) + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Generate test cases.""" + pred = torch.arange(1, 5, dtype=torch.float32) + thresholds = torch.arange(1, 5, dtype=torch.float32) + + gt_norm = torch.zeros(4).to(bool) + gt_anom = torch.concatenate([torch.zeros(2), torch.ones(2)]).to(bool) + + # in the case where thresholds are all unique values in the predictions + expected_norm = torch.stack( + [ + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[1, 3], [0, 0]]), + torch.tensor([[2, 2], [0, 0]]), + torch.tensor([[3, 1], [0, 0]]), + ], + axis=0, + ).to(int) + expected_anom = torch.stack( + [ + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[1, 1], [0, 2]]), + torch.tensor([[2, 0], [0, 2]]), + torch.tensor([[2, 0], [1, 1]]), + ], + axis=0, + ).to(int) + + expected_tprs_norm = torch.tensor([torch.nan, torch.nan, torch.nan, torch.nan]) + expected_tprs_anom = torch.tensor([1.0, 1.0, 1.0, 0.5]) + expected_tprs = torch.stack([expected_tprs_anom, expected_tprs_norm], axis=0).to(torch.float64) + + expected_fprs_norm = torch.tensor([1.0, 0.75, 0.5, 0.25]) + expected_fprs_anom = torch.tensor([1.0, 0.5, 0.0, 0.0]) + expected_fprs = torch.stack([expected_fprs_anom, expected_fprs_norm], axis=0).to(torch.float64) + + # in the case where all thresholds are higher than the highest prediction + expected_norm_thresholds_too_high = torch.stack( + [ + torch.tensor([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), + torch.tensor([[4, 0], [0, 0]]), + ], + axis=0, + ).to(int) + expected_anom_thresholds_too_high = torch.stack( + [ + torch.tensor([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), + torch.tensor([[2, 0], [2, 0]]), + ], + axis=0, + ).to(int) + + # in the case where all thresholds are lower than the lowest prediction + expected_norm_thresholds_too_low = torch.stack( + [ + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + torch.tensor([[0, 4], [0, 0]]), + ], + axis=0, + ).to(int) + expected_anom_thresholds_too_low = torch.stack( + [ + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), + torch.tensor([[0, 2], [0, 2]]), + ], + axis=0, + ).to(int) + + if metafunc.function is test__binclf_one_curve: + metafunc.parametrize( + argnames=("pred", "gt", "thresholds", "expected"), + argvalues=[ + (pred, gt_anom, thresholds[:3], expected_anom[:3]), + (pred, gt_anom, thresholds, expected_anom), + (pred, gt_norm, thresholds, expected_norm), + (pred, gt_norm, 10 * thresholds, expected_norm_thresholds_too_high), + (pred, gt_anom, 10 * thresholds, expected_anom_thresholds_too_high), + (pred, gt_norm, 0.001 * thresholds, expected_norm_thresholds_too_low), + (pred, gt_anom, 0.001 * thresholds, expected_anom_thresholds_too_low), + ], + ) + + preds = torch.stack([pred, pred], axis=0) + gts = torch.stack([gt_anom, gt_norm], axis=0) + binclf_curves = torch.stack([expected_anom, expected_norm], axis=0) + binclf_curves_thresholds_too_high = torch.stack( + [expected_anom_thresholds_too_high, expected_norm_thresholds_too_high], + axis=0, + ) + binclf_curves_thresholds_too_low = torch.stack( + [expected_anom_thresholds_too_low, expected_norm_thresholds_too_low], + axis=0, + ) + + if metafunc.function is test__binclf_multiple_curves: + metafunc.parametrize( + argnames=("preds", "gts", "thresholds", "expecteds"), + argvalues=[ + (preds, gts, thresholds[:3], binclf_curves[:, :3]), + (preds, gts, thresholds, binclf_curves), + ], + ) + + if metafunc.function is test_binclf_multiple_curves: + metafunc.parametrize( + argnames=( + "preds", + "gts", + "thresholds", + "expected_binclf_curves", + ), + argvalues=[ + (preds[:1], gts[:1], thresholds, binclf_curves[:1]), + (preds, gts, thresholds, binclf_curves), + (10 * preds, gts, 10 * thresholds, binclf_curves), + ], + ) + + if metafunc.function is test_binclf_multiple_curves_validations: + metafunc.parametrize( + argnames=("args", "kwargs", "exception"), + argvalues=[ + # `scores` and `gts` must be 2D + ([preds.reshape(2, 2, 2), gts, thresholds], {}, ValueError), + ([preds, gts.flatten(), thresholds], {}, ValueError), + # `thresholds` must be 1D + ([preds, gts, thresholds.reshape(2, 2)], {}, ValueError), + # `scores` and `gts` must have the same shape + ([preds, gts[:1], thresholds], {}, ValueError), + ([preds[:, :2], gts, thresholds], {}, ValueError), + # `scores` be of type float + ([preds.to(int), gts, thresholds], {}, TypeError), + # `gts` be of type bool + ([preds, gts.to(int), thresholds], {}, TypeError), + # `thresholds` be of type float + ([preds, gts, thresholds.to(int)], {}, TypeError), + # `thresholds` must be sorted in ascending order + ([preds, gts, torch.flip(thresholds, dims=[0])], {}, ValueError), + ([preds, gts, torch.concatenate([thresholds[-2:], thresholds[:2]])], {}, ValueError), + # `thresholds` must be unique + ([preds, gts, torch.sort(torch.concatenate([thresholds, thresholds]))[0]], {}, ValueError), + ], + ) + + # the following tests are for `per_image_binclf_curve()`, which expects + # inputs in image spatial format, i.e. (height, width) + preds = preds.reshape(2, 2, 2) + gts = gts.reshape(2, 2, 2) + + per_image_binclf_curves_argvalues = [ + # `thresholds_choice` = "given" + ( + preds, + gts, + "given", + thresholds, + None, + thresholds, + binclf_curves, + ), + ( + preds, + gts, + "given", + 10 * thresholds, + 2, + 10 * thresholds, + binclf_curves_thresholds_too_high, + ), + ( + preds, + gts, + "given", + 0.01 * thresholds, + None, + 0.01 * thresholds, + binclf_curves_thresholds_too_low, + ), + # `thresholds_choice` = 'minmax-linspace'" + ( + preds, + gts, + "minmax-linspace", + None, + len(thresholds), + thresholds, + binclf_curves, + ), + ( + 2 * preds, + gts.to(int), # this is ok + "minmax-linspace", + None, + len(thresholds), + 2 * thresholds, + binclf_curves, + ), + ] + + if metafunc.function is test_per_image_binclf_curve: + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + "threshold_choice", + "thresholds", + "num_thresholds", + "expected_thresholds", + "expected_binclf_curves", + ), + argvalues=per_image_binclf_curves_argvalues, + ) + + if metafunc.function is test_per_image_binclf_curve_validations: + metafunc.parametrize( + argnames=("args", "exception"), + argvalues=[ + # `scores` and `gts` must be 3D + ([preds.reshape(2, 2, 2, 1), gts], ValueError), + ([preds, gts.flatten()], ValueError), + # `scores` and `gts` must have the same shape + ([preds, gts[:1]], ValueError), + ([preds[:, :1], gts], ValueError), + # `scores` be of type float + ([preds.to(int), gts], TypeError), + # `gts` be of type bool or int + ([preds, gts.to(float)], TypeError), + # `thresholds` be of type float + ([preds, gts, thresholds.to(int)], TypeError), + ], + ) + metafunc.parametrize( + argnames=("kwargs",), + argvalues=[ + ( + { + "threshold_choice": "minmax-linspace", + "thresholds": None, + "num_thresholds": len(thresholds), + }, + ), + ], + ) + + # same as above but testing other validations + if metafunc.function is test_per_image_binclf_curve_validations_alt: + metafunc.parametrize( + argnames=("args", "kwargs", "exception"), + argvalues=[ + # invalid `thresholds_choice` + ( + [preds, gts], + {"threshold_choice": "glfrb", "thresholds": thresholds, "num_thresholds": None}, + ValueError, + ), + ], + ) + + if metafunc.function is test_rate_metrics: + metafunc.parametrize( + argnames=("binclf_curves", "expected_fprs", "expected_tprs"), + argvalues=[ + (binclf_curves, expected_fprs, expected_tprs), + (10 * binclf_curves, expected_fprs, expected_tprs), + ], + ) + + +# ================================================================================================== +# LOW-LEVEL FUNCTIONS (PYTHON) + + +def test__binclf_one_curve( + pred: torch.Tensor, + gt: torch.Tensor, + thresholds: torch.Tensor, + expected: torch.Tensor, +) -> None: + """Test if `_binclf_one_curve()` returns the expected values.""" + computed = _binary_classification_curve(pred, gt, thresholds) + assert computed.shape == (thresholds.numel(), 2, 2) + assert (computed == expected.numpy()).all() + + +def test__binclf_multiple_curves( + preds: torch.Tensor, + gts: torch.Tensor, + thresholds: torch.Tensor, + expecteds: torch.Tensor, +) -> None: + """Test if `_binclf_multiple_curves()` returns the expected values.""" + computed = binary_classification_curve(preds, gts, thresholds) + assert computed.shape == (preds.shape[0], thresholds.numel(), 2, 2) + assert (computed == expecteds).all() + + +# ================================================================================================== +# API FUNCTIONS (NUMPY) + + +def test_binclf_multiple_curves( + preds: torch.Tensor, + gts: torch.Tensor, + thresholds: torch.Tensor, + expected_binclf_curves: torch.Tensor, +) -> None: + """Test if `binclf_multiple_curves()` returns the expected values.""" + computed = binary_classification_curve( + preds, + gts, + thresholds, + ) + assert computed.shape == expected_binclf_curves.shape + assert (computed == expected_binclf_curves).all() + + # it's ok to have the threhsholds beyond the range of the preds + binary_classification_curve(preds, gts, 2 * thresholds) + + # or inside the bounds without reaching them + binary_classification_curve(preds, gts, 0.5 * thresholds) + + # it's also ok to have more thresholds than unique values in the preds + # add the values in between the thresholds + thresholds_unncessary = 0.5 * (thresholds[:-1] + thresholds[1:]) + thresholds_unncessary = torch.concatenate([thresholds_unncessary, thresholds]) + thresholds_unncessary = torch.sort(thresholds_unncessary)[0] + binary_classification_curve(preds, gts, thresholds_unncessary) + + # or less + binary_classification_curve(preds, gts, thresholds[1:3]) + + +def test_binclf_multiple_curves_validations(args: list, kwargs: dict, exception: Exception) -> None: + """Test if `_binclf_multiple_curves_python()` raises the expected errors.""" + with pytest.raises(exception): + binary_classification_curve(*args, **kwargs) + + +def test_per_image_binclf_curve( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, + threshold_choice: str, + thresholds: torch.Tensor | None, + num_thresholds: int | None, + expected_thresholds: torch.Tensor, + expected_binclf_curves: torch.Tensor, +) -> None: + """Test if `per_image_binclf_curve()` returns the expected values.""" + computed_thresholds, computed_binclf_curves = threshold_and_binary_classification_curve( + anomaly_maps, + masks, + threshold_choice=threshold_choice, + thresholds=thresholds, + num_thresholds=num_thresholds, + ) + + # thresholds + assert computed_thresholds.shape == expected_thresholds.shape + assert computed_thresholds.dtype == computed_thresholds.dtype + assert (computed_thresholds == expected_thresholds).all() + + # binclf_curves + assert computed_binclf_curves.shape == expected_binclf_curves.shape + assert computed_binclf_curves.dtype == expected_binclf_curves.dtype + assert (computed_binclf_curves == expected_binclf_curves).all() + + +def test_per_image_binclf_curve_validations(args: list, kwargs: dict, exception: Exception) -> None: + """Test if `per_image_binclf_curve()` raises the expected errors.""" + with pytest.raises(exception): + threshold_and_binary_classification_curve(*args, **kwargs) + + +def test_per_image_binclf_curve_validations_alt(args: list, kwargs: dict, exception: Exception) -> None: + """Test if `per_image_binclf_curve()` raises the expected errors.""" + test_per_image_binclf_curve_validations(args, kwargs, exception) + + +def test_rate_metrics( + binclf_curves: torch.Tensor, + expected_fprs: torch.Tensor, + expected_tprs: torch.Tensor, +) -> None: + """Test if rate metrics are computed correctly.""" + tprs = per_image_tpr(binclf_curves) + fprs = per_image_fpr(binclf_curves) + + assert tprs.shape == expected_tprs.shape + assert fprs.shape == expected_fprs.shape + + assert torch.allclose(tprs, expected_tprs, equal_nan=True) + assert torch.allclose(fprs, expected_fprs, equal_nan=True) diff --git a/tests/unit/metrics/pimo/test_pimo.py b/tests/unit/metrics/pimo/test_pimo.py new file mode 100644 index 0000000000..81bafe4c8e --- /dev/null +++ b/tests/unit/metrics/pimo/test_pimo.py @@ -0,0 +1,368 @@ +"""Test `anomalib.metrics.per_image.functional`.""" + +# Original Code +# https://github.com/jpcbertoldo/aupimo +# +# Modified +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import pytest +import torch +from torch import Tensor + +from anomalib.metrics.pimo import AUPIMOResult, PIMOResult, functional, pimo + + +def pytest_generate_tests(metafunc: pytest.Metafunc) -> None: + """Generate tests for all functions in this module. + + All functions are parametrized with the same setting: 1 normal and 2 anomalous images. + The anomaly maps are the same for all functions, but the masks are different. + """ + expected_thresholds = torch.arange(1, 7 + 1, dtype=torch.float32) + shape = (1000, 1000) # (H, W), 1 million pixels + + # --- normal --- + # histogram of scores: + # value: 7 6 5 4 3 2 1 + # count: 1 9 90 900 9k 90k 900k + # cumsum: 1 10 100 1k 10k 100k 1M + pred_norm = torch.ones(1_000_000, dtype=torch.float32) + pred_norm[:100_000] += 1 + pred_norm[:10_000] += 1 + pred_norm[:1_000] += 1 + pred_norm[:100] += 1 + pred_norm[:10] += 1 + pred_norm[:1] += 1 + pred_norm = pred_norm.reshape(shape) + mask_norm = torch.zeros_like(pred_norm, dtype=torch.int32) + + expected_fpr_norm = torch.tensor([1.0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6], dtype=torch.float64) + expected_tpr_norm = torch.full((7,), torch.nan, dtype=torch.float64) + + # --- anomalous --- + pred_anom1 = pred_norm.clone() + mask_anom1 = torch.ones_like(pred_anom1, dtype=torch.int32) + expected_tpr_anom1 = expected_fpr_norm.clone() + + # only the first 100_000 pixels are anomalous + # which corresponds to the first 100_000 highest scores (2 to 7) + pred_anom2 = pred_norm.clone() + mask_anom2 = torch.concatenate([torch.ones(100_000), torch.zeros(900_000)]).reshape(shape).to(torch.int32) + expected_tpr_anom2 = (10 * expected_fpr_norm).clip(0, 1) + + anomaly_maps = torch.stack([pred_norm, pred_anom1, pred_anom2], axis=0) + masks = torch.stack([mask_norm, mask_anom1, mask_anom2], axis=0) + + expected_shared_fpr = expected_fpr_norm + expected_per_image_tprs = torch.stack([expected_tpr_norm, expected_tpr_anom1, expected_tpr_anom2], axis=0) + expected_image_classes = torch.tensor([0, 1, 1], dtype=torch.int32) + + if metafunc.function is test_pimo or metafunc.function is test_aupimo_values: + argvalues_tensors = [ + ( + anomaly_maps, + masks, + expected_thresholds, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ), + ( + 10 * anomaly_maps, + masks, + 10 * expected_thresholds, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ), + ] + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + "expected_thresholds", + "expected_shared_fpr", + "expected_per_image_tprs", + "expected_image_classes", + ), + argvalues=argvalues_tensors, + ) + + if metafunc.function is test_aupimo_values: + argvalues_tensors = [ + ( + (1e-1, 1.0), + torch.tensor( + [ + torch.nan, + # recall: trapezium area = (a + b) * h / 2 + (0.10 + 1.0) * 1 / 2, + (1.0 + 1.0) * 1 / 2, + ], + dtype=torch.float64, + ), + ), + ( + (1e-3, 1e-1), + torch.tensor( + [ + torch.nan, + # average of two trapezium areas / 2 (normalizing factor) + (((1e-3 + 1e-2) * 1 / 2) + ((1e-2 + 1e-1) * 1 / 2)) / 2, + (((1e-2 + 1e-1) * 1 / 2) + ((1e-1 + 1.0) * 1 / 2)) / 2, + ], + dtype=torch.float64, + ), + ), + ( + (1e-5, 1e-4), + torch.tensor( + [ + torch.nan, + (1e-5 + 1e-4) * 1 / 2, + (1e-4 + 1e-3) * 1 / 2, + ], + dtype=torch.float64, + ), + ), + ] + metafunc.parametrize( + argnames=( + "fpr_bounds", + "expected_aupimos", # trapezoid surfaces + ), + argvalues=argvalues_tensors, + ) + + if metafunc.function is test_aupimo_edge: + metafunc.parametrize( + argnames=( + "anomaly_maps", + "masks", + ), + argvalues=[ + ( + anomaly_maps, + masks, + ), + ( + 10 * anomaly_maps, + masks, + ), + ], + ) + metafunc.parametrize( + argnames=("fpr_bounds",), + argvalues=[ + ((1e-1, 1.0),), + ((1e-3, 1e-2),), + ((1e-5, 1e-4),), + (None,), + ], + ) + + +def _do_test_pimo_outputs( + thresholds: Tensor, + shared_fpr: Tensor, + per_image_tprs: Tensor, + image_classes: Tensor, + expected_thresholds: Tensor, + expected_shared_fpr: Tensor, + expected_per_image_tprs: Tensor, + expected_image_classes: Tensor, +) -> None: + """Test if the outputs of any of the PIMO interfaces are correct.""" + assert isinstance(shared_fpr, Tensor) + assert isinstance(per_image_tprs, Tensor) + assert isinstance(image_classes, Tensor) + assert isinstance(expected_thresholds, Tensor) + assert isinstance(expected_shared_fpr, Tensor) + assert isinstance(expected_per_image_tprs, Tensor) + assert isinstance(expected_image_classes, Tensor) + allclose = torch.allclose + + assert thresholds.ndim == 1 + assert shared_fpr.ndim == 1 + assert per_image_tprs.ndim == 2 + assert tuple(image_classes.shape) == (3,) + + assert allclose(thresholds, expected_thresholds) + assert allclose(shared_fpr, expected_shared_fpr) + assert allclose(per_image_tprs, expected_per_image_tprs, equal_nan=True) + assert (image_classes == expected_image_classes).all() + + +def test_pimo( + anomaly_maps: Tensor, + masks: Tensor, + expected_thresholds: Tensor, + expected_shared_fpr: Tensor, + expected_per_image_tprs: Tensor, + expected_image_classes: Tensor, +) -> None: + """Test if `pimo()` returns the expected values.""" + + def do_assertions(pimo_result: PIMOResult) -> None: + thresholds = pimo_result.thresholds + shared_fpr = pimo_result.shared_fpr + per_image_tprs = pimo_result.per_image_tprs + image_classes = pimo_result.image_classes + _do_test_pimo_outputs( + thresholds, + shared_fpr, + per_image_tprs, + image_classes, + expected_thresholds, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ) + + # metric interface + metric = pimo.PIMO( + num_thresholds=7, + ) + metric.update(anomaly_maps, masks) + pimo_result = metric.compute() + do_assertions(pimo_result) + + +def _do_test_aupimo_outputs( + thresholds: Tensor, + shared_fpr: Tensor, + per_image_tprs: Tensor, + image_classes: Tensor, + aupimos: Tensor, + expected_thresholds: Tensor, + expected_shared_fpr: Tensor, + expected_per_image_tprs: Tensor, + expected_image_classes: Tensor, + expected_aupimos: Tensor, +) -> None: + _do_test_pimo_outputs( + thresholds, + shared_fpr, + per_image_tprs, + image_classes, + expected_thresholds, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + ) + assert isinstance(aupimos, Tensor) + assert isinstance(expected_aupimos, Tensor) + allclose = torch.allclose + assert tuple(aupimos.shape) == (3,) + assert allclose(aupimos, expected_aupimos, equal_nan=True) + + +def test_aupimo_values( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, + fpr_bounds: tuple[float, float], + expected_thresholds: torch.Tensor, + expected_shared_fpr: torch.Tensor, + expected_per_image_tprs: torch.Tensor, + expected_image_classes: torch.Tensor, + expected_aupimos: torch.Tensor, +) -> None: + """Test if `aupimo()` returns the expected values.""" + + def do_assertions(pimo_result: PIMOResult, aupimo_result: AUPIMOResult) -> None: + # test metadata + assert aupimo_result.fpr_bounds == fpr_bounds + # recall: this one is not the same as the number of thresholds in the curve + # this is the number of thresholds used to compute the integral in `aupimo()` + # always less because of the integration bounds + assert aupimo_result.num_thresholds < 7 + + # test data + # from pimo result + thresholds = pimo_result.thresholds + shared_fpr = pimo_result.shared_fpr + per_image_tprs = pimo_result.per_image_tprs + image_classes = pimo_result.image_classes + # from aupimo result + aupimos = aupimo_result.aupimos + _do_test_aupimo_outputs( + thresholds, + shared_fpr, + per_image_tprs, + image_classes, + aupimos, + expected_thresholds, + expected_shared_fpr, + expected_per_image_tprs, + expected_image_classes, + expected_aupimos, + ) + thresh_lower_bound = aupimo_result.thresh_lower_bound + thresh_upper_bound = aupimo_result.thresh_upper_bound + assert anomaly_maps.min() <= thresh_lower_bound < thresh_upper_bound <= anomaly_maps.max() + + # metric interface + metric = pimo.AUPIMO( + num_thresholds=7, + fpr_bounds=fpr_bounds, + return_average=False, + force=True, + ) + metric.update(anomaly_maps, masks) + pimo_result_from_metric, aupimo_result_from_metric = metric.compute() + do_assertions(pimo_result_from_metric, aupimo_result_from_metric) + + # metric interface + metric = pimo.AUPIMO( + num_thresholds=7, + fpr_bounds=fpr_bounds, + return_average=True, # only return the average AUPIMO + force=True, + ) + metric.update(anomaly_maps, masks) + metric.compute() + + +def test_aupimo_edge( + anomaly_maps: torch.Tensor, + masks: torch.Tensor, + fpr_bounds: tuple[float, float], + caplog: pytest.LogCaptureFixture, +) -> None: + """Test some edge cases.""" + # None is the case of testing the default bounds + fpr_bounds = {"fpr_bounds": fpr_bounds} if fpr_bounds is not None else {} + + # not enough points on the curve + # 10 thresholds / 6 decades = 1.6 thresholds per decade < 3 + with pytest.raises(RuntimeError): # force=False --> raise error + functional.aupimo_scores( + anomaly_maps, + masks, + num_thresholds=10, + force=False, + **fpr_bounds, + ) + + with caplog.at_level(logging.WARNING): # force=True --> warn + functional.aupimo_scores( + anomaly_maps, + masks, + num_thresholds=10, + force=True, + **fpr_bounds, + ) + assert "Computation was forced!" in caplog.text + + # default number of points on the curve (300k thresholds) should be enough + torch.manual_seed(42) + functional.aupimo_scores( + anomaly_maps * torch.FloatTensor(anomaly_maps.shape).uniform_(1.0, 1.1), + masks, + force=False, + **fpr_bounds, + ) diff --git a/third-party-programs.txt b/third-party-programs.txt index 3155b2a930..5eeaca8ea9 100644 --- a/third-party-programs.txt +++ b/third-party-programs.txt @@ -42,3 +42,7 @@ terms are listed below. 7. CLIP neural network used for deep feature extraction in AI-VAD model Copyright (c) 2022 @openai, https://github.com/openai/CLIP. SPDX-License-Identifier: MIT + +8. AUPIMO metric implementation is based on the original code + Copyright (c) 2023 @jpcbertoldo, https://github.com/jpcbertoldo/aupimo + SPDX-License-Identifier: MIT