diff --git a/didactic/apps/negative_perturbation_test.py b/didactic/apps/negative_perturbation_test.py new file mode 100644 index 0000000..a1ffab1 --- /dev/null +++ b/didactic/apps/negative_perturbation_test.py @@ -0,0 +1,233 @@ +import itertools +import logging +from typing import Sequence + +import pandas as pd +import torch +from scipy.special import softmax +from sklearn.metrics import roc_auc_score +from tqdm.auto import tqdm +from vital.data.cardinal.config import CardinalTag, TabularAttribute +from vital.data.cardinal.datapipes import process_patient +from vital.data.cardinal.utils.itertools import Patients +from vital.utils.format.torch import numpy_to_torch + +from didactic.models.explain import SelfAttentionGenerator +from didactic.tasks.cardiac_multimodal_representation import CardiacMultimodalRepresentationTask + +logger = logging.getLogger(__name__) + + +def compute_attributes_relevance( + model: CardiacMultimodalRepresentationTask, + patients: Patients, + relevancy_target: TabularAttribute, + mask_tag: str = CardinalTag.mask, + progress_bar: bool = False, +) -> pd.DataFrame: + """Computes the relevance of each input attribute, averaged over all patients. + + Args: + model: Transformer encoder model to use for inference. + patients: Patients on which to compute the relevance. + relevancy_target: Target attribute w.r.t. which to compute the relevancy computation. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. + progress_bar: If ``True``, enables progress bars detailing the progress of the processing and encoding patients + data. + + Returns: + A dataframe containing the relevance of each attribute, for each patient. + """ + tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs + patients_data = ( + process_patient( + patient, + tabular_attrs=tab_attrs + (relevancy_target,), + time_series_attrs=time_series_attrs, + mask_tag=mask_tag, + ) + for patient in patients.values() + ) + + relevancy_gen = SelfAttentionGenerator(model) + attr_tags = model.token_tags[:-1] # Exclude the CLS token + + msg = "Collecting attention maps from patients" + if progress_bar: + patients_data = tqdm(patients_data, desc=msg, unit="patient", total=len(patients)) + else: + logger.info(msg + "...") + + relevancy_data = {} + for patient_data in patients_data: + # Separate the tabular and time-series attributes + add a batch dimension + patient_tab_attrs = numpy_to_torch({attr: patient_data[attr][None, ...] for attr in tab_attrs}) + patient_time_series_attrs = numpy_to_torch( + { + (view, attr): patient_data[view][attr][None, ...] + for view in model.hparams.views + for attr in time_series_attrs + } + ) + target_label = torch.from_numpy(patient_data[relevancy_target][None, ...]) + + relevancy_by_attr = ( + relevancy_gen.generate_relevancy( + patient_tab_attrs, patient_time_series_attrs, target_label, relevancy_target + )[0] + .cpu() + .numpy() + ) + + relevancy_data[patient_data["id"]] = { + attr: attr_relevancy for attr, attr_relevancy in zip(attr_tags, relevancy_by_attr) + } + + return pd.DataFrame(relevancy_data) + + +def run_perturbation_test( + model: CardiacMultimodalRepresentationTask, + patients: Patients, + target: TabularAttribute, + attrs_to_remove: Sequence[TabularAttribute], + mask_tag: str = CardinalTag.mask, + progress_bar: bool = False, +) -> pd.Series: + """Performs perturbation tests and computes model's AUROC score the more attributes are removed. + + Args: + model: Transformer encoder model to use for inference. + patients: Patients on which to compute measure the model's AUROC score. + target: Target attribute w.r.t. which to compute the model's AUROC score. + attrs_to_remove: Attributes in the order in which to gradually remove them. + mask_tag: Tag of the segmentation mask for which to extract the time-series attributes. + progress_bar: If ``True``, enables progress bars detailing the progress of how many attributes are left to + perturb. + + Returns: + A series containing the AUROC score of the model for each further attribute removed. + """ + tab_attrs, time_series_attrs = model.hparams.tabular_attrs, model.hparams.time_series_attrs + + cumulated_attrs_to_remove = itertools.accumulate( + attrs_to_remove, lambda removed_attrs, next_attr: removed_attrs + [next_attr], initial=[] + ) + msg = "Removing gradually more attributes" + if progress_bar: + cumulated_attrs_to_remove = tqdm( + cumulated_attrs_to_remove, desc=msg, unit="attr", total=len(attrs_to_remove) + 1 + ) + else: + logger.info(msg + "...") + + attrs_perturbation_scores = {} + for cur_attrs_to_remove in cumulated_attrs_to_remove: + patients_data = ( + process_patient( + patient, + tabular_attrs=tab_attrs + [target], + time_series_attrs=time_series_attrs, + mask_tag=mask_tag, + mask_attrs=cur_attrs_to_remove, + ) + for patient in patients.values() + ) + + pred_logits = {} + target_labels = {} + for patient_data in patients_data: + # Separate the tabular and time-series attributes + add a batch dimension + patient_tab_attrs = numpy_to_torch({attr: patient_data[attr][None, ...] for attr in tab_attrs}) + patient_time_series_attrs = numpy_to_torch( + { + (view, attr): patient_data[view][attr][None, ...] + for view in model.hparams.views + for attr in time_series_attrs + } + ) + + # Compute the predicted probabilities for each class for the target attribute + pred_logits[patient_data["id"]] = ( + model(patient_tab_attrs, patient_time_series_attrs, task="predict")[target][0].detach().cpu().numpy() + ) + + target_labels[patient_data["id"]] = patient_data[target].item() + + # Compute the ROC AUC score for the target attribute + pred_logits = pd.DataFrame.from_dict(pred_logits, orient="index") + target_labels = pd.Series(target_labels) + attrs_perturbation_scores[cur_attrs_to_remove[-1] if cur_attrs_to_remove else "none"] = roc_auc_score( + target_labels, softmax(pred_logits.to_numpy(), axis=1), multi_class="ovr" + ) + + return pd.Series(attrs_perturbation_scores) + + +def main(): + """Run the script.""" + import argparse + from pathlib import Path + + from vital.utils.logging import configure_logging + from vital.utils.saving import load_from_checkpoint + + configure_logging(log_to_console=True, console_level=logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument( + "pretrained_encoder", + type=Path, + help="Path to a model checkpoint, or name of a model from a Comet model registry, of an encoder", + ) + parser = Patients.add_args(parser) + parser.add_argument( + "--relevancy_target", + type=TabularAttribute, + default=TabularAttribute.ht_severity, + help="Target attribute w.r.t. which to compute the relevancy computation", + ) + parser.add_argument( + "--mask_tag", + type=str, + default=CardinalTag.mask, + help="Tag of the segmentation mask for which to extract the time-series attributes", + ) + parser.add_argument("--output_dir", type=Path, default=Path.cwd(), help="Directory to save the output files") + args = parser.parse_args() + kwargs = vars(args) + + encoder_ckpt, relevancy_target, mask_tag, output_dir = ( + kwargs.pop("pretrained_encoder"), + kwargs.pop("relevancy_target"), + kwargs.pop("mask_tag"), + kwargs.pop("output_dir"), + ) + + encoder = load_from_checkpoint(encoder_ckpt, expected_checkpoint_type=CardiacMultimodalRepresentationTask) + patients = Patients(**kwargs) + + # Compute the relevancy for all attributes and all patients + relevancy_df = compute_attributes_relevance( + encoder, patients, relevancy_target, mask_tag=mask_tag, progress_bar=True + ) + relevancy_df = relevancy_df.mean(axis="columns") # Average the relevancy over all patients + attrs_sorted_by_relevance = relevancy_df.sort_values() + + # Run the negative perturbation test + attrs_perturbation_scores = run_perturbation_test( + encoder, patients, relevancy_target, attrs_sorted_by_relevance.index, mask_tag=mask_tag, progress_bar=True + ) + + # Save the results to disk + output_dir.mkdir(parents=True, exist_ok=True) + attrs_sorted_by_relevance.to_csv(output_dir / "attributes_relevance.csv") + attrs_perturbation_scores.to_csv(output_dir / "negative_perturbation_test.csv") + # with sns.axes_style("darkgrid"): + # x = np.linspace(0, 100, len(attrs_perturbation_scores)) + # plot = sns.lineplot(x=x, y=attrs_perturbation_scores.values()) + # plot.save(output_dir / "negative_perturbation_test.png", bbox_inches="tight") + + +if __name__ == "__main__": + main()