Skip to content

Commit b51fe69

Browse files
authored
Custom PA perf improvements (#222)
* enable custom PA with max seqlen 128k * custom PA support to write out scaled fp8 value * use regular divide for scaling * enable custom PA to write out fp8 with scaling factor in llama * linter fixes * clang-format fixes * update abstract attn impl with fp8_out_scale * add optional fp8_out_scale arg to all attn backend classes * clang format fix * add env var to enable cpa fp8 write out * isort fix
1 parent 89bde53 commit b51fe69

File tree

16 files changed

+257
-101
lines changed

16 files changed

+257
-101
lines changed

csrc/rocm/attention.cu

Lines changed: 207 additions & 92 deletions
Large diffs are not rendered by default.

csrc/rocm/ops.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
2020
int64_t max_context_len,
2121
const c10::optional<torch::Tensor>& alibi_slopes,
2222
const std::string& kv_cache_dtype, double k_scale,
23-
double v_scale);
23+
double v_scale,
24+
const c10::optional<torch::Tensor>& fp8_out_scale);

csrc/rocm/torch_bindings.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
3535
" int max_context_len,"
3636
" Tensor? alibi_slopes,"
3737
" str kv_cache_dtype,"
38-
" float k_scale, float v_scale) -> ()");
38+
" float k_scale, float v_scale,"
39+
" Tensor? fp8_out_scale) -> ()");
3940
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
4041
rocm_ops.def(
4142
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"

vllm/_custom_ops.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,12 +159,14 @@ def paged_attention_rocm(
159159
kv_cache_dtype: str,
160160
k_scale: float,
161161
v_scale: float,
162+
fp8_out_scale: Optional[torch.Tensor],
162163
) -> None:
163164
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
164165
key_cache, value_cache, num_kv_heads,
165166
scale, block_tables, seq_lens,
166167
block_size, max_seq_len, alibi_slopes,
167-
kv_cache_dtype, k_scale, v_scale)
168+
kv_cache_dtype, k_scale, v_scale,
169+
fp8_out_scale)
168170

169171

170172
# pos encoding ops

vllm/attention/backends/abstract.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,5 +228,6 @@ def forward(
228228
k_scale: float = 1.0,
229229
v_scale: float = 1.0,
230230
attn_type: AttentionType = AttentionType.DECODER,
231+
fp8_out_scale: Optional[torch.Tensor] = None,
231232
) -> torch.Tensor:
232233
raise NotImplementedError

vllm/attention/backends/blocksparse_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,7 @@ def forward(
349349
k_scale: float = 1.0,
350350
v_scale: float = 1.0,
351351
attn_type: AttentionType = AttentionType.DECODER,
352+
fp8_out_scale: Optional[torch.Tensor] = None,
352353
) -> torch.Tensor:
353354
"""Forward pass with FlashAttention and PagedAttention.
354355

vllm/attention/backends/flash_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,7 @@ def forward(
657657
k_scale: float = 1.0,
658658
v_scale: float = 1.0,
659659
attn_type: AttentionType = AttentionType.DECODER,
660+
fp8_out_scale: Optional[torch.Tensor] = None,
660661
) -> torch.Tensor:
661662
"""Forward pass with FlashAttention.
662663

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,7 @@ def forward(
751751
k_scale: float = 1.0,
752752
v_scale: float = 1.0,
753753
attn_type: AttentionType = AttentionType.DECODER,
754+
fp8_out_scale: Optional[torch.Tensor] = None,
754755
) -> torch.Tensor:
755756
assert k_scale == 1.0 and v_scale == 1.0, (
756757
"key/v_scale is not supported in FlashInfer.")

vllm/attention/backends/ipex_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def forward(
172172
k_scale: float = 1.0,
173173
v_scale: float = 1.0,
174174
attn_type: AttentionType = AttentionType.DECODER,
175+
fp8_out_scale: Optional[torch.Tensor] = None,
175176
) -> torch.Tensor:
176177
"""Forward pass with IPEX varlen_attention and PagedAttention.
177178

vllm/attention/backends/pallas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def forward(
148148
k_scale: float = 1.0,
149149
v_scale: float = 1.0,
150150
attn_type: AttentionType = AttentionType.DECODER,
151+
fp8_out_scale: Optional[torch.Tensor] = None,
151152
) -> torch.Tensor:
152153
"""Forward pass with Pallas attention.
153154

0 commit comments

Comments
 (0)