Skip to content

Commit

Permalink
Implement AUROC metric for generic cardiac records stratification task
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Jul 1, 2024
1 parent 97ec873 commit 4f95be5
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions didactic/tasks/cardiac_records_stratification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import logging
from typing import Sequence, Tuple
from typing import Dict, Sequence, Tuple

import hydra
import numpy as np
Expand All @@ -9,6 +9,7 @@
from omegaconf import DictConfig
from pytorch_lightning.trainer.states import TrainerFn
from sklearn.base import ClassifierMixin
from sklearn.metrics import accuracy_score, roc_auc_score
from vital.data.cardinal.config import TabularAttribute
from vital.data.cardinal.data_module import CardinalDataModule
from vital.data.cardinal.datapipes import MISSING_CAT_ATTR
Expand All @@ -31,6 +32,10 @@ def __init__(
tabular_attrs: List of tabular attributes to use as input features for the classifier.
target_attr: Tabular attribute to use as the target label for the classifier.
"""
for method in ["fit", "predict_proba", "save_model"]:
if not callable(getattr(model, method, None)):
raise ValueError(f"Model must implement method: {method}")

self.model = model

# Ensure string tags are converted to their appropriate enum types
Expand Down Expand Up @@ -102,20 +107,26 @@ def fit(self, data: CardinalDataModule) -> "CardiacRecordsStratificationTask":

return self

def score(self, data: CardinalDataModule) -> float:
def score(self, data: CardinalDataModule) -> Dict[str, float]:
"""Measure the model's performance on the test set.
Args:
data: CARDINAL data module.
Returns:
Model's accuracy on the test set.
Dictionary of model's metrics (e.g. accuracy, AUROC, etc.) on the test set.
"""
# Extract the tabular data (i.e. inputs and target labels) from the test set
X, y = self._prepare_data_subset(data, "test")

# Compute the model's performance on the test set
return self.model.score(X, y)
# Perform inference on the test set, keeping intermediate predictions (i.e. class probabilities)
predict_proba = self.model.predict_proba(X)
y_hat = np.argmax(predict_proba, axis=1)

# Compute the model's performance metrics
scores = {"acc": accuracy_score(y, y_hat), "auroc": roc_auc_score(y, predict_proba, multi_class="ovr")}

return scores

def save(self, path: str) -> None:
"""Save the model to disk.
Expand Down Expand Up @@ -145,18 +156,18 @@ def main(cfg: DictConfig):
task.fit(data)

# Evaluate the model's performance on the test set
score = task.score(data)
scores = task.score(data)

hydra_output_dir = Path(HydraConfig.get().runtime.output_dir)

# Save the model
task.save(hydra_output_dir / cfg.model_ckpt)

# Save the model's performance
score_df = pd.Series({"acc": score})
score_df = pd.Series(scores)
score_df.to_csv(hydra_output_dir / cfg.scores_filename, header=["value"])

logger.info(f"Logging model and its score ({score:.2%}) to {hydra_output_dir}")
logger.info(f"Logging model and its scores: {scores} to {hydra_output_dir}")


if __name__ == "__main__":
Expand Down

0 comments on commit 4f95be5

Please sign in to comment.