Skip to content

Commit

Permalink
[Kernel] Pipe attn_logits_soft_cap through paged attention TPU kernels (
Browse files Browse the repository at this point in the history
#12482)

Signed-off-by: Fenghui Zhang <[email protected]>
  • Loading branch information
fenghuizhang authored Jan 28, 2025
1 parent c386c43 commit 80fcc3e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 26 deletions.
Empty file modified .buildkite/run-tpu-test.sh
100644 → 100755
Empty file.
42 changes: 16 additions & 26 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
Expand All @@ -120,9 +121,6 @@ def __init__(
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")

if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
Expand Down Expand Up @@ -230,6 +228,7 @@ def forward(
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
# Decoding run.
Expand Down Expand Up @@ -257,6 +256,7 @@ def forward(
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
chunk_size = max_num_seq
Expand All @@ -280,6 +280,7 @@ def forward(
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
output[chunk_start:chunk_end] = chunk_output

Expand Down Expand Up @@ -313,33 +314,22 @@ def paged_attention(
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
*,
attn_logits_soft_cap: Optional[float],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode

# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
return torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
attn_logits_soft_cap=attn_logits_soft_cap,
)

0 comments on commit 80fcc3e

Please sign in to comment.