Skip to content

Commit

Permalink
Merge branch 'main' into ci_testing_20241015
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexei-V-Ivanov-AMD authored Oct 22, 2024
2 parents 0aac5d9 + e0b6bb4 commit fc7469e
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,11 +621,25 @@ def forward(

if prefill_meta := attn_metadata.prefill_metadata:
output = torch.empty_like(query)
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = _get_seq_len_block_table_args(
prefill_meta, attn_type)

# Prompt run.
# normal attention and DECODER
if attn_type == AttentionType.DECODER and (
kv_cache.numel() == 0 or prefill_meta.block_tables is None
or prefill_meta.block_tables.numel() == 0):
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = (prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
prefill_meta.seq_start_loc,
prefill_meta.max_prefill_seq_len,
attn_metadata.seq_lens, True)
# prefix-enabled attention and ENCODER/ENCODER_DECODER
else:
(query_seq_start_loc, query_max_seq_len, key_seq_start_loc,
key_max_seq_len, seq_lens,
causal_mask) = _get_seq_len_block_table_args(
prefill_meta, attn_type)
# Prompt run.
if kv_cache.numel() == 0 or prefill_meta.block_tables.numel() == 0:
# triton attention
Expand Down

0 comments on commit fc7469e

Please sign in to comment.