Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Visualizing attention map #19

Open
ParnianA opened this issue Jun 10, 2021 · 4 comments
Open

Visualizing attention map #19

ParnianA opened this issue Jun 10, 2021 · 4 comments

Comments

@ParnianA
Copy link

Hi. Does anyone know how we can have access to attention maps?

@tolaut
Copy link

tolaut commented Jun 21, 2021

I'm trying to figure out the same thing

@gouttham
Copy link

gouttham commented Aug 25, 2021

Using the below code I was able to visualize the attention maps.

Step 1:
In transformer.py under class MultiHeadedSelfAttention(nn.Module): replace the forward method with the below code

def forward(self, x, mask):
    """
    x, q(query), k(key), v(value) : (B(batch_size), S(seq_len), D(dim))
    mask : (B(batch_size) x S(seq_len))
    * split D(dim) into (H(n_heads), W(width of head)) ; D = H * W
    """
    # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
    q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
    q, k, v = (split_last(x, (self.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
    # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S) -softmax-> (B, H, S, S)
    scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
    if mask is not None:
        mask = mask[:, None, None, :].float()
        scores -= 10000.0 * (1.0 - mask)
    scores = self.drop(F.softmax(scores, dim=-1))
    # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
    h = (scores @ v).transpose(1, 2).contiguous()
    # -merge-> (B, S, D)
    h = merge_last(h, 2)
    self.scores = scores
    return h

Step 2:
In the Transformer.py under class Transformer(nn.Module) replace the forward method with the below code

def forward(self, x, mask=None):
    atten_scores = []
    for block in self.blocks:
        x = block(x, mask)
        atten_scores.append(block.attn.scores)
    return x,atten_scores

Step 3:
In model.py under class 'class ViT(nn.Module)' replace the forward method with the below code

def forward(self, x):
    b, c, fh, fw = x.shape
    x = self.patch_embedding(x)  # b,d,gh,gw
    x = x.flatten(2).transpose(1, 2)  # b,gh*gw,d
    if hasattr(self, 'class_token'):
        x = torch.cat((self.class_token.expand(b, -1, -1), x), dim=1)  # b,gh*gw+1,d
    if hasattr(self, 'positional_embedding'): 
        x = self.positional_embedding(x)  # b,gh*gw+1,d 
    x,atten_scores = self.transformer(x)  # b,gh*gw+1,d
    att_mat = torch.stack(atten_scores).squeeze(1)
    att_mat = torch.mean(att_mat, dim=1)
    # print("att_mat",att_mat.shape)
    if hasattr(self, 'pre_logits'):
        x = self.pre_logits(x)
        x = torch.tanh(x)
    if hasattr(self, 'fc'):
        x = self.norm(x)[:, 0]  # b,d
        x = self.fc(x)  # b,num_classes
    return x,att_mat

Step 4:
Now in forward pass will return output of MLP layer and the activation map.
x,atten_weights = model.forward(input_image.unsqueeze(0))
here atten_weights will contain the activation maps

Step 5:
Iterate through each atten_weights and visualize those

from PIL import Image
import matplotlib.pyplot as plt
im = Image.open(img_pth)

for att_mat in atten_weights:
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]
    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n - 1])
    v = joint_attentions
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0,1:].reshape(grid_size, grid_size).detach().numpy()
    mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
    result = (mask * im).astype("uint8")
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map')
    _ = ax1.imshow(im)
    _ = ax2.imshow(result)

@kiashann
Copy link

Could you please share final code or any colab demo for extract attention map
@gouttham gouttham

@IJS1016
Copy link

IJS1016 commented Apr 20, 2022

Could you please share final code or any colab demo for extract attention map @gouttham gouttham

https://github.com/jeonsworld/ViT-pytorch/blob/main/visualize_attention_map.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants