Skip to content

Commit

Permalink
Custom PA perf improvements (#222)
Browse files Browse the repository at this point in the history
* enable custom PA with max seqlen 128k

* custom PA support to write out scaled fp8 value

* use regular divide for scaling

* enable custom PA to write out fp8 with scaling factor in llama

* linter fixes

* clang-format  fixes

* update abstract attn impl with fp8_out_scale

* add optional fp8_out_scale arg to all attn backend classes

* clang format fix

* add env var to enable cpa fp8 write out

* isort fix
  • Loading branch information
sanyalington authored Oct 8, 2024
1 parent 89bde53 commit b51fe69
Show file tree
Hide file tree
Showing 16 changed files with 257 additions and 101 deletions.
299 changes: 207 additions & 92 deletions csrc/rocm/attention.cu

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale,
double v_scale);
double v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" float k_scale, float v_scale,"
" Tensor? fp8_out_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
rocm_ops.def(
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
Expand Down
4 changes: 3 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,14 @@ def paged_attention_rocm(
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
fp8_out_scale: Optional[torch.Tensor],
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
kv_cache_dtype, k_scale, v_scale,
fp8_out_scale)


# pos encoding ops
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,6 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
1 change: 1 addition & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashInfer.")
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.
Expand Down
17 changes: 14 additions & 3 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.
Expand Down Expand Up @@ -731,12 +732,18 @@ def forward(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
cpa_fp8_out = False
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
if fp8_out_scale is not None:
out = torch.empty_like(output,
dtype=torch.float8_e4m3fnuz)
cpa_fp8_out = True
else:
out = output
ops.paged_attention_rocm(
output[num_prefill_tokens:],
out,
exp_sums,
max_logits,
tmp_output,
Expand All @@ -757,7 +764,10 @@ def forward(
self.kv_cache_dtype,
k_scale,
v_scale,
fp8_out_scale if cpa_fp8_out else None,
)
if cpa_fp8_out:
return out.view(num_seqs, num_heads * head_size)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
Expand Down Expand Up @@ -827,4 +837,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024)
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,17 @@ def forward(
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:

return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
attn_type=attn_type,
fp8_out_scale=fp8_out_scale)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -254,6 +255,12 @@ def get_default_config_root():
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1") != "0"),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "False").lower() in
("true", "1") != "0"),

# rank of the process in the distributed setting, used to determine
# the driver worker
"RANK":
Expand Down
12 changes: 11 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import nn
from transformers import LlamaConfig

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
Expand Down Expand Up @@ -180,6 +181,9 @@ def __init__(
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and is_hip() \
and isinstance(quant_config, Fp8Config)

def forward(
self,
Expand All @@ -191,7 +195,13 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
fp8_out_scale=self.o_proj.input_scale
if self.attn_fp8_out else None)
output, _ = self.o_proj(attn_output)
return output

Expand Down

0 comments on commit b51fe69

Please sign in to comment.