From 4f95be55c688b2f4812ad1bd9013946b757a0bcc Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Mon, 1 Jul 2024 16:49:02 +0200 Subject: [PATCH] Implement AUROC metric for generic cardiac records stratification task --- .../tasks/cardiac_records_stratification.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/didactic/tasks/cardiac_records_stratification.py b/didactic/tasks/cardiac_records_stratification.py index 6d3b599..b8d31b3 100644 --- a/didactic/tasks/cardiac_records_stratification.py +++ b/didactic/tasks/cardiac_records_stratification.py @@ -1,6 +1,6 @@ import itertools import logging -from typing import Sequence, Tuple +from typing import Dict, Sequence, Tuple import hydra import numpy as np @@ -9,6 +9,7 @@ from omegaconf import DictConfig from pytorch_lightning.trainer.states import TrainerFn from sklearn.base import ClassifierMixin +from sklearn.metrics import accuracy_score, roc_auc_score from vital.data.cardinal.config import TabularAttribute from vital.data.cardinal.data_module import CardinalDataModule from vital.data.cardinal.datapipes import MISSING_CAT_ATTR @@ -31,6 +32,10 @@ def __init__( tabular_attrs: List of tabular attributes to use as input features for the classifier. target_attr: Tabular attribute to use as the target label for the classifier. """ + for method in ["fit", "predict_proba", "save_model"]: + if not callable(getattr(model, method, None)): + raise ValueError(f"Model must implement method: {method}") + self.model = model # Ensure string tags are converted to their appropriate enum types @@ -102,20 +107,26 @@ def fit(self, data: CardinalDataModule) -> "CardiacRecordsStratificationTask": return self - def score(self, data: CardinalDataModule) -> float: + def score(self, data: CardinalDataModule) -> Dict[str, float]: """Measure the model's performance on the test set. Args: data: CARDINAL data module. Returns: - Model's accuracy on the test set. + Dictionary of model's metrics (e.g. accuracy, AUROC, etc.) on the test set. """ # Extract the tabular data (i.e. inputs and target labels) from the test set X, y = self._prepare_data_subset(data, "test") - # Compute the model's performance on the test set - return self.model.score(X, y) + # Perform inference on the test set, keeping intermediate predictions (i.e. class probabilities) + predict_proba = self.model.predict_proba(X) + y_hat = np.argmax(predict_proba, axis=1) + + # Compute the model's performance metrics + scores = {"acc": accuracy_score(y, y_hat), "auroc": roc_auc_score(y, predict_proba, multi_class="ovr")} + + return scores def save(self, path: str) -> None: """Save the model to disk. @@ -145,7 +156,7 @@ def main(cfg: DictConfig): task.fit(data) # Evaluate the model's performance on the test set - score = task.score(data) + scores = task.score(data) hydra_output_dir = Path(HydraConfig.get().runtime.output_dir) @@ -153,10 +164,10 @@ def main(cfg: DictConfig): task.save(hydra_output_dir / cfg.model_ckpt) # Save the model's performance - score_df = pd.Series({"acc": score}) + score_df = pd.Series(scores) score_df.to_csv(hydra_output_dir / cfg.scores_filename, header=["value"]) - logger.info(f"Logging model and its score ({score:.2%}) to {hydra_output_dir}") + logger.info(f"Logging model and its scores: {scores} to {hydra_output_dir}") if __name__ == "__main__":