From 80fcc3ed1c940ea43e1b495bbdf8b9765f837128 Mon Sep 17 00:00:00 2001
From: fenghuizhang <159459388+fenghuizhang@users.noreply.github.com>
Date: Tue, 28 Jan 2025 14:36:44 -0800
Subject: [PATCH] [Kernel] Pipe attn_logits_soft_cap through paged attention
 TPU kernels (#12482)

Signed-off-by: Fenghui Zhang <fhzhang@google.com>
---
 .buildkite/run-tpu-test.sh        |  0
 vllm/attention/backends/pallas.py | 42 ++++++++++++-------------------
 2 files changed, 16 insertions(+), 26 deletions(-)
 mode change 100644 => 100755 .buildkite/run-tpu-test.sh

diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh
old mode 100644
new mode 100755
diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py
index facdee6b29e39..209a623ba441c 100644
--- a/vllm/attention/backends/pallas.py
+++ b/vllm/attention/backends/pallas.py
@@ -110,6 +110,7 @@ def __init__(
 
         assert self.num_heads % self.num_kv_heads == 0
         self.num_queries_per_kv = self.num_heads // self.num_kv_heads
+        self.logits_soft_cap = logits_soft_cap
         if head_size % 128 != 0:
             raise NotImplementedError("Head size must be a multiple of 128.")
         if alibi_slopes is not None:
@@ -120,9 +121,6 @@ def __init__(
             raise NotImplementedError("FP8 KV cache dtype is not supported.")
         if blocksparse_params is not None:
             raise NotImplementedError("Blocksparse is not supported.")
-        if logits_soft_cap is not None:
-            raise NotImplementedError(
-                "Attention logits soft-capping is not supported.")
 
         if torch_xla.tpu.version() < 4:
             raise NotImplementedError("TPU version must be 4 or higher.")
@@ -230,6 +228,7 @@ def forward(
                     num_kv_pages_per_compute_block,
                     num_queries_per_compute_block,
                     use_kernel=True,
+                    attn_logits_soft_cap=self.logits_soft_cap,
                 )
         else:
             # Decoding run.
@@ -257,6 +256,7 @@ def forward(
                     attn_metadata.block_tables,
                     pages_per_compute_block,
                     self.megacore_mode,
+                    attn_logits_soft_cap=self.logits_soft_cap,
                 )
             else:
                 chunk_size = max_num_seq
@@ -280,6 +280,7 @@ def forward(
                         attn_metadata.block_tables[chunk_start:chunk_end],
                         pages_per_compute_block,
                         self.megacore_mode,
+                        attn_logits_soft_cap=self.logits_soft_cap,
                     )
                     output[chunk_start:chunk_end] = chunk_output
 
@@ -313,6 +314,8 @@ def paged_attention(
     block_tables: torch.Tensor,
     pages_per_compute_block: int,
     megacore_mode: Optional[str],
+    *,
+    attn_logits_soft_cap: Optional[float],
 ) -> torch.Tensor:
     batch_size = query.shape[0]
     if megacore_mode == "batch" and batch_size % 2 != 0:
@@ -320,26 +323,13 @@ def paged_attention(
     else:
         megacore_mode = megacore_mode
 
-    # NOTE(woosuk): A temporary workaround to avoid the error:
-    # "xla::paged_attention() Expected a value of type 'str' for
-    # argument 'megacore_mode' but instead found type 'NoneType'."
-    if megacore_mode is not None:
-        output = torch.ops.xla.paged_attention(
-            query,
-            key_cache,
-            value_cache,
-            context_lens,
-            block_tables,
-            pages_per_compute_block,
-            megacore_mode=megacore_mode,
-        )
-    else:
-        output = torch.ops.xla.paged_attention(
-            query,
-            key_cache,
-            value_cache,
-            context_lens,
-            block_tables,
-            pages_per_compute_block,
-        )
-    return output
+    return torch.ops.xla.paged_attention(
+        query,
+        key_cache,
+        value_cache,
+        context_lens,
+        block_tables,
+        pages_per_compute_block,
+        megacore_mode=megacore_mode,
+        attn_logits_soft_cap=attn_logits_soft_cap,
+    )