From dafed1e215f1fc884e3ee1037df0df43e35063b3 Mon Sep 17 00:00:00 2001 From: Nathan Painchaud Date: Thu, 18 Jul 2024 19:48:48 +0200 Subject: [PATCH] Add impl. of Chefer et al.'s relevancy score for self-attention models Link to paper: https://arxiv.org/abs/2103.15679 --- didactic/models/explain.py | 82 +++++++++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 1 deletion(-) diff --git a/didactic/models/explain.py b/didactic/models/explain.py index 4332b2a3..01b29ce0 100644 --- a/didactic/models/explain.py +++ b/didactic/models/explain.py @@ -1,9 +1,89 @@ import functools import itertools -from typing import Dict, Literal, Optional, Sequence +from typing import Dict, Literal, Optional, Sequence, Tuple import torch from torch import Tensor, nn +from torchmetrics.utilities.data import to_onehot +from vital.data.cardinal.config import TabularAttribute, TimeSeriesAttribute +from vital.data.cardinal.config import View as ViewEnum + +from didactic.tasks.cardiac_multimodal_representation import CardiacMultimodalRepresentationTask + + +class SelfAttentionGenerator: + """Computes attention w.r.t. input tokens for transformer models, using various attention attribution techniques.""" + + def __init__(self, model: CardiacMultimodalRepresentationTask): + """Initializes class instance. + + Args: + model: Model for which to generate attention scores. + """ + self.model = model.eval() + + def generate_relevancy( + self, + tabular_attrs: Dict[TabularAttribute, Tensor], + time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor], + target_labels: Tensor, + target_attr: TabularAttribute, + ) -> Tensor: + """Compute the relevancy formulation w.r.t. a target class, as proposed by Chefer et al. + + References: + - Paper by Chefer et al. proposing the relevancy formulation: https://arxiv.org/abs/2103.15679 + + Args: + tabular_attrs: (K: S, V: N) Sequence of batches of tabular attributes. To indicate an item is missing an + attribute, the flags `MISSING_NUM_ATTR`/`MISSING_CAT_ATTR` can be used for numerical and categorical + attributes, respectively. + time_series_attrs: (K: S, V: (N, ?)), Sequence of batches of time-series attributes, where the + dimensionality of each attribute can vary. + target_labels: (N), Target labels for the batch of samples. + target_attr: Target attribute for which to generate the attention score. + + Returns: + (N, S-1), Attention scores w.r.t. the target class for each token in the sequence, excluding the CLS token. + """ + batch_size = len(target_labels) + + # Extract the model's predicted probabilities on the target class + output = self.model(tabular_attrs, time_series_attrs, task="predict")[target_attr] + target_labels = to_onehot(target_labels, num_classes=output.shape[-1]) + target_labels = target_labels.to(output.device, dtype=float).requires_grad_(True) + base_output = torch.sum(output * target_labels) + + # Compute the gradients of the target class probabilities w.r.t. the model's parameters + self.model.zero_grad() + base_output.backward(retain_graph=True) + + # Initialize the relevancy score as the identity matrix + blocks = self.model.encoder.blocks + num_tokens = blocks[0].attention.attn.shape[-1] + R = torch.eye(num_tokens, num_tokens).to(blocks[0].attention.attn.device) + R = R.repeat(batch_size, 1, 1) + + # Update the relevancy score along the attention layers in the model + for blk in blocks: + grad = blk.attention.attn_grad # (N * n_heads, S, S) + cam = blk.attention.attn # (N * n_heads, S, S) + cam = grad * cam # (N * n_heads, S, S) + cam = cam.clamp(min=0) # Remove negative attributions + # Split the tensor to separate the batch and heads dimensions -> (N, n_heads, S, S) + # then average across the heads dimension -> (N, S, S) + cam = torch.stack(cam.split(batch_size, dim=0), dim=0).mean(dim=0) + R += torch.matmul(cam, R) + + # Normalize the relevancy score per row, following equations 8 and 9 from the paper + identity_rel = torch.eye(R.shape[-1], device=R.device) + R_hat = R - identity_rel + R = (R_hat / R_hat.sum(dim=-1, keepdim=True)) + identity_rel + + # Take the attention of the CLS token w.r.t. other tokens (2nd dim indexing), + # dropping its attention on itself at the same time (3rd dim indexing) + cls_per_token_score = R[:, -1, :-1].detach() + return cls_per_token_score def _patch_attn(attn_module: nn.MultiheadAttention) -> nn.MultiheadAttention: