diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh old mode 100644 new mode 100755 diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index facdee6b29e39..209a623ba441c 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -110,6 +110,7 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.logits_soft_cap = logits_soft_cap if head_size % 128 != 0: raise NotImplementedError("Head size must be a multiple of 128.") if alibi_slopes is not None: @@ -120,9 +121,6 @@ def __init__( raise NotImplementedError("FP8 KV cache dtype is not supported.") if blocksparse_params is not None: raise NotImplementedError("Blocksparse is not supported.") - if logits_soft_cap is not None: - raise NotImplementedError( - "Attention logits soft-capping is not supported.") if torch_xla.tpu.version() < 4: raise NotImplementedError("TPU version must be 4 or higher.") @@ -230,6 +228,7 @@ def forward( num_kv_pages_per_compute_block, num_queries_per_compute_block, use_kernel=True, + attn_logits_soft_cap=self.logits_soft_cap, ) else: # Decoding run. @@ -257,6 +256,7 @@ def forward( attn_metadata.block_tables, pages_per_compute_block, self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, ) else: chunk_size = max_num_seq @@ -280,6 +280,7 @@ def forward( attn_metadata.block_tables[chunk_start:chunk_end], pages_per_compute_block, self.megacore_mode, + attn_logits_soft_cap=self.logits_soft_cap, ) output[chunk_start:chunk_end] = chunk_output @@ -313,6 +314,8 @@ def paged_attention( block_tables: torch.Tensor, pages_per_compute_block: int, megacore_mode: Optional[str], + *, + attn_logits_soft_cap: Optional[float], ) -> torch.Tensor: batch_size = query.shape[0] if megacore_mode == "batch" and batch_size % 2 != 0: @@ -320,26 +323,13 @@ def paged_attention( else: megacore_mode = megacore_mode - # NOTE(woosuk): A temporary workaround to avoid the error: - # "xla::paged_attention() Expected a value of type 'str' for - # argument 'megacore_mode' but instead found type 'NoneType'." - if megacore_mode is not None: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - megacore_mode=megacore_mode, - ) - else: - output = torch.ops.xla.paged_attention( - query, - key_cache, - value_cache, - context_lens, - block_tables, - pages_per_compute_block, - ) - return output + return torch.ops.xla.paged_attention( + query, + key_cache, + value_cache, + context_lens, + block_tables, + pages_per_compute_block, + megacore_mode=megacore_mode, + attn_logits_soft_cap=attn_logits_soft_cap, + )