Add return_attention_scores support to CachedMultiHeadAttention #2213
+16
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR addresses #2055, where attention_scores was always None due to the _return_attention_scores flag not being set in the CachedMultiHeadAttention subclass.
In recent Keras versions, the base MultiHeadAttention layer uses a private flag self._return_attention_scores to decide whether or not to return attention scores from _compute_attention.
However, CachedMultiHeadAttention was not passing or setting this flag at all, which meant attention_scores were silently dropped — making them inaccessible for debugging or analysis.
In this PR we did the following-
1.Adds return_attention_scores as an optional argument to the constructor (default False, just like in base MHA).
2.Sets self._return_attention_scores appropriately.
3.Updates the call() method to return attention_scores alongside attention_output and cache when requested — fully preserving existing behavior otherwise.