diff --git a/keras_hub/src/layers/modeling/cached_multi_head_attention.py b/keras_hub/src/layers/modeling/cached_multi_head_attention.py index 0441e71845..e8d00d6b42 100644 --- a/keras_hub/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_hub/src/layers/modeling/cached_multi_head_attention.py @@ -63,7 +63,13 @@ class CachedMultiHeadAttention(keras.layers.MultiHeadAttention): projected to the shape specified by `output_shape`. `cache` is the updated cache. """ - + def __init__(self, num_heads, key_dim, return_attention_scores=False, **kwargs): + super().__init__( + num_heads, + key_dim, + return_attention_scores=return_attention_scores, + **kwargs, + ) def call( self, query, @@ -118,6 +124,12 @@ def call( attention_output = self._output_dense(attention_output) - if cache is not None: - return attention_output, cache - return attention_output + # Returning updated logic to support attention_scores if requested + if self._return_attention_scores: + if cache is not None: + return attention_output, attention_scores, cache + return attention_output, attention_scores + else: + if cache is not None: + return attention_output, cache + return attention_output