diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9407194cf3fc5..677e3f6bb029c 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -610,18 +610,17 @@ def forward( assert attn_metadata.num_encoder_tokens is not None num_prefill_tokens = attn_metadata.num_encoder_tokens + output = torch.empty_like(query) # Query for decode. KV is not needed because it is already cached. decode_query = query[num_prefill_tokens:] - # QKV for prefill. query = query[:num_prefill_tokens] + if key is not None and value is not None: key = key[:num_prefill_tokens] value = value[:num_prefill_tokens] if prefill_meta := attn_metadata.prefill_metadata: - output = torch.empty_like(query) - # Prompt run. # normal attention and DECODER if attn_type == AttentionType.DECODER and ( @@ -738,7 +737,6 @@ def forward( if decode_meta := attn_metadata.decode_metadata: # Decoding run. # Whether to use rocm custom paged attention or not - output = torch.empty_like(decode_query) num_seqs, num_heads, head_size = decode_query.shape block_size = value_cache.shape[3] gqa_ratio = num_heads // self.num_kv_heads