Skip to content

Commit

Permalink
[BUGFIX] Restored handling of ROCM FA output as before adaptation of …
Browse files Browse the repository at this point in the history
…llama3.2 (#241)

* improved handling of output to be the same as before

* after merge correction

---------

Co-authored-by: Aleksandr Malyshev <[email protected]>
  • Loading branch information
maleksan85 and Aleksandr Malyshev authored Oct 23, 2024
1 parent 16cedce commit 69d5e1d
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,18 +610,17 @@ def forward(
assert attn_metadata.num_encoder_tokens is not None
num_prefill_tokens = attn_metadata.num_encoder_tokens

output = torch.empty_like(query)
# Query for decode. KV is not needed because it is already cached.
decode_query = query[num_prefill_tokens:]

# QKV for prefill.
query = query[:num_prefill_tokens]

if key is not None and value is not None:
key = key[:num_prefill_tokens]
value = value[:num_prefill_tokens]

if prefill_meta := attn_metadata.prefill_metadata:
output = torch.empty_like(query)

# Prompt run.
# normal attention and DECODER
if attn_type == AttentionType.DECODER and (
Expand Down Expand Up @@ -738,7 +737,6 @@ def forward(
if decode_meta := attn_metadata.decode_metadata:
# Decoding run.
# Whether to use rocm custom paged attention or not
output = torch.empty_like(decode_query)
num_seqs, num_heads, head_size = decode_query.shape
block_size = value_cache.shape[3]
gqa_ratio = num_heads // self.num_kv_heads
Expand Down

0 comments on commit 69d5e1d

Please sign in to comment.