Skip to content

Commit

Permalink
Add impl. of Chefer et al.'s relevancy score for self-attention models
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanpainchaud committed Jul 18, 2024
1 parent 3d34f7c commit dafed1e
Showing 1 changed file with 81 additions and 1 deletion.
82 changes: 81 additions & 1 deletion didactic/models/explain.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,89 @@
import functools
import itertools
from typing import Dict, Literal, Optional, Sequence
from typing import Dict, Literal, Optional, Sequence, Tuple

import torch
from torch import Tensor, nn
from torchmetrics.utilities.data import to_onehot
from vital.data.cardinal.config import TabularAttribute, TimeSeriesAttribute
from vital.data.cardinal.config import View as ViewEnum

from didactic.tasks.cardiac_multimodal_representation import CardiacMultimodalRepresentationTask


class SelfAttentionGenerator:
"""Computes attention w.r.t. input tokens for transformer models, using various attention attribution techniques."""

def __init__(self, model: CardiacMultimodalRepresentationTask):
"""Initializes class instance.
Args:
model: Model for which to generate attention scores.
"""
self.model = model.eval()

def generate_relevancy(
self,
tabular_attrs: Dict[TabularAttribute, Tensor],
time_series_attrs: Dict[Tuple[ViewEnum, TimeSeriesAttribute], Tensor],
target_labels: Tensor,
target_attr: TabularAttribute,
) -> Tensor:
"""Compute the relevancy formulation w.r.t. a target class, as proposed by Chefer et al.
References:
- Paper by Chefer et al. proposing the relevancy formulation: https://arxiv.org/abs/2103.15679
Args:
tabular_attrs: (K: S, V: N) Sequence of batches of tabular attributes. To indicate an item is missing an
attribute, the flags `MISSING_NUM_ATTR`/`MISSING_CAT_ATTR` can be used for numerical and categorical
attributes, respectively.
time_series_attrs: (K: S, V: (N, ?)), Sequence of batches of time-series attributes, where the
dimensionality of each attribute can vary.
target_labels: (N), Target labels for the batch of samples.
target_attr: Target attribute for which to generate the attention score.
Returns:
(N, S-1), Attention scores w.r.t. the target class for each token in the sequence, excluding the CLS token.
"""
batch_size = len(target_labels)

# Extract the model's predicted probabilities on the target class
output = self.model(tabular_attrs, time_series_attrs, task="predict")[target_attr]
target_labels = to_onehot(target_labels, num_classes=output.shape[-1])
target_labels = target_labels.to(output.device, dtype=float).requires_grad_(True)
base_output = torch.sum(output * target_labels)

# Compute the gradients of the target class probabilities w.r.t. the model's parameters
self.model.zero_grad()
base_output.backward(retain_graph=True)

# Initialize the relevancy score as the identity matrix
blocks = self.model.encoder.blocks
num_tokens = blocks[0].attention.attn.shape[-1]
R = torch.eye(num_tokens, num_tokens).to(blocks[0].attention.attn.device)
R = R.repeat(batch_size, 1, 1)

# Update the relevancy score along the attention layers in the model
for blk in blocks:
grad = blk.attention.attn_grad # (N * n_heads, S, S)
cam = blk.attention.attn # (N * n_heads, S, S)
cam = grad * cam # (N * n_heads, S, S)
cam = cam.clamp(min=0) # Remove negative attributions
# Split the tensor to separate the batch and heads dimensions -> (N, n_heads, S, S)
# then average across the heads dimension -> (N, S, S)
cam = torch.stack(cam.split(batch_size, dim=0), dim=0).mean(dim=0)
R += torch.matmul(cam, R)

# Normalize the relevancy score per row, following equations 8 and 9 from the paper
identity_rel = torch.eye(R.shape[-1], device=R.device)
R_hat = R - identity_rel
R = (R_hat / R_hat.sum(dim=-1, keepdim=True)) + identity_rel

# Take the attention of the CLS token w.r.t. other tokens (2nd dim indexing),
# dropping its attention on itself at the same time (3rd dim indexing)
cls_per_token_score = R[:, -1, :-1].detach()
return cls_per_token_score


def _patch_attn(attn_module: nn.MultiheadAttention) -> nn.MultiheadAttention:
Expand Down

0 comments on commit dafed1e

Please sign in to comment.