From b51fe698281197721d46fe981628b2743c9220e9 Mon Sep 17 00:00:00 2001 From: sanyalington Date: Tue, 8 Oct 2024 23:52:21 +0530 Subject: [PATCH] Custom PA perf improvements (#222) * enable custom PA with max seqlen 128k * custom PA support to write out scaled fp8 value * use regular divide for scaling * enable custom PA to write out fp8 with scaling factor in llama * linter fixes * clang-format fixes * update abstract attn impl with fp8_out_scale * add optional fp8_out_scale arg to all attn backend classes * clang format fix * add env var to enable cpa fp8 write out * isort fix --- csrc/rocm/attention.cu | 299 ++++++++++++++------ csrc/rocm/ops.h | 3 +- csrc/rocm/torch_bindings.cpp | 3 +- vllm/_custom_ops.py | 4 +- vllm/attention/backends/abstract.py | 1 + vllm/attention/backends/blocksparse_attn.py | 1 + vllm/attention/backends/flash_attn.py | 1 + vllm/attention/backends/flashinfer.py | 1 + vllm/attention/backends/ipex_attn.py | 1 + vllm/attention/backends/pallas.py | 1 + vllm/attention/backends/rocm_flash_attn.py | 17 +- vllm/attention/backends/torch_sdpa.py | 1 + vllm/attention/backends/xformers.py | 1 + vllm/attention/layer.py | 5 +- vllm/envs.py | 7 + vllm/model_executor/models/llama.py | 12 +- 16 files changed, 257 insertions(+), 101 deletions(-) diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index b48348a515c8d..a7d3a47f2d11a 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -63,6 +63,7 @@ typedef struct _B16x8 { } _B16x8; using _B8x8 = uint2; +using bit8_t = uint8_t; ////// Non temporal load stores /////// @@ -197,8 +198,8 @@ __device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, // grid (num_seqs, num_partitions,num_heads/gqa_ratio) // block (partition size) template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -215,10 +216,11 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -657,18 +659,12 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( __syncthreads(); if (warpid == 0) { + // const float out_scale = (fp8_out_scale_ptr != nullptr) ? + // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; _B16x4 vout[QHLOOP][VHELOOP]; // iterate across heads - scalar_t* out_ptr; - int out_num_partitions; - if (context_len > partition_size) { - out_num_partitions = max_num_partitions; - out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + - partition_idx * HEAD_SIZE; - } else { - out_num_partitions = 1; - out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; - } #pragma unroll for (int qh = 0; qh < QHLOOP; qh++) { // iterate over each v head elem (within head_size) @@ -680,28 +676,74 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( vout[qh][vh] = addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); } - const int head_size_elem = vh * WARP_SIZE + laneid; - bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + } + } + + if (context_len > partition_size) { + scalar_t* out_ptr = out + + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + const int out_num_partitions = max_num_partitions; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); #pragma unroll - for (int i = 0; i < 4; i++) { - const int head_idx = 4 * qh + i; - if (head_idx < GQA_RATIO) { - out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * - HEAD_SIZE + - head_size_elem] = vout[qh][vh][i]; + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + const int head_size_elem = vh * WARP_SIZE + laneid; + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } + } + } + } // context_len > partition_size + else { + bit8_t* final_out_ptr_b8; + bit16_t* final_out_ptr_b16; + if constexpr (std::is_same::value) { + final_out_ptr_b8 = final_out + seq_idx * num_heads * HEAD_SIZE; + } else { + OUTT* out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + final_out_ptr_b16 = reinterpret_cast(out_ptr); + } + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + const int head_size_elem = vh * WARP_SIZE + laneid; + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + if constexpr (std::is_same::value) { + const float tmpf = + out_scale * to_float(vout[qh][vh][i]); + const OUTT tmp = vllm::fp8::vec_conversion(tmpf); + final_out_ptr_b8[(wg_start_head_idx + head_idx) * HEAD_SIZE + + head_size_elem] = tmp; + } else { + final_out_ptr_b16[(wg_start_head_idx + head_idx) * HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } } } } } - } + } // warpid == 0 } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -709,7 +751,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { + const int max_num_partitions, const float* __restrict__ fp8_out_scale_ptr) { const int num_heads = gridDim.x; const int head_idx = blockIdx.x; const int seq_idx = blockIdx.y; @@ -726,7 +768,8 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const int laneid = threadIdx.x % WARP_SIZE; __shared__ float shared_global_exp_sum; - __shared__ float shared_exp_sums[2 * WARP_SIZE]; + // max num partitions supported is warp_size * NPAR_LOOPS + __shared__ float shared_exp_sums[NPAR_LOOPS * WARP_SIZE]; if (warpid == 0) { const float* max_logits_ptr = max_logits + @@ -735,14 +778,25 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( // valid partition is the last valid partition in case threadid > num // partitions - const int valid_partition = - (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; - const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) - ? WARP_SIZE + threadIdx.x - : num_partitions - 1; - float reg_max_logit = max_logits_ptr[valid_partition]; - float reg_max_logit2 = max_logits_ptr[valid_partition2]; - float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + int valid_partition[NPAR_LOOPS]; + float reg_max_logit[NPAR_LOOPS]; + const int last_valid_partition = num_partitions - 1; + + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + valid_partition[i] = + (partition_no < num_partitions) ? partition_no : last_valid_partition; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + reg_max_logit[i] = max_logits_ptr[valid_partition[i]]; + } + float max_logit = reg_max_logit[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + max_logit = fmaxf(max_logit, reg_max_logit[i]); + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -753,17 +807,28 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( seq_idx * num_heads * max_num_partitions + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - float rescaled_exp_sum = exp_sums_ptr[valid_partition]; - float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; - rescaled_exp_sum *= - (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; - rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) - ? expf(reg_max_logit2 - max_logit) - : 0.0f; - global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; - shared_exp_sums[threadIdx.x] = rescaled_exp_sum; - shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + float rescaled_exp_sum[NPAR_LOOPS]; + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + rescaled_exp_sum[i] = exp_sums_ptr[valid_partition[i]]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + rescaled_exp_sum[i] *= (partition_no < num_partitions) + ? expf(reg_max_logit[i] - max_logit) + : 0.0f; + } + float global_exp_sum = rescaled_exp_sum[0]; + #pragma unroll + for (int i = 1; i < NPAR_LOOPS; i++) { + global_exp_sum += rescaled_exp_sum[i]; + } + #pragma unroll + for (int i = 0; i < NPAR_LOOPS; i++) { + const int partition_no = i * WARP_SIZE + threadIdx.x; + shared_exp_sums[partition_no] = rescaled_exp_sum[i]; + } #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { @@ -840,37 +905,47 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( } } - if (num_partitions > MAX_NPAR) { - idx = 0; + for (int p = 1; p < NPAR_LOOPS; p++) { + if (num_partitions > p * MAX_NPAR) { + idx = 0; #pragma unroll - for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; - j += HEAD_SIZE) { - // lastj is last valid partition - const int lastj_offset = - (j < num_partition_offset) ? j : last_partition_offset; - tmps[idx] = tmp_out_ptr[lastj_offset]; - idx++; - } + for (int j = p * MAX_NPAR * HEAD_SIZE; j < (p + 1) * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } #pragma unroll - for (int j = 0; j < MAX_NPAR; j++) { - acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + p * MAX_NPAR]; + } } } const float inv_global_exp_sum = __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + // const float out_scale = (fp8_out_scale_ptr != nullptr) ? + // __fdividef(1.0f,(*fp8_out_scale_ptr)) : 1.0f; + const float out_scale = + (fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f; acc *= inv_global_exp_sum; - scalar_t* out_ptr = - out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - out_ptr[threadIdx.x] = from_float(acc); + acc *= out_scale; + OUTT* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + if constexpr (std::is_same::value) { + out_ptr[threadIdx.x] = vllm::fp8::vec_conversion(acc); + } else { + out_ptr[threadIdx.x] = from_float(acc); + } } #else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] @@ -887,19 +962,20 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] float* __restrict__ max_logits, // [num_seqs, num_heads, // max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, - // head_size] - scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale, + const float* __restrict__ fp8_out_scale_ptr) { UNREACHABLE_CODE } // Grid: (num_heads, num_seqs). -template +template __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + OUTT* __restrict__ out, // [num_seqs, num_heads, head_size] const float* __restrict__ exp_sums, // [num_seqs, num_heads, // max_num_partitions] const float* __restrict__ max_logits, // [num_seqs, num_heads, @@ -907,29 +983,39 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, // max_num_partitions, head_size] const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions){UNREACHABLE_CODE} + const int max_num_partitions, + const float* __restrict__ fp8_out_scale_ptr){UNREACHABLE_CODE} #endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support #define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ - paged_attention_ll4mi_QKV_kernel \ + paged_attention_ll4mi_QKV_kernel \ <<>>( \ query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale, v_scale); + k_scale, v_scale, fp8_out_scale_ptr); + +#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \ + paged_attention_ll4mi_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, \ + context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, + int PARTITION_SIZE = 512> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, const c10::optional& alibi_slopes, - float k_scale, float v_scale) { + float k_scale, float v_scale, + const c10::optional& fp8_out_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -944,7 +1030,6 @@ void paged_attention_custom_launcher( ? reinterpret_cast(alibi_slopes.value().data_ptr()) : nullptr; - T* out_ptr = reinterpret_cast(out.data_ptr()); float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); @@ -954,13 +1039,20 @@ void paged_attention_custom_launcher( int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); + // NOTE: fp8_out_scale is optional. + const float* fp8_out_scale_ptr = + fp8_out_scale + ? reinterpret_cast(fp8_out_scale.value().data_ptr()) + : nullptr; + OUTT* out_ptr = reinterpret_cast(out.data_ptr()); + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - assert(max_num_partitions <= 128); + assert(max_num_partitions <= 256); constexpr int NTHR = PARTITION_SIZE; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); @@ -1032,26 +1124,48 @@ void paged_attention_custom_launcher( if (max_context_len > PARTITION_SIZE) { dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); - paged_attention_ll4mi_reduce_kernel - <<>>( - out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, - context_lens_ptr, max_num_partitions); + const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); + switch (npar_loops) { + case 1: + LAUNCH_CUSTOM_REDUCTION(1); + break; + case 2: + LAUNCH_CUSTOM_REDUCTION(2); + break; + case 3: + LAUNCH_CUSTOM_REDUCTION(3); + break; + case 4: + LAUNCH_CUSTOM_REDUCTION(4); + break; + default: + TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); + break; + } } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ - paged_attention_custom_launcher( \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT) \ + paged_attention_custom_launcher( \ out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ num_kv_heads, scale, block_tables, context_lens, max_context_len, \ - alibi_slopes, k_scale, v_scale); + alibi_slopes, k_scale, v_scale, fp8_out_scale); + +#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, uint8_t); \ + } else { \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ + } #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ break; \ case 32: \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ break; \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ @@ -1087,7 +1201,8 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale) { + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const c10::optional& fp8_out_scale) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { @@ -1117,4 +1232,4 @@ void paged_attention( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP \ No newline at end of file +#undef DIVIDE_ROUND_UP diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 18c72f937f90a..9549cfa5dae85 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -20,4 +20,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, int64_t max_context_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, - double v_scale); + double v_scale, + const c10::optional& fp8_out_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 2efa03e87e214..4d21ea944ee41 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -35,7 +35,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " float k_scale, float v_scale," + " Tensor? fp8_out_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); rocm_ops.def( "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 944130d424394..c7da00c6ab0e2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -159,12 +159,14 @@ def paged_attention_rocm( kv_cache_dtype: str, k_scale: float, v_scale: float, + fp8_out_scale: Optional[torch.Tensor], ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, - kv_cache_dtype, k_scale, v_scale) + kv_cache_dtype, k_scale, v_scale, + fp8_out_scale) # pos encoding ops diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96b..e9d5e61125193 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -228,5 +228,6 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d84a40890ebbd..ce622ba5c6805 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -349,6 +349,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 084e8113cd421..8550dadbc8482 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -657,6 +657,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with FlashAttention. diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 3a602fbfbbc04..f53861eb80c67 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -751,6 +751,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashInfer.") diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 113a2788eacd3..513d5e2401226 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -172,6 +172,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 83fdef16ef5cb..cff177349bf3f 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -148,6 +148,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with Pallas attention. diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 417dbc6d1483c..af6e2dbbc1590 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -506,6 +506,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: torch.Tensor = None, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. @@ -731,12 +732,18 @@ def forward( device=output.device, ) max_logits = torch.empty_like(exp_sums) + cpa_fp8_out = False if num_prefill_tokens > 0: out = output[num_prefill_tokens:] else: - out = output + if fp8_out_scale is not None: + out = torch.empty_like(output, + dtype=torch.float8_e4m3fnuz) + cpa_fp8_out = True + else: + out = output ops.paged_attention_rocm( - output[num_prefill_tokens:], + out, exp_sums, max_logits, tmp_output, @@ -757,7 +764,10 @@ def forward( self.kv_cache_dtype, k_scale, v_scale, + fp8_out_scale if cpa_fp8_out else None, ) + if cpa_fp8_out: + return out.view(num_seqs, num_heads * head_size) else: output[num_prefill_tokens:] = PagedAttention.forward_decode( decode_query, @@ -827,4 +837,5 @@ def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) - and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768) + and (gqa_ratio >= 1 and gqa_ratio <= 16) + and max_seq_len <= 128 * 1024) diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8a1f8f2930c84..c51f86d9ac793 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -156,6 +156,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index e073d616bf01d..cba56fd2f37c5 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -450,6 +450,7 @@ def forward( k_scale: float = 1.0, v_scale: float = 1.0, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index ecf964fa49d9b..4cf35ce079bc8 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -93,8 +93,8 @@ def forward( kv_cache: Optional[torch.Tensor], attn_metadata: AttentionMetadata, attn_type: AttentionType = AttentionType.DECODER, + fp8_out_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return self.impl.forward(query, key, value, @@ -102,7 +102,8 @@ def forward( attn_metadata, self._k_scale, self._v_scale, - attn_type=attn_type) + attn_type=attn_type, + fp8_out_scale=fp8_out_scale) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore diff --git a/vllm/envs.py b/vllm/envs.py index ee4711dbec842..803dacc8cedc6 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -17,6 +17,7 @@ VLLM_USE_TRITON_FLASH_ATTN: bool = True VLLM_USE_ROCM_SKINNY_GEMM: bool = True VLLM_USE_ROCM_CUSTOM_PAGED_ATTN: bool = True + VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT: bool = False RANK: int = 0 LOCAL_RANK: int = 0 CUDA_VISIBLE_DEVICES: Optional[str] = None @@ -254,6 +255,12 @@ def get_default_config_root(): lambda: (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1") != "0"), + # have custom paged attention implemented for MI3* cards write out fp8 + "VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT": + lambda: + (os.getenv("VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT", "False").lower() in + ("true", "1") != "0"), + # rank of the process in the distributed setting, used to determine # the driver worker "RANK": diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 125acdd63ce5c..6b6f8e165b5ca 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,6 +27,7 @@ from torch import nn from transformers import LlamaConfig +import vllm.envs as envs from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -180,6 +181,9 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config) + self.attn_fp8_out = envs.VLLM_USE_ROCM_CUSTOM_PAGED_ATTN_FP8_OUT \ + and is_hip() \ + and isinstance(quant_config, Fp8Config) def forward( self, @@ -191,7 +195,13 @@ def forward( qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + fp8_out_scale=self.o_proj.input_scale + if self.attn_fp8_out else None) output, _ = self.o_proj(attn_output) return output