Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom PA perf improvements #222

Merged
merged 13 commits into from
Oct 8, 2024
299 changes: 207 additions & 92 deletions csrc/rocm/attention.cu

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale,
double v_scale);
double v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" int max_context_len,"
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" float k_scale, float v_scale) -> ()");
" float k_scale, float v_scale,"
" Tensor? fp8_out_scale) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
rocm_ops.def(
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
Expand Down
4 changes: 3 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,14 @@ def paged_attention_rocm(
kv_cache_dtype: str,
k_scale: float,
v_scale: float,
fp8_out_scale: Optional[torch.Tensor],
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale)
kv_cache_dtype, k_scale, v_scale,
fp8_out_scale)


# pos encoding ops
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,5 +228,6 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError
1 change: 1 addition & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention.

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert k_scale == 1.0 and v_scale == 1.0, (
"key/v_scale is not supported in FlashInfer.")
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with IPEX varlen_attention and PagedAttention.

Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with Pallas attention.

Expand Down
17 changes: 14 additions & 3 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: torch.Tensor = None,
) -> torch.Tensor:
"""Forward pass with FlashAttention and PagedAttention.

Expand Down Expand Up @@ -731,12 +732,18 @@ def forward(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
cpa_fp8_out = False
if num_prefill_tokens > 0:
out = output[num_prefill_tokens:]
else:
out = output
if fp8_out_scale is not None:
out = torch.empty_like(output,
dtype=torch.float8_e4m3fnuz)
cpa_fp8_out = True
else:
out = output
ops.paged_attention_rocm(
output[num_prefill_tokens:],
out,
exp_sums,
max_logits,
tmp_output,
Expand All @@ -757,7 +764,10 @@ def forward(
self.kv_cache_dtype,
k_scale,
v_scale,
fp8_out_scale if cpa_fp8_out else None,
)
if cpa_fp8_out:
return out.view(num_seqs, num_heads * head_size)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
decode_query,
Expand Down Expand Up @@ -827,4 +837,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)
and (gqa_ratio >= 1 and gqa_ratio <= 16)
and max_seq_len <= 128 * 1024)
1 change: 1 addition & 0 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with torch SDPA and PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def forward(
k_scale: float = 1.0,
v_scale: float = 1.0,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Forward pass with xFormers and PagedAttention.
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,16 +93,17 @@ def forward(
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
attn_type: AttentionType = AttentionType.DECODER,
fp8_out_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:

return self.impl.forward(query,
key,
value,
kv_cache,
attn_metadata,
self._k_scale,
self._v_scale,
attn_type=attn_type)
attn_type=attn_type,
fp8_out_scale=fp8_out_scale)

def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
Expand Down
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
VLLM_USE_TRITON_FLASH_ATTN: bool = True
VLLM_USE_ROCM_SKINNY_GEMM: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True
VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False
RANK: int = 0
LOCAL_RANK: int = 0
CUDA_VISIBLE_DEVICES: Optional[str] = None
Expand Down Expand Up @@ -254,6 +255,12 @@ def get_default_config_root():
lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in
("true", "1") != "0"),

# have custom paged attention implemented for MI3* cards write out fp8
"VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT":
lambda:
(os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "False").lower() in
("true", "1") != "0"),

# rank of the process in the distributed setting, used to determine
# the driver worker
"RANK":
Expand Down
12 changes: 11 additions & 1 deletion vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torch import nn
from transformers import LlamaConfig

import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
Expand Down Expand Up @@ -180,6 +181,9 @@ def __init__(
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \
and is_hip() \
and isinstance(quant_config, Fp8Config)

def forward(
self,
Expand All @@ -191,7 +195,13 @@ def forward(
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
attn_output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
fp8_out_scale=self.o_proj.input_scale
if self.attn_fp8_out else None)
output, _ = self.o_proj(attn_output)
return output

Expand Down
Loading