Skip to content

Commit

Permalink
WIP: Implement negative perturbation tests guided by attributes' rele…
Browse files Browse the repository at this point in the history
…vancy
  • Loading branch information
nathanpainchaud committed Jul 18, 2024
1 parent dafed1e commit 15e3c88
Showing 1 changed file with 233 additions and 0 deletions.
233 changes: 233 additions & 0 deletions didactic/apps/negative_perturbation_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import itertools
import logging
from typing import Sequence

import pandas as pd
import torch
from scipy.special import softmax
from sklearn.metrics import roc_auc_score
from tqdm.auto import tqdm
from vital.data.cardinal.config import CardinalTag, TabularAttribute
from vital.data.cardinal.datapipes import process_patient
from vital.data.cardinal.utils.itertools import Patients
from vital.utils.format.torch import numpy_to_torch

from didactic.models.explain import SelfAttentionGenerator
from didactic.tasks.cardiac_multimodal_representation import CardiacMultimodalRepresentationTask

logger = logging.getLogger(__name__)


def compute_attributes_relevance(
model: CardiacMultimodalRepresentationTask,
patients: Patients,
relevancy_target: TabularAttribute,
mask_tag: str = CardinalTag.mask,
progress_bar: bool = False,
) -> pd.DataFrame:
"""Computes the relevance of each input attribute, averaged over all patients.
Args:
model: Transformer encoder model to use for inference.
patients: Patients on which to compute the relevance.
relevancy_target: Target attribute w.r.t. which to compute the relevancy computation.
mask_tag: Tag of the segmentation mask for which to extract the time-series attributes.
progress_bar: If ``True``, enables progress bars detailing the progress of the processing and encoding patients
data.
Returns:
A dataframe containing the relevance of each attribute, for each patient.
"""
tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs
patients_data = (
process_patient(
patient,
tabular_attrs=tab_attrs + (relevancy_target,),
time_series_attrs=time_series_attrs,
mask_tag=mask_tag,
)
for patient in patients.values()
)

relevancy_gen = SelfAttentionGenerator(model)
attr_tags = model.token_tags[:-1] # Exclude the CLS token

msg = "Collecting attention maps from patients"
if progress_bar:
patients_data = tqdm(patients_data, desc=msg, unit="patient", total=len(patients))
else:
logger.info(msg + "...")

relevancy_data = {}
for patient_data in patients_data:
# Separate the tabular and time-series attributes + add a batch dimension
patient_tab_attrs = numpy_to_torch({attr: patient_data[attr][None, ...] for attr in tab_attrs})
patient_time_series_attrs = numpy_to_torch(
{
(view, attr): patient_data[view][attr][None, ...]
for view in model.hparams.views
for attr in time_series_attrs
}
)
target_label = torch.from_numpy(patient_data[relevancy_target][None, ...])

relevancy_by_attr = (
relevancy_gen.generate_relevancy(
patient_tab_attrs, patient_time_series_attrs, target_label, relevancy_target
)[0]
.cpu()
.numpy()
)

relevancy_data[patient_data["id"]] = {
attr: attr_relevancy for attr, attr_relevancy in zip(attr_tags, relevancy_by_attr)
}

return pd.DataFrame(relevancy_data)


def run_perturbation_test(
model: CardiacMultimodalRepresentationTask,
patients: Patients,
target: TabularAttribute,
attrs_to_remove: Sequence[TabularAttribute],
mask_tag: str = CardinalTag.mask,
progress_bar: bool = False,
) -> pd.Series:
"""Performs perturbation tests and computes model's AUROC score the more attributes are removed.
Args:
model: Transformer encoder model to use for inference.
patients: Patients on which to compute measure the model's AUROC score.
target: Target attribute w.r.t. which to compute the model's AUROC score.
attrs_to_remove: Attributes in the order in which to gradually remove them.
mask_tag: Tag of the segmentation mask for which to extract the time-series attributes.
progress_bar: If ``True``, enables progress bars detailing the progress of how many attributes are left to
perturb.
Returns:
A series containing the AUROC score of the model for each further attribute removed.
"""
tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs

cumulated_attrs_to_remove = itertools.accumulate(
attrs_to_remove, lambda removed_attrs, next_attr: removed_attrs + [next_attr], initial=[]
)
msg = "Removing gradually more attributes"
if progress_bar:
cumulated_attrs_to_remove = tqdm(
cumulated_attrs_to_remove, desc=msg, unit="attr", total=len(attrs_to_remove) + 1
)
else:
logger.info(msg + "...")

attrs_perturbation_scores = {}
for cur_attrs_to_remove in cumulated_attrs_to_remove:
patients_data = (
process_patient(
patient,
tabular_attrs=tab_attrs + [target],
time_series_attrs=time_series_attrs,
mask_tag=mask_tag,
mask_attrs=cur_attrs_to_remove,
)
for patient in patients.values()
)

pred_logits = {}
target_labels = {}
for patient_data in patients_data:
# Separate the tabular and time-series attributes + add a batch dimension
patient_tab_attrs = numpy_to_torch({attr: patient_data[attr][None, ...] for attr in tab_attrs})
patient_time_series_attrs = numpy_to_torch(
{
(view, attr): patient_data[view][attr][None, ...]
for view in model.hparams.views
for attr in time_series_attrs
}
)

# Compute the predicted probabilities for each class for the target attribute
pred_logits[patient_data["id"]] = (
model(patient_tab_attrs, patient_time_series_attrs, task="predict")[target][0].detach().cpu().numpy()
)

target_labels[patient_data["id"]] = patient_data[target].item()

# Compute the ROC AUC score for the target attribute
pred_logits = pd.DataFrame.from_dict(pred_logits, orient="index")
target_labels = pd.Series(target_labels)
attrs_perturbation_scores[cur_attrs_to_remove[-1] if cur_attrs_to_remove else "none"] = roc_auc_score(
target_labels, softmax(pred_logits.to_numpy(), axis=1), multi_class="ovr"
)

return pd.Series(attrs_perturbation_scores)


def main():
"""Run the script."""
import argparse
from pathlib import Path

from vital.utils.logging import configure_logging
from vital.utils.saving import load_from_checkpoint

configure_logging(log_to_console=True, console_level=logging.INFO)

parser = argparse.ArgumentParser()
parser.add_argument(
"pretrained_encoder",
type=Path,
help="Path to a model checkpoint, or name of a model from a Comet model registry, of an encoder",
)
parser = Patients.add_args(parser)
parser.add_argument(
"--relevancy_target",
type=TabularAttribute,
default=TabularAttribute.ht_severity,
help="Target attribute w.r.t. which to compute the relevancy computation",
)
parser.add_argument(
"--mask_tag",
type=str,
default=CardinalTag.mask,
help="Tag of the segmentation mask for which to extract the time-series attributes",
)
parser.add_argument("--output_dir", type=Path, default=Path.cwd(), help="Directory to save the output files")
args = parser.parse_args()
kwargs = vars(args)

encoder_ckpt, relevancy_target, mask_tag, output_dir = (
kwargs.pop("pretrained_encoder"),
kwargs.pop("relevancy_target"),
kwargs.pop("mask_tag"),
kwargs.pop("output_dir"),
)

encoder = load_from_checkpoint(encoder_ckpt, expected_checkpoint_type=CardiacMultimodalRepresentationTask)
patients = Patients(**kwargs)

# Compute the relevancy for all attributes and all patients
relevancy_df = compute_attributes_relevance(
encoder, patients, relevancy_target, mask_tag=mask_tag, progress_bar=True
)
relevancy_df = relevancy_df.mean(axis="columns") # Average the relevancy over all patients
attrs_sorted_by_relevance = relevancy_df.sort_values()

# Run the negative perturbation test
attrs_perturbation_scores = run_perturbation_test(
encoder, patients, relevancy_target, attrs_sorted_by_relevance.index, mask_tag=mask_tag, progress_bar=True
)

# Save the results to disk
output_dir.mkdir(parents=True, exist_ok=True)
attrs_sorted_by_relevance.to_csv(output_dir / "attributes_relevance.csv")
attrs_perturbation_scores.to_csv(output_dir / "negative_perturbation_test.csv")
# with sns.axes_style("darkgrid"):
# x = np.linspace(0, 100, len(attrs_perturbation_scores))
# plot = sns.lineplot(x=x, y=attrs_perturbation_scores.values())
# plot.save(output_dir / "negative_perturbation_test.png", bbox_inches="tight")


if __name__ == "__main__":
main()

0 comments on commit 15e3c88

Please sign in to comment.