From 69d5e1d87b48521759a1f38f657b6aced835e462 Mon Sep 17 00:00:00 2001 From: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Date: Tue, 22 Oct 2024 17:11:36 -0700 Subject: [PATCH] [BUGFIX] Restored handling of ROCM FA output as before adaptation of llama3.2 (#241) * improved handling of output to be the same as before * after merge correction --------- Co-authored-by: Aleksandr Malyshev --- vllm/attention/backends/rocm_flash_attn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9407194cf3fc5..677e3f6bb029c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -610,18 +610,17 @@ def forward( assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens + output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] - # QKV for prefill. query = query[:num_prefill_tokens] + if key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: - output = torch.empty_like(query) - # Prompt run. # normal attention and DECODER if attn_type == AttentionType.DECODER and ( @@ -738,7 +737,6 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. # Whether to use rocm custom paged attention or not - output = torch.empty_like(decode_query) num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads