Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for (vision) transformers, add options to set last-layer relevance #15

Draft
wants to merge 26 commits into
base: main
Choose a base branch
from

Conversation

Maximilian-Stefan-Ernst
Copy link
Contributor

Vision Transformers

Add explanations for (vision) transformer (ViT) models by adding the package extension VisionTransformerExt that depends on Metalhead.jl.

Adds the rules

  • SelfAttentionRule for MultiHeadSelfAttention layers
  • PositionalEmbeddingRule for ViPosEmbedding layers

Also adds support for some special layers of vision transformers by adding a method for ZeroRule:

  • _flatten_spatial is a reshaping layer near the input
  • ClassTokens adds a class token to the model
  • SelectClassToken only retains the class token for the model prediction (this layer was added by me because Metalhead uses an anonymous function for this purpose, that we have to swap for a "real" layer before explaining the model)

So far, no support for Flux.jl's built-in MultiHeadAttention layer was added, because this layer does not work nicely with Chains (Metalhead's MultiHeadSelfAttention layer is not limited to vision transformer, but can also be used to build "regular" transformer models as long as they use only self-attention, e.g. encoder-only models, something like BERT should be doable).

In addition, the function prepare_vit can be used to prepare Metalhead's ViT (convert it to a Chain, add SelectClassToken layer).

Last layer relevance

Adds the keyword arguments normalize_output=true, R=nothing to LRP. If R is supplied, the relevances in the last layer are set to R. If normalize_output is false, the target neuron activation is not set to one, but remains the "raw" activation from the forward pass.

Canonization

This PR already contains the changes of PR #14, because otherwise canonization of ViT models does not work properly - so PR #14 should be merged first.

ToDo

  • add tests
  • add docs

For documentation, I guess it would be nice to have an extra "Extensions" section in the docs, and have a small tutorial under "Extensions => Vision Transformer".

Copy link

codecov bot commented Mar 16, 2024

Codecov Report

Attention: Patch coverage is 0% with 61 lines in your changes are missing coverage. Please review.

Project coverage is 0.00%. Comparing base (7b2af98) to head (79b2d40).

Files Patch % Lines
ext/RelevancePropagationMetalheadExt/rules.jl 0.00% 51 Missing ⚠️
ext/RelevancePropagationMetalheadExt/utils.jl 0.00% 9 Missing ⚠️
src/rules.jl 0.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #15       +/-   ##
==========================================
- Coverage   96.66%   0.00%   -96.67%     
==========================================
  Files          14      15        +1     
  Lines         660     698       +38     
==========================================
- Hits          638       0      -638     
- Misses         22     698      +676     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@Maximilian-Stefan-Ernst
Copy link
Contributor Author

Maybe move prepare_vit to canonize

Project.toml Outdated Show resolved Hide resolved
ext/VisionTransformerExt/VisionTransformerExt.jl Outdated Show resolved Hide resolved
ext/VisionTransformerExt/utils.jl Outdated Show resolved Hide resolved
src/canonize.jl Outdated Show resolved Hide resolved
src/extensions.jl Outdated Show resolved Hide resolved
Comment on lines 10 to 11
struct SelectClassToken end
Flux.@functor SelectClassToken
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XAIBase exports generic feature selectors.
Maybe these could be used here and extended for transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do that, but I don't get how these feature selectors are supposed to be used in a model / why there are no rules for them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay as we discussed, it does not really make sense to use the feature selectors. I think the remaining question is where you want to define new layers in the codebase - maybe an extra file src/layers.jl?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay as we discussed, it does not really make sense to use the feature selectors.

Sorry, it's been a while... Can you remind me what the exact issue was? 😅
I can vaguely remember it was something that should go in XAIBase.jl.

Similar to this: https://github.com/Julia-XAI/XAIBase.jl/blob/main/src/feature_selection.jl

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so vision transformers have a special token that is selected near the output, and all other tokens are discarded. They implement this in Metalhead through an anonymous function, so we can't use it for computing LRP. What I did was implementing this simple Flux layer (and an associated rule), that is swapped in for the anonymous function. The problem with the feature selector is that we need an actual layer, so I think we decided that this is probably not the right place ^^

src/extensions.jl Outdated Show resolved Hide resolved
src/lrp.jl Outdated Show resolved Hide resolved
test/test_canonize.jl Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants