Skip to content

Add return_attention_scores support to CachedMultiHeadAttention #2213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

DakshBegani
Copy link

This PR addresses #2055, where attention_scores was always None due to the _return_attention_scores flag not being set in the CachedMultiHeadAttention subclass.

In recent Keras versions, the base MultiHeadAttention layer uses a private flag self._return_attention_scores to decide whether or not to return attention scores from _compute_attention.

However, CachedMultiHeadAttention was not passing or setting this flag at all, which meant attention_scores were silently dropped — making them inaccessible for debugging or analysis.

In this PR we did the following-
1.Adds return_attention_scores as an optional argument to the constructor (default False, just like in base MHA).
2.Sets self._return_attention_scores appropriately.
3.Updates the call() method to return attention_scores alongside attention_output and cache when requested — fully preserving existing behavior otherwise.

@divyashreepathihalli divyashreepathihalli added the kokoro:force-run Runs Tests on GPU label Apr 15, 2025
@kokoro-team kokoro-team removed the kokoro:force-run Runs Tests on GPU label Apr 15, 2025
# Call the parent class constructor
super().__init__(num_heads, key_dim, **kwargs)
# New flag to optionally return attention scores
self._return_attention_scores = return_attention_scores
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a call arg in the super class here - https://github.com/keras-team/keras/blob/44a655bdb28037046ab279a49d4cd679fea7ca50/keras/src/layers/attention/multi_head_attention.py#L523

Also if flash attention is used using ops.dot_product_attention then attention scores will not be returned

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a call arg in the super class here

makes sense now — instead of manually setting the flag, it made sense to just pass return_attention_scores into super().init() since the base MHA layer handles it internally.

I’ve pushed the fix with that change; let me know if further changes are needed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants