Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Pipe attn_logits_soft_cap through paged attention TPU kernels #12482

Merged
merged 1 commit into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file modified .buildkite/run-tpu-test.sh
100644 → 100755
Empty file.
42 changes: 16 additions & 26 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
if alibi_slopes is not None:
Expand All @@ -120,9 +121,6 @@ def __init__(
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
if logits_soft_cap is not None:
raise NotImplementedError(
"Attention logits soft-capping is not supported.")

if torch_xla.tpu.version() < 4:
raise NotImplementedError("TPU version must be 4 or higher.")
Expand Down Expand Up @@ -230,6 +228,7 @@ def forward(
num_kv_pages_per_compute_block,
num_queries_per_compute_block,
use_kernel=True,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
# Decoding run.
Expand Down Expand Up @@ -257,6 +256,7 @@ def forward(
attn_metadata.block_tables,
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
else:
chunk_size = max_num_seq
Expand All @@ -280,6 +280,7 @@ def forward(
attn_metadata.block_tables[chunk_start:chunk_end],
pages_per_compute_block,
self.megacore_mode,
attn_logits_soft_cap=self.logits_soft_cap,
)
output[chunk_start:chunk_end] = chunk_output

Expand Down Expand Up @@ -313,33 +314,22 @@ def paged_attention(
block_tables: torch.Tensor,
pages_per_compute_block: int,
megacore_mode: Optional[str],
*,
attn_logits_soft_cap: Optional[float],
) -> torch.Tensor:
batch_size = query.shape[0]
if megacore_mode == "batch" and batch_size % 2 != 0:
megacore_mode = None
else:
megacore_mode = megacore_mode

# NOTE(woosuk): A temporary workaround to avoid the error:
# "xla::paged_attention() Expected a value of type 'str' for
# argument 'megacore_mode' but instead found type 'NoneType'."
if megacore_mode is not None:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
)
else:
output = torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
)
return output
return torch.ops.xla.paged_attention(
query,
key_cache,
value_cache,
context_lens,
block_tables,
pages_per_compute_block,
megacore_mode=megacore_mode,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

megacore_mode is checked internally for None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we fixed it on the ptxla side. No workaround needed here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool

attn_logits_soft_cap=attn_logits_soft_cap,
)
Loading