From a8f6134ed8105f9680a4f6dd863c13c907583ab5 Mon Sep 17 00:00:00 2001 From: Gregory Shtrasberg <156009573+gshtras@users.noreply.github.com> Date: Thu, 23 Jan 2025 13:04:03 -0500 Subject: [PATCH] [FP8][Kernel] Dynamic kv cache scaling factors computation (#11906) Signed-off-by: Gregory Shtrasberg Co-authored-by: Micah Williamson --- .../kernels/benchmark_paged_attention.py | 4 +- csrc/attention/attention_kernels.cuh | 10 +- csrc/attention/paged_attention_v1.cu | 17 +- csrc/attention/paged_attention_v2.cu | 17 +- csrc/cache.h | 6 +- csrc/cache_kernels.cu | 30 +- csrc/cpu/attention.cpp | 12 +- csrc/cpu/cache.cpp | 6 +- csrc/cpu/torch_bindings.cpp | 6 +- csrc/ops.h | 10 +- csrc/rocm/attention.cu | 17 +- csrc/rocm/ops.h | 4 +- csrc/rocm/torch_bindings.cpp | 2 +- csrc/torch_bindings.cpp | 8 +- .../quantization/quantized_kvcache.md | 10 +- examples/other/fp8/README.md | 96 ----- examples/other/fp8/extract_scales.py | 367 ------------------ examples/other/fp8/quantizer/README.md | 32 -- examples/other/fp8/quantizer/quantize.py | 367 ------------------ .../llama2-70b-fp8-kv/kv_cache_scales.json | 90 ----- .../llama2-7b-fp8-kv/kv_cache_scales.json | 42 -- tests/kernels/test_attention.py | 2 +- tests/kernels/test_blocksparse_attention.py | 2 +- tests/kernels/test_cache.py | 10 +- tests/kernels/test_prefix_prefill.py | 10 + tests/kernels/utils.py | 2 + .../models/decoder_only/language/test_fp8.py | 15 +- tests/worker/test_model_input.py | 3 + vllm/_custom_ops.py | 20 +- vllm/attention/backends/abstract.py | 10 +- vllm/attention/backends/blocksparse_attn.py | 2 + vllm/attention/backends/flash_attn.py | 5 +- vllm/attention/backends/flashinfer.py | 10 +- vllm/attention/backends/ipex_attn.py | 2 +- vllm/attention/backends/pallas.py | 2 +- vllm/attention/backends/placeholder_attn.py | 3 + vllm/attention/backends/rocm_flash_attn.py | 2 + vllm/attention/backends/torch_sdpa.py | 2 +- vllm/attention/backends/utils.py | 2 + vllm/attention/backends/xformers.py | 2 + vllm/attention/layer.py | 28 +- vllm/attention/ops/ipex_attn.py | 16 +- vllm/attention/ops/paged_attn.py | 12 +- vllm/attention/ops/prefix_prefill.py | 12 +- vllm/config.py | 16 +- vllm/engine/arg_utils.py | 25 +- vllm/envs.py | 9 + .../layers/quantization/kv_cache.py | 19 +- .../model_loader/weight_utils.py | 45 +-- vllm/model_executor/models/exaone.py | 35 +- vllm/model_executor/models/granite.py | 32 +- vllm/model_executor/models/llama.py | 35 +- vllm/model_executor/models/mllama.py | 7 +- vllm/model_executor/models/solar.py | 35 +- vllm/v1/attention/backends/flash_attn.py | 4 - vllm/worker/hpu_model_runner.py | 7 +- vllm/worker/model_runner.py | 37 +- vllm/worker/openvino_model_runner.py | 1 + vllm/worker/tpu_model_runner.py | 5 + vllm/worker/xpu_model_runner.py | 2 + 60 files changed, 276 insertions(+), 1365 deletions(-) delete mode 100644 examples/other/fp8/README.md delete mode 100644 examples/other/fp8/extract_scales.py delete mode 100644 examples/other/fp8/quantizer/README.md delete mode 100644 examples/other/fp8/quantizer/quantize.py delete mode 100644 tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json delete mode 100644 tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 14eef00b855ac..219013a38134b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -98,7 +98,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: start_time = time.perf_counter() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, + dtype=torch.float32, + device=device) for _ in range(num_iters): if version == "v1": diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index 563e1438f0b01..eb216dc8baf10 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -105,7 +105,7 @@ __device__ void paged_attention_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { const int seq_idx = blockIdx.y; @@ -285,7 +285,7 @@ __device__ void paged_attention_kernel( Quant_vec k_vec_quant = *reinterpret_cast( k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_vecs[j] = fp8::scaled_convert( - k_vec_quant, k_scale); + k_vec_quant, *k_scale); } } @@ -415,7 +415,7 @@ __device__ void paged_attention_kernel( *reinterpret_cast(v_ptr + offset); // Vector conversion from V_quant_vec to V_vec. v_vec = fp8::scaled_convert(v_quant_vec, - v_scale); + *v_scale); } if (block_idx == num_seq_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the @@ -513,7 +513,7 @@ __global__ void paged_attention_v1_kernel( const int max_num_blocks_per_seq, const float* __restrict__ alibi_slopes, // [num_heads] const int q_stride, const int kv_block_stride, const int kv_head_stride, - const float k_scale, const float v_scale, const int tp_rank, + const float* k_scale, const float* v_scale, const int tp_rank, const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_head_sliding_step) { paged_attention_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -80,6 +80,8 @@ void paged_attention_v1_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int padded_max_seq_len = @@ -177,8 +179,9 @@ void paged_attention_v1( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/attention/paged_attention_v2.cu b/csrc/attention/paged_attention_v2.cu index a453b2243e48c..9935359e02fb1 100644 --- a/csrc/attention/paged_attention_v2.cu +++ b/csrc/attention/paged_attention_v2.cu @@ -37,7 +37,7 @@ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ - kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + kv_block_stride, kv_head_stride, k_scale_ptr, v_scale_ptr, tp_rank, \ blocksparse_local_blocks, blocksparse_vert_stride, \ blocksparse_block_size, blocksparse_head_sliding_step); \ vllm::paged_attention_v2_reduce_kernel& alibi_slopes, float k_scale, - float v_scale, const int tp_rank, const int blocksparse_local_blocks, - const int blocksparse_vert_stride, const int blocksparse_block_size, - const int blocksparse_head_sliding_step) { + const std::optional& alibi_slopes, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -84,6 +84,8 @@ void paged_attention_v2_launcher( CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* seq_lens_ptr = seq_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); @@ -188,8 +190,9 @@ void paged_attention_v2( torch::Tensor& seq_lens, // [num_seqs] int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { const bool is_block_sparse = (blocksparse_vert_stride > 1); diff --git a/csrc/cache.h b/csrc/cache.h index 11c4c5001daaa..eedad9fafa3c0 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -18,15 +18,15 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale); + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale); void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, const std::string& kv_cache_dtype, - const double k_scale, const double v_scale); + torch::Tensor& k_scale, torch::Tensor& v_scale); // Just for unittest void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 8a95279f9a25a..21a0aec0ececc 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel( // block_size] const int64_t* __restrict__ slot_mapping, // [num_tokens] const int key_stride, const int value_stride, const int num_heads, - const int head_size, const int block_size, const int x, const float k_scale, - const float v_scale) { + const int head_size, const int block_size, const int x, + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; if (slot_idx < 0) { @@ -196,9 +196,9 @@ __global__ void reshape_and_cache_kernel( value_cache[tgt_value_idx] = tgt_value; } else { key_cache[tgt_key_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -214,7 +214,7 @@ __global__ void reshape_and_cache_flash_kernel( const int64_t* __restrict__ slot_mapping, // [num_tokens] const int block_stride, const int key_stride, const int value_stride, const int num_heads, const int head_size, const int block_size, - const float k_scale, const float v_scale) { + const float* k_scale, const float* v_scale) { const int64_t token_idx = blockIdx.x; const int64_t slot_idx = slot_mapping[token_idx]; // NOTE: slot_idx can be -1 if the token is padded @@ -239,9 +239,9 @@ __global__ void reshape_and_cache_flash_kernel( value_cache[tgt_key_value_idx] = tgt_value; } else { key_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_key, k_scale); + fp8::scaled_convert(tgt_key, *k_scale); value_cache[tgt_key_value_idx] = - fp8::scaled_convert(tgt_value, v_scale); + fp8::scaled_convert(tgt_value, *v_scale); } } } @@ -258,7 +258,9 @@ __global__ void reshape_and_cache_flash_kernel( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), key_stride, value_stride, \ - num_heads, head_size, block_size, x, k_scale, v_scale); + num_heads, head_size, block_size, x, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -268,8 +270,8 @@ void reshape_and_cache( torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] torch::Tensor& slot_mapping, // [num_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); @@ -299,7 +301,9 @@ void reshape_and_cache( reinterpret_cast(key_cache.data_ptr()), \ reinterpret_cast(value_cache.data_ptr()), \ slot_mapping.data_ptr(), block_stride, key_stride, \ - value_stride, num_heads, head_size, block_size, k_scale, v_scale); + value_stride, num_heads, head_size, block_size, \ + reinterpret_cast(k_scale.data_ptr()), \ + reinterpret_cast(v_scale.data_ptr())); void reshape_and_cache_flash( torch::Tensor& key, // [num_tokens, num_heads, head_size] @@ -308,8 +312,8 @@ void reshape_and_cache_flash( torch::Tensor& value_cache, // [num_blocks, block_size, num_heads, head_size] torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens] - const std::string& kv_cache_dtype, const double k_scale, - const double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { // NOTE(woosuk): In vLLM V1, key.size(0) can be different from // slot_mapping.size(0) because of padding for CUDA graphs. // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because diff --git a/csrc/cpu/attention.cpp b/csrc/cpu/attention.cpp index ef5b14088c63b..b9764056e8a2d 100644 --- a/csrc/cpu/attention.cpp +++ b/csrc/cpu/attention.cpp @@ -460,11 +460,11 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl", @@ -782,11 +782,11 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); TORCH_CHECK(blocksparse_vert_stride <= 1, "CPU backend does not support blocksparse attention yet."); VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl", diff --git a/csrc/cpu/cache.cpp b/csrc/cpu/cache.cpp index 31d454328b2c1..e3809acad7453 100644 --- a/csrc/cpu/cache.cpp +++ b/csrc/cpu/cache.cpp @@ -107,10 +107,8 @@ void copy_blocks(std::vector const& key_caches, void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, torch::Tensor& key_cache, torch::Tensor& value_cache, torch::Tensor& slot_mapping, - const std::string& kv_cache_dtype, double k_scale, - double v_scale) { - TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f); - + const std::string& kv_cache_dtype, + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_tokens = key.size(0); int num_heads = key.size(1); int head_size = key.size(2); diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 74e4d8189d403..5d1c5f4c83d3e 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -148,7 +148,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache); } diff --git a/csrc/ops.h b/csrc/ops.h index 5a194a0dd3654..346898964010d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -34,8 +34,9 @@ void paged_attention_v1( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); @@ -45,8 +46,9 @@ void paged_attention_v2( torch::Tensor& value_cache, int64_t num_kv_heads, double scale, torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, int64_t max_seq_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale, - const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale, const int64_t tp_rank, + const int64_t blocksparse_local_blocks, const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, const int64_t blocksparse_head_sliding_step); diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 0fec9624c457e..9477790629c9f 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -218,7 +218,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + int max_ctx_blocks, const float* k_scale_ptr, const float* v_scale_ptr) { constexpr int NWARPS = NUM_THREADS / WARP_SIZE; const int warpid = threadIdx.x / WARP_SIZE; const int laneid = threadIdx.x % WARP_SIZE; @@ -406,7 +406,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; const _B8x8 Vlocalb8 = v_ptrh8be[d]; Vlocal[h][b * BLOCK_SIZE / 8 + d] = - scaled_convert_b8x8(Vlocalb8, v_scale); + scaled_convert_b8x8(Vlocalb8, *v_scale_ptr); } } } @@ -416,7 +416,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( #pragma unroll for (int d = 0; d < KHELOOP; d++) { Klocal[d] = - scaled_convert_b8x8(Klocalb8[d], k_scale); + scaled_convert_b8x8(Klocalb8[d], *k_scale_ptr); } } @@ -890,7 +890,7 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, // head_size] scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] - int max_ctx_blocks, float k_scale, float v_scale) { + int max_ctx_blocks, const float* k_scale, const float* v_scale) { UNREACHABLE_CODE } @@ -919,7 +919,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ - k_scale, v_scale); + k_scale_ptr, v_scale_ptr); template @@ -929,7 +929,7 @@ void paged_attention_custom_launcher( torch::Tensor& value_cache, const int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, int max_context_len, const std::optional& alibi_slopes, - float k_scale, float v_scale) { + torch::Tensor& k_scale, torch::Tensor& v_scale) { int num_seqs = query.size(0); int num_heads = query.size(1); int head_size = query.size(2); @@ -953,6 +953,8 @@ void paged_attention_custom_launcher( KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); + const float* k_scale_ptr = reinterpret_cast(k_scale.data_ptr()); + const float* v_scale_ptr = reinterpret_cast(v_scale.data_ptr()); const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); const int max_num_partitions = @@ -1087,7 +1089,8 @@ void paged_attention( torch::Tensor& context_lens, // [num_seqs] int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, double v_scale) { + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 34b2f9ce8a4c4..ba161951772ad 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -10,5 +10,5 @@ void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& context_lens, int64_t block_size, int64_t max_context_len, const std::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, - double v_scale); + const std::string& kv_cache_dtype, torch::Tensor& k_scale, + torch::Tensor& v_scale); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index a283d4263d293..a5d2e2f97a3ed 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -27,7 +27,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " int max_context_len," " Tensor? alibi_slopes," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); } diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index fb53d122487d3..ec63170d511f0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -30,7 +30,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -44,7 +44,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor value_cache, int num_kv_heads, float scale," " Tensor block_tables, Tensor seq_lens, int block_size," " int max_seq_len, Tensor? alibi_slopes," - " str kv_cache_dtype, float k_scale, float v_scale," + " str kv_cache_dtype, Tensor k_scale, Tensor v_scale," " int tp_rank, int blocksparse_local_blocks," " int blocksparse_vert_stride, int blocksparse_block_size," " int blocksparse_head_sliding_step) -> ()"); @@ -449,7 +449,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! key_cache, Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); // Reshape the key and value tensors and cache them. @@ -459,7 +459,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { " Tensor! value_cache," " Tensor slot_mapping," " str kv_cache_dtype," - " float k_scale, float v_scale) -> ()"); + " Tensor k_scale, Tensor v_scale) -> ()"); cache_ops.impl("reshape_and_cache_flash", torch::kCUDA, &reshape_and_cache_flash); diff --git a/docs/source/features/quantization/quantized_kvcache.md b/docs/source/features/quantization/quantized_kvcache.md index 95fa5e81e2f74..9f36c2949e0dd 100644 --- a/docs/source/features/quantization/quantized_kvcache.md +++ b/docs/source/features/quantization/quantized_kvcache.md @@ -35,16 +35,18 @@ Studies have shown that FP8 E4M3 quantization typically only minimally degrades Here is an example of how to enable FP8 quantization: ```python +# To calculate kv cache scales on the fly enable the calculate_kv_scales +# parameter + from vllm import LLM, SamplingParams sampling_params = SamplingParams(temperature=0.7, top_p=0.8) -llm = LLM(model="meta-llama/Llama-3.1-8B-Instruct", kv_cache_dtype="fp8") +llm = LLM(model="meta-llama/Llama-2-7b-chat-hf", + kv_cache_dtype="fp8", + calculate_kv_scales=True) prompt = "London is the capital of" out = llm.generate(prompt, sampling_params)[0].outputs[0].text print(out) - -# output w/ scaling factors: England, the United Kingdom, and one of the world's leading financial, -# output w/o scaling factors: England, located in the southeastern part of the country. It is known ``` The `kv_cache_dtype` argument specifies the data type for KV cache storage: diff --git a/examples/other/fp8/README.md b/examples/other/fp8/README.md deleted file mode 100644 index 4e8031d954113..0000000000000 --- a/examples/other/fp8/README.md +++ /dev/null @@ -1,96 +0,0 @@ -# FP8 KV Cache - -This utility extracts the KV cache scaling factors from a quantized HF (Hugging Face) model. The extracted scaling factors are saved to a JSON file, which can later be used by vLLM (variable-length language model) during runtime. This tool is particularly useful when the KV cache data type is FP8 and is intended for use on ROCm (AMD GPU) platforms. - -## Prerequisites - -- Python 3.x -- PyTorch -- NumPy -- Hugging Face Transformers -- Hugging Face Hub -- AMMO - -Before incorporating the FP8 datatype for inference workloads, you must adhere to the following steps: -1. Install all necessary prerequisites and dependencies. -2. Convert HF model into a quantized HF model. -3. Extract KV Cache Scaling Factors from quantized HF model. -4. Load KV Cache Scaling Factors into VLLM. - -### 2. Convert HF model into a quantized HF model. -Note: The following steps are adapted from the [TensorRT-LLM repository](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/README.md). - -`quantize.py` (examples/other/fp8/quantizer/quantize.py) uses the quantization toolkit (AMMO) to calibrate the PyTorch models and export TensorRT-LLM checkpoints. Each TensorRT-LLM checkpoint contains a config file (in .json format) and one or several rank weight files (in .safetensors format). - -The detailed quantization toolkit (AMMO) conversion guide for FP8 can be found at `examples/other/fp8/quantizer/README.md`. - -### 3. Extract KV Cache Scaling Factors from quantized HF model. -`extract_scales.py` (examples/other/fp8/extract_scales.py) can be utilized to extract the KV cache scaling factors from your quantized HF model, however at the moment, this tool exclusively supports Llama 2 models. It is also important to note the following: -1. **File Structure**: The utility operates under the assumption that all parameters, including KV cache scaling factors, corresponding to a particular Tensor Parallelism (TP) rank are stored in a single file. These files must adhere to a specific naming convention where the TP rank is immediately identified after a specific keyword (e.g., "rank") in the filename. - -2. **TP Decomposition**: The utility assumes consistency between the TP decomposition employed by the quantizer tool and that used by vLLM. - -3. **AMMO Compatibility**: Currently, the generated KV cache scaling factors for AMMO remain uniform across all TP ranks. - -```python -# prerequisites: -# - Quantized HF LLaMa 2 model -python3 examples/other/fp8/extract_scales.py --help -Usage: extract_scales.py [-h] --quantized_model QUANTIZED_MODEL [--load_format {auto,safetensors,npz,pt}] [--output_dir OUTPUT_DIR] [--output_name OUTPUT_NAME] [--tp_size TP_SIZE] - -KV Scale Extraction Example - -optional arguments: ---quantized_model: Specify either the local path to, or name of, a quantized HF model. It is expected that the quantization format is FP8_E4M3, for use on ROCm (AMD GPU). -Optional arguments: ---cache_dir: Specify a cache directory to use in the event of a HF model download. (Default: None) ---load_format: Specify the format of the model's tensor files containing the KV cache scaling factors. (Choices: auto, safetensors, npz, pt; Default: auto) ---revision: Specify the model's revision number. (Default: None) ---output_dir: Specify the output directory. By default the KV cache scaling factors will be saved in the model directory. (Default: None) ---output_name: Specify the output filename. (Default: kv_cache_scales.json) ---tp_size: Specify the tensor-parallel (TP) size that the quantized model should correspond to. If specified, during KV cache scaling factor extraction the observed TP size will be checked against this and an error will be raised if there is a mismatch. (Default: None) -``` -```python -Example: -python3 examples/other/fp8/extract_scales.py --quantized_model --tp_size --output_dir -``` -### 4. Load KV Cache Scaling Factors into VLLM. -This script evaluates the inference throughput of language models using various backends such as vLLM. It measures the time taken to process a given number of prompts and generate sequences for each prompt. The recently generated KV cache scaling factors are now integrated into the benchmarking process and allow for KV cache scaling factors to be utilized for FP8. -``` -# prerequisites: -# - LLaMa 2 kv_cache_scales.json file - -python3 benchmarks/benchmark_throughput.py --help -usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] - [--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] - [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] - [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] - [--quantization-param-path KV_CACHE_quantization_param_path] - -Benchmark Throughput Example -optional arguments: - -h, --help show this help message and exit - --backend {vllm,hf,mii} - --dataset DATASET Path to the dataset. - --input-len INPUT_LEN Input prompt length for each request - --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. - --model MODEL - --tokenizer TOKENIZER - --quantization {awq,gptq,None}, -q {awq,gptq,None} - --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE - --n N Number of generated sequences per prompt. - --use-beam-search - --num-prompts NUM_PROMPTS Number of prompts to process. - --seed SEED - --hf-max-batch-size HF_MAX_BATCH_SIZE Maximum batch size for HF backend. - --trust-remote-code trust remote code from huggingface - --max-model-len MAX_MODEL_LEN Maximum length of a sequence (including prompt and output). If None, will be derived from the model. - --dtype {auto,half,float16,bfloat16,float,float32} data type for model weights and activations. The "auto" option will use FP16 precision for FP32 and FP16 models, and BF16 precision for BF16 models. - --enforce-eager enforce eager execution - --kv-cache-dtype {auto,fp8} Data type for kv cache storage. If "auto", will use model data type. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported ```for common inference criteria. - --quantization-param-path QUANT_PARAM_JSON Path to the JSON file containing the KV cache scaling factors. This should generally be supplied, when KV cache dtype is FP8. Otherwise, KV cache scaling factors default to 1.0, which may cause accuracy issues. FP8_E5M2 (without scaling) is only supported on cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for common inference criteria. -``` -Example: -```console -python3 benchmarks/benchmark_throughput.py --input-len --output-len -tp --kv-cache-dtype fp8 --quantization-param-path --model -``` diff --git a/examples/other/fp8/extract_scales.py b/examples/other/fp8/extract_scales.py deleted file mode 100644 index 1dce9d7e993a0..0000000000000 --- a/examples/other/fp8/extract_scales.py +++ /dev/null @@ -1,367 +0,0 @@ -import argparse -import glob -import json -import os -from typing import Any, Callable, Dict, List, Optional, Tuple - -import numpy as np -import torch -from safetensors.torch import safe_open - -from vllm.model_executor.layers.quantization.schema import QuantParamSchema - - -# Adapted from vllm/model_executor/model_loader/weight_utils.py -# The main differences are that we add the NPZ format and simplify -# its functionality drastically for our purposes (e.g. we assume that -# the quantized model exists locally and there is no need to download it) -def _prepare_hf_weights( - quantized_model_dir: str, - load_format: str = "auto", - fall_back_to_pt: bool = True, -) -> Tuple[List[str], bool]: - if not os.path.isdir(quantized_model_dir): - raise FileNotFoundError( - f"The quantized model directory `{quantized_model_dir}` " - "does not exist.") - use_safetensors = False - # Some quantized models use .pt files for storing the weights. - if load_format == "auto": - allow_patterns = ["*.safetensors", "*.bin"] - elif load_format == "safetensors": - use_safetensors = True - allow_patterns = ["*.safetensors"] - elif load_format == "pt": - allow_patterns = ["*.pt"] - elif load_format == "npz": - allow_patterns = ["*.npz"] - else: - raise ValueError(f"Unknown load_format: {load_format}") - if fall_back_to_pt: - allow_patterns += ["*.pt"] - - hf_weights_files: List[str] = [] - for pattern in allow_patterns: - hf_weights_files += glob.glob( - os.path.join(quantized_model_dir, pattern)) - if len(hf_weights_files) > 0: - if pattern == "*.safetensors": - use_safetensors = True - break - - if not use_safetensors: - # Exclude files that are not needed for inference. - # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 - blacklist = [ - "training_args.bin", - "optimizer.bin", - "optimizer.pt", - "scheduler.pt", - "scaler.pt", - ] - hf_weights_files = [ - f for f in hf_weights_files - if not any(f.endswith(x) for x in blacklist) - ] - - if len(hf_weights_files) == 0: - raise RuntimeError( - f"Cannot find any model weights with `{quantized_model_dir}`") - - return hf_weights_files, use_safetensors - - -# Adapted from vllm/model_executor/model_loader/weight_utils.py -def _hf_tensorfile_iterator(filename: str, load_format: str, - use_safetensors: bool): - if load_format == "npz": - assert not use_safetensors - with np.load(filename) as data: - for name in data.files: - param = torch.from_numpy(data[name]) - yield name, param - elif use_safetensors: - with safe_open(filename, framework="pt") as f: - for name in f.keys(): # NOQA: SIM118 - param = f.get_tensor(name) - yield name, param - else: - state = torch.load(filename, map_location="cpu") - for name, param in state.items(): - yield name, param - del state - torch.cuda.empty_cache() - - -def _kv_scales_extractor( - hf_tensor_files: List[str], - use_safetensors: bool, - rank_keyword: str = "rank", - expected_tp_size: Optional[int] = None) -> Dict[int, Dict[int, float]]: - """ - Given a list of files containing tensor data, attempt to extract KV cache - scales from these files. Intended as a helper function taking in the output - from _prepare_hf_weights. - Args: - rank_keyword Matches the number immediately after this keyword in the - tensor filename to determine the TP rank corresponding - to said tensor file - expected_tp_size If specified, the TP size of the tensor files is checked - against this and an error is raised if they don't match. - Returns a dictionary mapping TP ranks to their relevant KV cache scales. - The per-rank scales are themselves represented as a dictionary of layer - indices to the respective per-layer scale. - """ - for char in rank_keyword: - assert not char.isdecimal( - ), f"Rank keyword {rank_keyword} contains a numeric character!" - rank_scales_map: Dict[int, Dict[int, float]] = {} - for tensor_file in hf_tensor_files: - try: - rank_idx = tensor_file.find(rank_keyword) - if rank_idx != -1: - start_idx = rank_idx + len(rank_keyword) - stop_idx = start_idx - while stop_idx < len( - tensor_file) and tensor_file[stop_idx].isdecimal(): - stop_idx += 1 - if stop_idx == start_idx: - raise RuntimeError("Did not find rank # in filename.") - rank = int(tensor_file[start_idx:stop_idx]) - elif len(hf_tensor_files) == 1: - # Since there is only one tensor file, we can assume - # that it's intended for TP rank 0 - rank = 0 - else: - raise RuntimeError( - f"Filename does not contain '{rank_keyword}'.") - except RuntimeError: - print("Unable to determine TP rank " - f"corresponding to file '{tensor_file}'") - raise - - if rank not in rank_scales_map: - layer_scales_map: Dict[int, float] = {} - rank_scales_map[rank] = layer_scales_map - else: - raise RuntimeError( - f"Tensor file '{tensor_file}' shares TP rank {rank} " - "with another tensor file.") - - module_delimiter = ":" if args.load_format == "npz" else "." - for name, param in _hf_tensorfile_iterator(tensor_file, - args.load_format, - use_safetensors): - if "kv_cache_scaling_factor" in name: - nums = [ - int(s) for s in name.split(module_delimiter) - if s.isdecimal() - ] - assert len( - nums) == 1, f"Could not determine layer idx for {name}" - layer_idx = nums[0] - assert layer_idx not in layer_scales_map, f"Duplicate scaling"\ - f" factor corresponding to layer {layer_idx}" - try: - layer_scales_map[layer_idx] = param.item() - except RuntimeError: - print( - "This utility supports only per-tensor scalar scales " - f"for now. The tensor\n {name} = {param} \nis an " - "invalid scale factor.") - raise - - if all( - len(layer_scales_map) == 0 - for layer_scales_map in rank_scales_map.values()): - # Note: this is true even if the rank_scales_map is empty - print("WARNING: No KV cache scale factors found. No output saved.") - return None - empirical_tp_world_size = max(rank_scales_map.keys()) + 1 - if expected_tp_size is not None: - assert expected_tp_size == empirical_tp_world_size, \ - f"User expected TP world size = {expected_tp_size} " \ - "from model but tool is expecting TP world size = " \ - f"{empirical_tp_world_size} from model instead." - for i in range(empirical_tp_world_size): - assert i in rank_scales_map, "Expected TP world size = "\ - f"{empirical_tp_world_size} but did not find KV " \ - f"cache scaling factors for TP rank {i}" - print(f"Found TP world size = {empirical_tp_world_size} " - "when extracting KV cache scales!") - return rank_scales_map - - -def _metadata_extractor(quantized_model_dir: str, - metadata_extract_fns: \ - Dict[str, Callable[[Dict[str, Any]], Any]]) \ - -> Dict[str, Any]: - """ - Given a directory containing quantized model files, this function - aims to extract metadata from the JSON files within this directory. - Each JSON file is expected to represent a dictionary in JSON - format (referred to as a "JSON-dictionary"). Metadata extraction is - defined by a dictionary called metadata_extract_fns, where each - metadata field name is mapped to an extraction function. - - These extraction functions are designed to take a JSON-dictionary - as their only argument and return the corresponding metadata. - While extraction functions are permitted to raise exceptions, they - should only raise a KeyError or ValueError if the metadata field - cannot be extracted from the current JSON-dictionary, yet there's - a possibility of finding it in another JSON-dictionary. - - The function returns a dictionary that maps metadata fields to - their extracted data. The keys of this dictionary correspond exactly - to those in metadata_extract_fns. If any fields fail to be extracted, - their corresponding values are set to None, and a warning is printed. - """ - if not os.path.isdir(quantized_model_dir): - raise FileNotFoundError( - f"The quantized model directory `{quantized_model_dir}` " - "does not exist.") - metadata_files = glob.glob(os.path.join(quantized_model_dir, "*.json")) - - result: Dict[str, Any] = {} - for file in metadata_files: - with open(file) as f: - try: - metadata = json.load(f) - except json.JSONDecodeError: - print(f"Could not parse `{file}` as a valid metadata file," - " skipping it.") - continue - if not isinstance(metadata, dict): - print(f"The file `{file}` does not correspond to a " - "JSON-serialized dictionary, skipping it.") - continue - for metadata_name, extract_fn in metadata_extract_fns.items(): - try: - metadata_info = extract_fn(metadata) - if metadata_name not in result: - result[metadata_name] = metadata_info - elif metadata_info != result[metadata_name]: - raise RuntimeError( - "Metadata mismatch! Originally found " - f"{metadata_name} = {result[metadata_name]} but " - f"now found {metadata_name} = {metadata_info} in " - f"`{file}`") - except KeyError: - # It is possible that a given file does not contain some - # of our selected metadata as it could be located in some - # other metadata file. - # 'EFINAE': extract_fn failure is not an error. - pass - except ValueError: - # See above. - pass - - # Warn if we cannot find any of the requested metadata - for metadata_name in metadata_extract_fns: - if metadata_name not in result: - print("WARNING: Unable to find requested metadata field " - f"`{metadata_name}`, setting it to None.") - result[metadata_name] = None - - return result - - -def main(args): - metadata_extract_fns = { - "model_type": lambda json_dict: json_dict["layers"][0]["decoder_type"], - "tp_size": lambda json_dict: int(json_dict["tensor_parallel"]), - "model_dtype": lambda json_dict: json_dict["dtype"] - } - recovered_metadata = _metadata_extractor(args.quantized_model, - metadata_extract_fns) - if args.tp_size is not None: - metadata_tp_size = recovered_metadata["tp_size"] - if metadata_tp_size is not None: - assert args.tp_size == metadata_tp_size, \ - f"User expected TP world size = {args.tp_size} " \ - f"but found TP world size = {metadata_tp_size} from metadata!" - expected_tp_size = args.tp_size or recovered_metadata["tp_size"] - rank_keyword = "rank" - hf_tensor_files, use_safetensors = _prepare_hf_weights( - args.quantized_model, args.load_format) - rank_scales_map = _kv_scales_extractor(hf_tensor_files, use_safetensors, - rank_keyword, expected_tp_size) - # Postprocess: formatting to the current schema. Consider pulling it - # out into a dedicated function should it ever become more complicated. - rank_scales_map = { - rank: {k: scale[k] - for k in sorted(scale.keys())} - for rank, scale in rank_scales_map.items() - } - # TODO: Expand this with activation and weights scaling factors when - # they are used in the future - schema = QuantParamSchema( - model_type=recovered_metadata["model_type"], - kv_cache={ - "dtype": ("float8_e4m3fn" if len(rank_scales_map) > 0 else - recovered_metadata["model_dtype"]), - "scaling_factor": - rank_scales_map - }, - ) - - if args.output_dir is None: - output_file = os.path.join(args.quantized_model, args.output_name) - else: - if not os.path.isdir(args.output_dir): - os.makedirs(args.output_dir, exist_ok=True) - output_file = os.path.join(args.output_dir, args.output_name) - - with open(output_file, 'w') as f: - f.write(schema.model_dump_json(indent=4)) - print(f"Completed! KV cache scaling factors saved to {output_file}") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="This simple utility extracts the " - "KV cache scaling factors from a quantized HF model " - "and saves them to a JSON file compatible with later " - "use by vLLM (pass this file to the appropriate " - "runtime typically using the argument " - "--quantization-param-path ). This is only used " - "if the KV cache dtype is FP8 and on ROCm (AMD GPU).") - parser.add_argument( - "--quantized-model", - help="Specify the directory containing a single quantized HF model. " - "It is expected that the quantization format is FP8_E4M3, for use " - "on ROCm (AMD GPU).", - required=True) - parser.add_argument( - "--load_format", - help="Optionally specify the format of the model's tensor files " - "containing the KV cache scaling factors.", - choices=["auto", "safetensors", "npz", "pt"], - default="auto") - parser.add_argument( - "--output-dir", - help="Optionally specify the output directory. By default the " - "KV cache scaling factors will be saved in the model directory, " - "however you can override this behavior here.", - default=None) - parser.add_argument( - "--output-name", - help="Optionally specify the output filename.", - # TODO: Change this once additional scaling factors are enabled - default="kv_cache_scales.json") - parser.add_argument( - "--tp-size", - help="Optionally specify the tensor-parallel (TP) size that the " - "quantized model should correspond to. If specified, during KV " - "cache scaling factor extraction the observed TP size will be " - "checked against this and an error will be raised if there is " - "a mismatch. If not specified, the quantized model's expected " - "TP size is instead inferred from the largest TP rank observed. " - "The expected TP size is cross-checked against the TP ranks " - "observed in the quantized model and an error is raised if any " - "discrepancies are found.", - default=None, - type=int) - args = parser.parse_args() - - main(args) diff --git a/examples/other/fp8/quantizer/README.md b/examples/other/fp8/quantizer/README.md deleted file mode 100644 index d0895e97dc341..0000000000000 --- a/examples/other/fp8/quantizer/README.md +++ /dev/null @@ -1,32 +0,0 @@ -### Quantizer Utilities -`quantize.py`: NVIDIA Quantization utilities using TensorRT-Model-Optimizer, ported -from TensorRT-LLM: [`examples/quantization/quantize.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/quantization/quantize.py) - -### Prerequisite - -#### AMMO (AlgorithMic Model Optimization) Installation: nvidia-ammo 0.7.1 or later -`pip install --no-cache-dir --extra-index-url https://pypi.nvidia.com nvidia-ammo` - -#### AMMO Download (code and docs) -`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.5.0.tar.gz` -`https://developer.nvidia.com/downloads/assets/cuda/files/nvidia-ammo/nvidia_ammo-0.7.1.tar.gz` - -### Usage - -#### Run on H100 system for speed if FP8; number of GPUs depends on the model size - -#### Example: quantize Llama2-7b model from HF to FP8 with FP8 KV Cache: -`python quantize.py --model-dir ./ll2-7b --dtype float16 --qformat fp8 --kv-cache-dtype fp8 --output-dir ./ll2_7b_fp8 --calib-size 512 --tp-size 1` - -Outputs: model structure, quantized model & parameters (with scaling factors) are in JSON and Safetensors (npz is generated only for the reference) -``` -# ll ./ll2_7b_fp8/ -total 19998244 -drwxr-xr-x 2 root root 4096 Feb 7 01:08 ./ -drwxrwxr-x 8 1060 1061 4096 Feb 7 01:08 ../ --rw-r--r-- 1 root root 176411 Feb 7 01:08 llama_tp1.json --rw-r--r-- 1 root root 13477087480 Feb 7 01:09 llama_tp1_rank0.npz --rw-r--r-- 1 root root 7000893272 Feb 7 01:08 rank0.safetensors -# -``` - diff --git a/examples/other/fp8/quantizer/quantize.py b/examples/other/fp8/quantizer/quantize.py deleted file mode 100644 index d75cc8b3d1cf7..0000000000000 --- a/examples/other/fp8/quantizer/quantize.py +++ /dev/null @@ -1,367 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # noqa: E501 -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Adapted from examples/quantization/hf_ptq.py -""" - -import argparse -import copy -import json -import random -import time - -import ammo.torch.quantization as atq -import numpy as np -import torch -from ammo.torch.export import export_model_config -from datasets import load_dataset -from torch.utils.data import DataLoader -from transformers import AutoModelForCausalLM, AutoTokenizer - -RAND_SEED = 1234 -MAX_SEQ_LEN = 2048 - -EMPTY_CFG = { - "quant_cfg": { - "*weight_quantizer": { - "enable": False, - }, - "*input_quantizer": { - "enable": False - }, - "*lm_head*": { - "enable": False - }, - "*output_layer*": { - "enable": False - }, - "default": { - "enable": False - }, - }, - "algorithm": "max", -} - -KV_CACHE_CFG = { - "*.query_key_value.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.Wqkv.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.W_pack.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.c_attn.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.k_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, - "*.v_proj.output_quantizer": { - "num_bits": 8, - "axis": None, - "enable": True - }, -} - -QUANT_CFG_CHOICES = { - "int8_sq": atq.INT8_SMOOTHQUANT_CFG, - "fp8": atq.FP8_DEFAULT_CFG, - "int4_awq": atq.INT4_AWQ_CFG, - "w4a8_awq": atq.W4A8_AWQ_BETA_CFG, - "int8_wo": EMPTY_CFG, - "int4_wo": EMPTY_CFG, - "full_prec": EMPTY_CFG, -} - -MODEL_NAME_PATTERN_MAP = { - "GPT2": "gpt2", - "Xverse": "llama", - "Llama": "llama", - "Mistral": "llama", - "GPTJ": "gptj", - "FalconForCausalLM": "falcon", - "RWForCausalLM": "falcon", - "baichuan": "baichuan", - "MPT": "mpt", - "Bloom": "bloom", - "ChatGLM": "chatglm", - "QWen": "qwen", -} - - -def get_tokenizer(ckpt_path, max_seq_len=MAX_SEQ_LEN, model_type=None): - print(f"Initializing tokenizer from {ckpt_path}") - tokenizer = AutoTokenizer.from_pretrained( - ckpt_path, - model_max_length=max_seq_len, - padding_side="left", - trust_remote_code=True, - ) - if model_type and model_type == "qwen": - # qwen use token id 151643 as pad and eos tokens - tokenizer.pad_token = tokenizer.convert_ids_to_tokens(151643) - tokenizer.eos_token = tokenizer.convert_ids_to_tokens(151643) - - # can't set attribute 'pad_token' for "" - if tokenizer.pad_token != "": - tokenizer.pad_token = tokenizer.eos_token - if tokenizer.pad_token is None: - tokenizer.pad_token = tokenizer.eos_token - assert (tokenizer.pad_token - is not None), f"Pad token for {model_type} cannot be set!" - - return tokenizer - - -def get_model(ckpt_path, dtype="fp16", device="cuda"): - print(f"Initializing model from {ckpt_path}") - if dtype == "bf16" or dtype == "bfloat16": - dtype = torch.bfloat16 - elif dtype == "fp16" or dtype == "float16": - dtype = torch.float16 - elif dtype == "fp32" or dtype == "float32": - dtype = torch.float32 - else: - raise NotImplementedError(f"Unknown dtype {dtype}") - - # model_kwargs = {"torch_dtype": dtype} - model_kwargs = {"torch_dtype": "auto"} - - model = AutoModelForCausalLM.from_pretrained(ckpt_path, - device_map="auto", - **model_kwargs, - trust_remote_code=True) - model.eval() - - model_dtype = next(model.parameters()).dtype - if dtype != model_dtype: - print("[TensorRT-LLM][WARNING] The manually set model data type is " - f"{dtype}, but the data type of the HuggingFace model is " - f"{model_dtype}.") - - return model - - -def get_model_type(model): - for k, v in MODEL_NAME_PATTERN_MAP.items(): - if k.lower() in type(model).__name__.lower(): - return v - return None - - -def get_calib_dataloader(data="cnn_dailymail", - tokenizer=None, - batch_size=1, - calib_size=512, - block_size=512, - device=None): - print("Loading calibration dataset") - if data == "pileval": - dataset = load_dataset( - "json", - data_files="https://the-eye.eu/public/AI/pile/val.jsonl.zst", - split="train") - dataset = dataset["text"][:calib_size] - elif data == "cnn_dailymail": - dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train") - dataset = dataset["article"][:calib_size] - else: - raise NotImplementedError - - batch_encoded = tokenizer.batch_encode_plus(dataset, - return_tensors="pt", - padding="max_length", - truncation=True, - max_length=block_size) - if device: - batch_encoded = batch_encoded.to(device) - batch_encoded = batch_encoded["input_ids"] - - calib_dataloader = DataLoader(batch_encoded, - batch_size=batch_size, - shuffle=False) - - return calib_dataloader - - -def quantize_model(model, quant_cfg, calib_dataloader=None): - - def calibrate_loop(): - if calib_dataloader is None: - return - """Adjusts weights and scaling factors based on selected algorithms.""" - for idx, data in enumerate(calib_dataloader): - print(f"Calibrating batch {idx}") - model(data) - - print("Starting quantization...") - start_time = time.time() - atq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - end_time = time.time() - print("Quantization done. Total time used: {:.2f} s.".format(end_time - - start_time)) - - return model - - -def main(args): - if not torch.cuda.is_available(): - raise OSError("GPU is required for inference.") - - random.seed(RAND_SEED) - np.random.seed(RAND_SEED) - - model = get_model(args.model_dir, args.dtype, args.device) - model_type = get_model_type(model) - tokenizer = get_tokenizer(args.model_dir, model_type=model_type) - - if args.qformat in ["full_prec", "int8_wo", "int4_wo" - ] and args.kv_cache_dtype is None: - print(f"No quantization applied, export {args.dtype} model") - else: - if "awq" in args.qformat: - if args.calib_size > 32: - print("AWQ calibration could take longer with calib_size = " - f"{args.calib_size}, Using calib_size=32 instead") - args.calib_size = 32 - print("\nAWQ calibration could take longer than other calibration " - "methods. Please increase the batch size to speed up the " - "calibration process. Batch size can be set by adding the " - "argument --batch_size to the command line.\n") - - calib_dataloader = get_calib_dataloader( - tokenizer=tokenizer, - batch_size=args.batch_size, - calib_size=args.calib_size, - device=args.device, - ) - - if args.qformat in QUANT_CFG_CHOICES: - quant_cfg = QUANT_CFG_CHOICES[args.qformat] - else: - raise ValueError( - f"Unsupported quantization format: {args.qformat}") - - if "awq" in args.qformat: - quant_cfg = copy.deepcopy(QUANT_CFG_CHOICES[args.qformat]) - weight_quantizer = quant_cfg["quant_cfg"][ - "*weight_quantizer"] # type: ignore - if isinstance(weight_quantizer, list): - weight_quantizer = weight_quantizer[0] - weight_quantizer["block_sizes"][-1] = args.awq_block_size - - if args.kv_cache_dtype is not None: - if args.kv_cache_dtype == "fp8": - for value in KV_CACHE_CFG.values(): - value.update({"num_bits": (4, 3)}) # type: ignore - quant_cfg["quant_cfg"].update(KV_CACHE_CFG) # type: ignore - - print(quant_cfg) - - model = quantize_model(model, quant_cfg, calib_dataloader) - - with torch.inference_mode(): - if model_type is None: - print(f"Unknown model type {type(model).__name__}. Continue " - "exporting...") - model_type = f"unknown:{type(model).__name__}" - - export_path = args.output_dir - start_time = time.time() - - if args.qformat == "int4_awq" and model_type == "qwen": - torch.save(model.state_dict(), export_path) - else: - export_npz = (model_type not in [ - 'gptj', 'falcon', 'chatglm', 'mpt', 'llama', 'baichuan' - ]) - - # export safetensors - export_model_config( - model, - model_type, - getattr(torch, args.dtype), - export_dir=export_path, - inference_tensor_parallel=args.tp_size, - inference_pipeline_parallel=args.pp_size, - # export_tensorrt_llm_config=(not export_npz), - export_tensorrt_llm_config=False, - export_npz=export_npz) - - # Workaround for wo quantization - if args.qformat in ["int8_wo", "int4_wo", "full_prec"]: - with open(f"{export_path}/config.json") as f: - tensorrt_llm_config = json.load(f) - if args.qformat == "int8_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W8A16' - elif args.qformat == "int4_wo": - tensorrt_llm_config["quantization"]["quant_algo"] = 'W4A16' - else: - tensorrt_llm_config["quantization"]["quant_algo"] = None - with open(f"{export_path}/config.json", "w") as f: - json.dump(tensorrt_llm_config, f, indent=4) - - end_time = time.time() - print("Quantized model exported to {} \nTotal time used {:.2f} s.". - format(export_path, end_time - start_time)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument("--model-dir", - help="Specify where the HuggingFace model is", - required=True) - parser.add_argument("--device", default="cuda") - parser.add_argument("--dtype", help="Model data type.", default="float16") - parser.add_argument( - "--qformat", - help="Quantization format.", - default="full_prec", - choices=[ - "fp8", "int8_sq", "int4_awq", "w4a8_awq", "int8_wo", "int4_wo", - "full_prec" - ], - ) - parser.add_argument("--batch-size", - help="Batch size for calibration.", - type=int, - default=1) - parser.add_argument("--calib-size", - help="Number of samples for calibration.", - type=int, - default=512) - parser.add_argument("--output-dir", default="exported_model") - parser.add_argument("--tp-size", type=int, default=1) - parser.add_argument("--pp-size", type=int, default=1) - parser.add_argument("--awq-block-size", type=int, default=128) - parser.add_argument("--kv-cache-dtype", - help="KV Cache dtype.", - default=None, - choices=["int8", "fp8", None]) - args = parser.parse_args() - - main(args) diff --git a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json deleted file mode 100644 index a548f0a9611f6..0000000000000 --- a/tests/fp8_kv/llama2-70b-fp8-kv/kv_cache_scales.json +++ /dev/null @@ -1,90 +0,0 @@ -{ - "model_type": "llama", - "kv_cache": { - "dtype": "float8_e4m3fn", - "scaling_factor": { - "0": { - "0": 0.0230364128947258, - "1": 0.01979283057153225, - "2": 0.0241350457072258, - "3": 0.0308314748108387, - "4": 0.0430733822286129, - "5": 0.0370396226644516, - "6": 0.0306222103536129, - "7": 0.0357491634786129, - "8": 0.0358189195394516, - "9": 0.0443289652466774, - "10": 0.0433175228536129, - "11": 0.0416782945394516, - "12": 0.0366908498108387, - "13": 0.0432477705180645, - "14": 0.0410505048930645, - "15": 0.0457589291036129, - "16": 0.0418526791036129, - "17": 0.0432477705180645, - "18": 0.0469447560608387, - "19": 0.0514787957072258, - "20": 0.0541294664144516, - "21": 0.0587681382894516, - "22": 0.0625, - "23": 0.0585588738322258, - "24": 0.0600237175822258, - "25": 0.0588030144572258, - "26": 0.0531180277466774, - "27": 0.06396484375, - "28": 0.0603027381002903, - "29": 0.0582101047039032, - "30": 0.0625348836183548, - "31": 0.0585588738322258, - "32": 0.0582798570394516, - "33": 0.0575125589966774, - "34": 0.0590820349752903, - "35": 0.0614188089966774, - "36": 0.0631975457072258, - "37": 0.0615931935608387, - "38": 0.0601283498108387, - "39": 0.0571986623108387, - "40": 0.0670340433716774, - "41": 0.0523507259786129, - "42": 0.0547223798930645, - "43": 0.0631975457072258, - "44": 0.0663713738322258, - "45": 0.0603376142680645, - "46": 0.0652204304933548, - "47": 0.0734514519572258, - "48": 0.0693708211183548, - "49": 0.0725446492433548, - "50": 0.0627790242433548, - "51": 0.0691266804933548, - "52": 0.0688825398683548, - "53": 0.068429134786129, - "54": 0.0605119988322258, - "55": 0.0799386203289032, - "56": 0.0853097140789032, - "57": 0.0661969929933548, - "58": 0.0689871683716774, - "59": 0.0724051371216774, - "60": 0.0541643425822258, - "61": 0.0626743882894516, - "62": 0.0628487765789032, - "63": 0.0607212632894516, - "64": 0.0589076466858387, - "65": 0.0451660193502903, - "66": 0.0453055277466774, - "67": 0.0414341539144516, - "68": 0.0385044664144516, - "69": 0.0414341539144516, - "70": 0.0466308631002903, - "71": 0.0399693101644516, - "72": 0.0437011756002903, - "73": 0.0434221550822258, - "74": 0.0428989976644516, - "75": 0.0401785746216774, - "76": 0.0431082621216774, - "77": 0.0484444759786129, - "78": 0.0417829267680645, - "79": 0.0418178029358387 - } - } - } -} \ No newline at end of file diff --git a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json b/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json deleted file mode 100644 index bb734039e982b..0000000000000 --- a/tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json +++ /dev/null @@ -1,42 +0,0 @@ -{ - "model_type": "llama", - "kv_cache": { - "dtype": "float8_e4m3fn", - "scaling_factor": { - "0": { - "0": 0.0152239128947258, - "1": 0.0188860222697258, - "2": 0.0354178324341774, - "3": 0.0376674123108387, - "4": 0.0418526791036129, - "5": 0.0433175228536129, - "6": 0.0397600457072258, - "7": 0.0424455925822258, - "8": 0.0415387861430645, - "9": 0.0408412404358387, - "10": 0.0395856611430645, - "11": 0.0377371683716774, - "12": 0.0400739423930645, - "13": 0.040771484375, - "14": 0.0393415205180645, - "15": 0.0369001142680645, - "16": 0.03857421875, - "17": 0.0387486070394516, - "18": 0.0403180830180645, - "19": 0.0396205373108387, - "20": 0.0375627800822258, - "21": 0.0407366082072258, - "22": 0.0432477705180645, - "23": 0.0377022884786129, - "24": 0.0399693101644516, - "25": 0.0374581478536129, - "26": 0.0413295216858387, - "27": 0.0442243330180645, - "28": 0.0424804724752903, - "29": 0.0456891767680645, - "30": 0.0409109964966774, - "31": 0.0482352152466774 - } - } - } -} diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index 124d5d297a574..574a0f223ef0d 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -182,7 +182,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the paged attention kernel. output = torch.empty_like(query) diff --git a/tests/kernels/test_blocksparse_attention.py b/tests/kernels/test_blocksparse_attention.py index fad342d1b5923..08f31219e3574 100644 --- a/tests/kernels/test_blocksparse_attention.py +++ b/tests/kernels/test_blocksparse_attention.py @@ -210,7 +210,7 @@ def test_paged_attention( key_cache, value_cache = key_caches[0], value_caches[0] # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) tp_rank = 0 # Call the paged attention kernel. diff --git a/tests/kernels/test_cache.py b/tests/kernels/test_cache.py index 40550ed51e2c7..c848be4f9d807 100644 --- a/tests/kernels/test_cache.py +++ b/tests/kernels/test_cache.py @@ -160,7 +160,7 @@ def test_reshape_and_cache( cloned_value_cache = value_cache.clone() # Using default kv_scale - k_scale = v_scale = 1.0 + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Call the reshape_and_cache kernel. opcheck(torch.ops._C_cache_ops.reshape_and_cache, @@ -258,8 +258,8 @@ def test_reshape_and_cache_flash( del key_caches del value_caches - k_scale = key.amax().item() / 256 - v_scale = value.amax().item() / 256 + k_scale = (key.amax() / 256.0).to(torch.float32) + v_scale = (value.amax() / 256.0).to(torch.float32) # Clone the KV caches. if kv_cache_dtype == "fp8": @@ -284,12 +284,12 @@ def test_reshape_and_cache_flash( result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) ops.convert_fp8(result_key_cache, key_cache, - k_scale, + k_scale.item(), kv_dtype=kv_cache_dtype) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) ops.convert_fp8(result_value_cache, value_cache, - v_scale, + v_scale.item(), kv_dtype=kv_cache_dtype) # Run the reference implementation. diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index 3fdb7996ba4e0..10e73ab950b0e 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -138,6 +138,7 @@ def test_contexted_kv_attention( # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time @@ -153,6 +154,8 @@ def test_contexted_kv_attention( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, sliding_window=sliding_window) torch.cuda.synchronize() start_time = time.time() @@ -168,6 +171,8 @@ def test_contexted_kv_attention( b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, sliding_window=sliding_window) torch.cuda.synchronize() end_time = time.time() @@ -366,6 +371,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: # to V_cache[num_blocks, num_kv_heads, head_size, block_size] v_cache = v_cache.view(-1, block_size, num_kv_heads, head_size).permute(0, 2, 3, 1).contiguous() + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) # Warm up the Triton kernel by calling it once before actually measuring # generation time @@ -381,6 +387,8 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, alibi_slopes=alibi_slopes) torch.cuda.synchronize() start_time = time.time() @@ -396,6 +404,8 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: b_seq_len, b_ctx_len, max_input_len, + k_scale, + v_scale, alibi_slopes=alibi_slopes) torch.cuda.synchronize() end_time = time.time() diff --git a/tests/kernels/utils.py b/tests/kernels/utils.py index 848eea7f54cab..8011398551b9d 100644 --- a/tests/kernels/utils.py +++ b/tests/kernels/utils.py @@ -909,6 +909,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -958,6 +959,7 @@ def make_test_metadata( num_prefills=num_prefills, slot_mapping=kv_mmap.slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, num_prefill_tokens=num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/tests/models/decoder_only/language/test_fp8.py b/tests/models/decoder_only/language/test_fp8.py index 53f23e24511b3..5f06f1e3a2fe9 100644 --- a/tests/models/decoder_only/language/test_fp8.py +++ b/tests/models/decoder_only/language/test_fp8.py @@ -19,18 +19,17 @@ @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") @pytest.mark.parametrize( - "kv_cache_dtype,base_model,test_model,scale_path", + "kv_cache_dtype,base_model,test_model", [ # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", - "nm-testing/Llama-3.2-1B-Instruct-FP8-KV", None), + "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"), # Test FP16 checkpoint w. fp8_e5m2 kv-cache. ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", - "meta-llama/Llama-3.2-1B-Instruct", None), + "meta-llama/Llama-3.2-1B-Instruct"), # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", - "meta-llama/Llama-2-7b-chat-hf", - "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + "meta-llama/Llama-2-7b-chat-hf") ]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @@ -48,7 +47,6 @@ def test_models( kv_cache_dtype: str, base_model: str, test_model: str, - scale_path: Optional[str], max_tokens: int, enforce_eager: bool, backend: str, @@ -76,10 +74,6 @@ def test_models( baseline_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) - extra_kwargs = {} - if scale_path is not None: - extra_kwargs["quantization_param_path"] = scale_path - with vllm_runner( test_model, max_model_len=MAX_MODEL_LEN, @@ -87,7 +81,6 @@ def test_models( enforce_eager=enforce_eager, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: test_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) diff --git a/tests/worker/test_model_input.py b/tests/worker/test_model_input.py index 309854e6babf3..57f1fd47a600f 100644 --- a/tests/worker/test_model_input.py +++ b/tests/worker/test_model_input.py @@ -74,6 +74,7 @@ def test_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), @@ -126,6 +127,7 @@ def test_embedding_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) model_input = ModelInputForGPUWithPoolingMetadata( input_tokens=torch.ones(10), @@ -177,6 +179,7 @@ def test_multi_step_model_runner_input(): num_decode_tokens=3, slot_mapping=torch.zeros(1), multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, ) frozen_model_input = ModelInputForGPUWithSamplingMetadata( input_tokens=torch.ones(10), diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 38b37f57e8150..c73b2bb6b0112 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -48,8 +48,8 @@ def paged_attention_v1( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -80,8 +80,8 @@ def paged_attention_v2( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -112,8 +112,8 @@ def paged_attention_rocm( max_seq_len: int, alibi_slopes: Optional[torch.Tensor], kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, @@ -966,8 +966,8 @@ def reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, @@ -981,8 +981,8 @@ def reshape_and_cache_flash( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2efe142a17b69..8027a52b82ffc 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -123,6 +123,10 @@ class AttentionMetadata: multi_modal_placeholder_index_maps: Optional[Dict[ str, MultiModalPlaceholderMap.IndexMap]] + # Enable/disable KV scales calculation. This is so that we can disable the + # calculation until after prefill and cuda graph capture. + enable_kv_scales_calculation: bool + @property @abstractmethod def prefill_metadata(self) -> Optional["AttentionMetadata"]: @@ -226,8 +230,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], class AttentionLayer(Protocol): - _k_scale: float - _v_scale: float + _k_scale: torch.Tensor + _v_scale: torch.Tensor + _k_scale_float: float + _v_scale_float: float def forward( self, diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 9089db1126c94..20e9a3f139de2 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -222,6 +222,7 @@ def prefill_metadata( slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -251,6 +252,7 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 18acfb82fac58..1be099283e472 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -230,6 +230,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -274,6 +275,7 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=seq_lens_tensor, max_decode_query_len=self.max_decode_query_len, @@ -557,6 +559,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, max_decode_query_len=max_decode_query_len, @@ -675,7 +678,7 @@ def forward( NOTE: It in-place updates the output tensor. """ # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, ( "key/v_scale is not supported in FlashAttention.") assert output is not None, "Output tensor must be provided." diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index b8ffbe6dd64dd..3135b0b405343 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -219,6 +219,7 @@ def graph_capture_get_metadata_for_batch( num_prefills=0, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, num_prefill_tokens=0, num_decode_tokens=batch_size, max_prefill_seq_len=0, @@ -733,6 +734,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, max_prefill_seq_len=max_prefill_seq_len, @@ -888,8 +890,8 @@ def forward( kv_cache, logits_soft_cap=logits_soft_cap, causal=True, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, window_left=window_left) if decode_meta := attn_metadata.decode_metadata: assert decode_meta is not None @@ -899,8 +901,8 @@ def forward( kv_cache, sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - k_scale=layer._k_scale, - v_scale=layer._v_scale, + k_scale=layer._k_scale_float, + v_scale=layer._v_scale_float, window_left=window_left) if prefill_output is None and decode_output is not None: diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index cd729a1c8b274..57916a3c6a34c 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -193,7 +193,7 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 num_tokens, hidden_size = query.shape # Reshape the query, key, and value tensors. query = query.view(-1, self.num_heads, self.head_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index f5bf390df6afb..facdee6b29e39 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -173,7 +173,7 @@ def forward( Returns: shape = [batch_size, seq_len, num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 + assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 batch_size, seq_len, hidden_size = query.shape query = query.view(batch_size, seq_len, self.num_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index 37860494702cf..826311896d1d2 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -140,6 +140,7 @@ def prefill_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_decode_query_len=0, @@ -173,6 +174,7 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_decode_query_len=self.max_decode_query_len, @@ -380,6 +382,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index daf31fc352d6c..c163396451a00 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -153,6 +153,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], max_query_len=self.max_query_len, @@ -182,6 +183,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=self.slot_mapping[self.num_prefill_tokens:], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], max_query_len=None, diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 8722d7376795a..c3b2398b4e632 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -379,6 +379,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], prefill_block_tables=prefill_block_tables, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) return attn_metadata @@ -454,7 +455,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - assert layer._k_scale == 1.0 and layer._v_scale == 1.0 attn_type = self.attn_type if (attn_type == AttentionType.ENCODER and (not attn_metadata.is_all_encoder_attn_metadata_set)): diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 3df7f54cbd8d2..84fe89b7df360 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -265,6 +265,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=True, num_prefill_tokens=self.num_prefill_tokens, num_decode_tokens=num_decode_tokens, seq_lens=seq_lens, @@ -317,6 +318,7 @@ def graph_capture_get_metadata_for_batch( num_decode_tokens=batch_size, slot_mapping=self._graph_slot_mapping[:batch_size], multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], max_query_len=1, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 38e27434dab2c..8c25dda7aad2c 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -218,6 +218,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]: slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps, + enable_kv_scales_calculation=self.enable_kv_scales_calculation, seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=self.max_query_len, @@ -262,6 +263,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: num_decode_tokens=self.num_decode_tokens, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=True, seq_lens_tensor=seq_lens_tensor, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c36f8d08eb4a7..79ea9b666c7e8 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -5,6 +5,7 @@ import torch.nn as nn import torch.nn.functional as F +import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.config import CacheConfig, get_current_vllm_config @@ -57,10 +58,12 @@ def __init__( kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size is_attention_free = cache_config.is_attention_free + calculate_kv_scales = cache_config.calculate_kv_scales else: kv_cache_dtype = "auto" block_size = 16 is_attention_free = False + calculate_kv_scales = False if num_kv_heads is None: num_kv_heads = num_heads @@ -70,8 +73,15 @@ def __init__( # expect the pre-quantized k/v_scale to be loaded along # with the model weights. self.kv_cache_dtype = kv_cache_dtype - self._k_scale = 1.0 - self._v_scale = 1.0 + self.calculate_kv_scales = calculate_kv_scales + self._k_scale = torch.tensor(1.0, dtype=torch.float32) + self._v_scale = torch.tensor(1.0, dtype=torch.float32) + + # We also keep the float32 versions of k/v_scale for attention + # backends that don't support tensors (Flashinfer) + self._k_scale_float = 1.0 + self._v_scale_float = 1.0 + quant_method = quant_config.get_quant_method( self, prefix=prefix) if quant_config else None if quant_method is not None: @@ -127,6 +137,9 @@ def __init__( ).parallel_config.pipeline_parallel_size) ] + self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) + self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) + def forward( self, query: torch.Tensor, @@ -135,6 +148,9 @@ def forward( kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, ) -> torch.Tensor: + if self.calculate_kv_scales and \ + attn_metadata.enable_kv_scales_calculation: + self.calc_kv_scales(key, value) if self.use_output: output = torch.empty_like(query) hidden_size = query.size(-1) @@ -161,6 +177,14 @@ def forward( return torch.ops.vllm.unified_attention( query, key, value, self.layer_name) + def calc_kv_scales(self, key, value): + self._k_scale.copy_(torch.abs(key).max() / self.k_range) + self._v_scale.copy_(torch.abs(value).max() / self.v_range) + self._k_scale_float = self._k_scale.item() + self._v_scale_float = self._v_scale.item() + # We only calculate the scales once + self.calculate_kv_scales = False + def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore s += f", num_heads={self.impl.num_heads}" # type: ignore diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index cbc6c74acf09a..3a07184ed31f0 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -52,8 +52,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: ops.reshape_and_cache( @@ -80,8 +80,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: tp_rank: int = 0 @@ -149,8 +149,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: ipex_modules.PagedAttention.reshape_and_cache( @@ -170,8 +170,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, *args, ) -> None: block_size = value_cache.shape[2] diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 076f151ffcb61..fd62329141f6f 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -69,8 +69,8 @@ def write_to_paged_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, kv_cache_dtype: str, - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> None: ops.reshape_and_cache( key, @@ -95,8 +95,8 @@ def forward_decode( num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, tp_rank: int = 0, blocksparse_local_blocks: int = 0, blocksparse_vert_stride: int = 0, @@ -204,8 +204,8 @@ def forward_prefix( max_query_len: int, alibi_slopes: Optional[torch.Tensor], sliding_window: Optional[int], - k_scale: float, - v_scale: float, + k_scale: torch.Tensor, + v_scale: torch.Tensor, ) -> torch.Tensor: output = torch.empty_like(query) context_attention_fwd( diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index 9c11a8df55278..e2f2b66dfc90c 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -133,7 +133,7 @@ def _fwd_kernel( other=0.0) # [D,N] if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load @@ -181,7 +181,7 @@ def _fwd_kernel( ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) # [N,D] if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) @@ -564,7 +564,7 @@ def _fwd_kernel_alibi( other=0.0) # [D,N] if k_load.dtype.is_fp8(): - k = (k_load.to(tl.float32) * k_scale).to(q.dtype) + k = (k_load.to(tl.float32) * tl.load(k_scale)).to(q.dtype) else: k = k_load @@ -604,7 +604,7 @@ def _fwd_kernel_alibi( ((start_n + offs_n[:, None]) < cur_batch_ctx_len), other=0.0) if v_load.dtype.is_fp8(): - v = (v_load.to(tl.float32) * v_scale).to(q.dtype) + v = (v_load.to(tl.float32) * tl.load(v_scale)).to(q.dtype) else: v = v_load p = p.to(v.dtype) @@ -713,8 +713,8 @@ def context_attention_fwd(q, b_seq_len, b_ctx_len, max_input_len, - k_scale: float = 1.0, - v_scale: float = 1.0, + k_scale: torch.Tensor, + v_scale: torch.Tensor, alibi_slopes=None, sliding_window=None): diff --git a/vllm/config.py b/vllm/config.py index f7547921a05ea..efd81ad3de3b4 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -120,11 +120,6 @@ class ModelConfig: decoding draft models. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. - quantization_param_path: Path to JSON file containing scaling factors. - Used to load KV cache scaling factors into the model when KV cache - type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also - be used to load activation and weight scaling factors when the - model dtype is FP8_E4M3 on ROCm. enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. @@ -187,7 +182,6 @@ def compute_hash(self) -> str: factors.append(self.model) factors.append(self.dtype) factors.append(self.quantization) - factors.append(self.quantization_param_path) factors.append(self.revision) factors.append(self.code_revision) factors.append(self.trust_remote_code) @@ -213,7 +207,6 @@ def __init__( max_model_len: Optional[int] = None, spec_target_max_model_len: Optional[int] = None, quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, @@ -274,7 +267,6 @@ def __init__( else: self.tokenizer_revision = tokenizer_revision self.quantization = quantization - self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs @@ -1002,6 +994,7 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + calculate_kv_scales: Optional[bool] = None, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -1012,7 +1005,7 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb - + self.calculate_kv_scales = calculate_kv_scales self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() @@ -1021,6 +1014,10 @@ def __init__( self.num_gpu_blocks: Optional[int] = None self.num_cpu_blocks: Optional[int] = None + # Set calculate_kv_scales to False if the value is unset. + if self.calculate_kv_scales is None: + self.calculate_kv_scales = False + def metrics_info(self): # convert cache_config to dict(key: str, value: str) for prometheus # metrics info @@ -3297,7 +3294,6 @@ def __str__(self): f"quantization={self.model_config.quantization}, " f"enforce_eager={self.model_config.enforce_eager}, " f"kv_cache_dtype={self.cache_config.cache_dtype}, " - f"quantization_param_path={self.model_config.quantization_param_path}," f" device_config={self.device_config.device}, " f"decoding_config={self.decoding_config!r}, " f"observability_config={self.observability_config!r}, " diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index f58c1b55e0c70..5d3aeb68ebcfe 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -98,7 +98,6 @@ class EngineArgs: config_format: ConfigFormat = ConfigFormat.AUTO dtype: str = 'auto' kv_cache_dtype: str = 'auto' - quantization_param_path: Optional[str] = None seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -199,6 +198,8 @@ class EngineArgs: generation_config: Optional[str] = None enable_sleep_mode: bool = False + calculate_kv_scales: Optional[bool] = None + def __post_init__(self): if not self.tokenizer: self.tokenizer = self.model @@ -350,17 +351,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Data type for kv cache storage. If "auto", will use model ' 'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' 'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') - parser.add_argument( - '--quantization-param-path', - type=nullable_str, - default=None, - help='Path to the JSON file containing the KV cache ' - 'scaling factors. This should generally be supplied, when ' - 'KV cache dtype is FP8. Otherwise, KV cache scaling factors ' - 'default to 1.0, which may cause accuracy issues. ' - 'FP8_E5M2 (without scaling) is only supported on cuda version ' - 'greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead ' - 'supported for common inference criteria.') parser.add_argument('--max-model-len', type=int, default=EngineArgs.max_model_len, @@ -962,6 +952,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="Enable sleep mode for the engine. " "(only cuda platform is supported)") + parser.add_argument( + '--calculate-kv-scales', + action='store_true', + help='This enables dynamic calculation of ' + 'k_scale and v_scale when kv-cache-dtype is fp8. ' + 'If calculate-kv-scales is false, the scales will ' + 'be loaded from the model checkpoint if available. ' + 'Otherwise, the scales will default to 1.0.') + return parser @classmethod @@ -991,7 +990,6 @@ def create_model_config(self) -> ModelConfig: tokenizer_revision=self.tokenizer_revision, max_model_len=self.max_model_len, quantization=self.quantization, - quantization_param_path=self.quantization_param_path, enforce_eager=self.enforce_eager, max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, @@ -1068,6 +1066,7 @@ def create_engine_config(self, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, + calculate_kv_scales=self.calculate_kv_scales, ) parallel_config = ParallelConfig( pipeline_parallel_size=self.pipeline_parallel_size, diff --git a/vllm/envs.py b/vllm/envs.py index b72e9141ac792..8627caec7790d 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -73,6 +73,8 @@ VLLM_ENABLE_V1_MULTIPROCESSING: bool = True VLLM_LOG_BATCHSIZE_INTERVAL: float = -1 VLLM_DISABLE_COMPILE_CACHE: bool = False + K_SCALE_CONSTANT: int = 200 + V_SCALE_CONSTANT: int = 100 VLLM_SERVER_DEV_MODE: bool = False VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 @@ -474,6 +476,13 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]: "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + # Divisor for dynamic key scale factor calculation for FP8 KV Cache + "K_SCALE_CONSTANT": + lambda: int(os.getenv("K_SCALE_CONSTANT", "200")), + + # Divisor for dynamic value scale factor calculation for FP8 KV Cache + "V_SCALE_CONSTANT": + lambda: int(os.getenv("V_SCALE_CONSTANT", "100")), # If set, enable multiprocessing in LLM for the V1 code path. "VLLM_ENABLE_V1_MULTIPROCESSING": lambda: bool(int(os.getenv("VLLM_ENABLE_V1_MULTIPROCESSING", "1"))), diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index a74f5415c8a51..e1870c73cc932 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -3,6 +3,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -40,11 +41,16 @@ def apply(self, layer: torch.nn.Module) -> torch.Tensor: def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 # regardless whether the kv-scale is available in the checkpoint. - if layer.kv_cache_dtype != "auto": + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: if layer.k_scale > 0.0 and layer.v_scale > 0.0: # We prefer to use separate k_scale and v_scale if present k_scale = layer.k_scale.to("cpu").tolist() v_scale = layer.v_scale.to("cpu").tolist() + if current_platform.is_rocm(): + k_scale *= 2 + v_scale *= 2 elif layer.k_scale < 0.0 and layer.v_scale < 0.0: # If no scales were loaded (both scales are invalid negative # values), use the default value of 1.0 @@ -58,6 +64,9 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: scale_to_duplicate = max(layer.k_scale, layer.v_scale) k_scale = scale_to_duplicate.to("cpu").tolist() v_scale = scale_to_duplicate.to("cpu").tolist() + if current_platform.is_rocm(): + k_scale *= 2 + v_scale *= 2 if not isinstance(k_scale, float) or not isinstance( v_scale, float): @@ -65,9 +74,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: "for fp8 KV cache") # These are used in the final Attention.forward() - layer._k_scale = k_scale - layer._v_scale = v_scale - if (layer._k_scale == 1.0 and layer._v_scale == 1.0 + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if (k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype): logger.warning_once( "Using KV cache scaling factor 1.0 for fp8_e4m3. This " diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 9cfcdbf620d2b..b70407221312a 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -6,8 +6,7 @@ import os import tempfile from collections import defaultdict -from typing import (Any, Callable, Dict, Generator, Iterable, List, Optional, - Tuple, Union) +from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union import filelock import gguf @@ -23,7 +22,6 @@ from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) -from vllm.model_executor.layers.quantization.schema import QuantParamSchema from vllm.platforms import current_platform from vllm.utils import PlaceholderModule @@ -496,47 +494,6 @@ def gguf_quant_weights_iterator( yield name, param -def kv_cache_scales_loader( - filename: str, tp_rank: int, tp_size: int, num_hidden_layers: int, - model_type: Optional[str]) -> Iterable[Tuple[int, float]]: - """ - A simple utility to read in KV cache scaling factors that have been - previously serialized to disk. Used by the model to populate the appropriate - KV cache scaling factors. The serialization should represent a dictionary - whose keys are the TP ranks and values are another dictionary mapping layers - to their KV cache scaling factors. - Keep this function in sync with the output of - examples/other/fp8/extract_scales.py - """ - try: - with open(filename) as f: - context = { - "model_type": model_type, - "num_hidden_layers": num_hidden_layers, - "tp_rank": tp_rank, - "tp_size": tp_size, - } - schema_dct = json.load(f) - schema = QuantParamSchema.model_validate(schema_dct, - context=context) - layer_scales_map = schema.kv_cache.scaling_factor[tp_rank] - return layer_scales_map.items() - - except FileNotFoundError: - logger.error("File or directory '%s' not found.", filename) - except json.JSONDecodeError: - logger.error("Error decoding JSON in file '%s'.", filename) - except Exception: - logger.exception("An error occurred while reading '%s'.", filename) - # This section is reached if and only if any of the excepts are hit - # Return an empty iterable (list) => no KV cache scales are loaded - # which ultimately defaults to 1.0 scales - logger.warning( - "Defaulting to KV cache scaling factors = 1.0 for all " - "layers in TP rank %d as an error occurred during loading.", tp_rank) - return [] - - def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: """convert PySafeSlice object from safetensors to torch.Tensor diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index eab3bf0756fca..bc3295da7b60a 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -30,8 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -44,9 +43,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.exaone import ExaoneConfig @@ -576,32 +574,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - if not isinstance(self.transformer.h[layer_idx], nn.Identity): - layer_self_attn = self.transformer.h[layer_idx].attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index ddd2d7a16b242..543b4e2f5e286 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -29,8 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -44,9 +43,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -518,29 +516,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): - if not isinstance(self.model.layers[layer_idx], nn.Identity): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f87379a6039ba..ca9646b43649b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -29,8 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -43,9 +42,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -441,32 +439,6 @@ def load_weights(self, weights: Iterable[Tuple[str, loaded_params.add(name) return loaded_params - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, tp_rank, tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type): - if not isinstance(self.layers[layer_idx], nn.Identity): - layer_self_attn = self.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") - class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { @@ -594,9 +566,6 @@ def load_weights(self, weights: Iterable[Tuple[str, self.maybe_remap_mistral(name, loaded_weight) for name, loaded_weight in weights) - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - self.model.load_kv_cache_scales(quantization_param_path) - # This function is used to remap the mistral format as # used by Mistral and Llama <=2 def maybe_remap_mistral( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 2554281610a30..61baa8e588d74 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -831,6 +831,7 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: + i = torch.ones(1, dtype=torch.float32) if self.attn.backend in (_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -843,8 +844,8 @@ def _attention_with_mask( attn_metadata. cross_slot_mapping, # type: ignore[union-attr] "auto", - 1.0, - 1.0, + i, + i, ) elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( @@ -853,7 +854,7 @@ def _attention_with_mask( cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) PagedAttention.write_to_paged_cache( cached_k, cached_v, key_cache, value_cache, - attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + attn_metadata.cross_slot_mapping, "auto", i, i) else: raise ValueError( f"Unsupported Attention backend {self.attn.backend} " diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 37c5a4b5713b8..e6d919f23c85d 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -30,8 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -44,9 +43,8 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) + default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA, SupportsPP @@ -535,32 +533,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader(param, loaded_weight) loaded_params.add(name) return loaded_params - - # If this function is called, it should always initialize KV cache scale - # factors (or else raise an exception). Thus, handled exceptions should - # make sure to leave KV cache scale factors in a known good (dummy) state - def load_kv_cache_scales(self, quantization_param_path: str) -> None: - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - for layer_idx, scaling_factor in kv_cache_scales_loader( - quantization_param_path, - tp_rank, - tp_size, - self.config.num_hidden_layers, - self.config.__class__.model_type, - ): - if not isinstance(self.model.layers[layer_idx], nn.Identity): - layer_self_attn = self.model.layers[layer_idx].self_attn - - if current_platform.is_rocm(): - # The scaling factor convention we are assuming is - # quantized_value * scaling_factor ~= true_value - # which is consistent with the practice of setting - # scaling_factor = tensor_amax / FPtype_max - scaling_factor *= 2 - if hasattr(layer_self_attn.attn, "_k_scale"): - layer_self_attn.attn._k_scale = scaling_factor - layer_self_attn.attn._v_scale = scaling_factor - else: - raise RuntimeError("Self attention has no KV cache scaling " - "factor attribute!") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1806fec8833a3..7fe9b3a8f595a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -166,10 +166,6 @@ def forward( Returns: shape = [num_tokens, num_heads * head_size] """ - # NOTE(woosuk): FlashAttention does not support FP8 KV cache. - assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( - "key/v_scale is not supported in FlashAttention.") - assert output is not None, "Output tensor must be provided." if attn_metadata is None: diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 4c8f69e449393..a339c97a8383c 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -903,7 +903,8 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps= - None # FIXME(kzawora): mutli-modality will not work here + None, # FIXME(kzawora): mutli-modality will not work here + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) @@ -1057,7 +1058,9 @@ def _prepare_decode( num_prefill_tokens=0, num_decode_tokens=num_decode_tokens, slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None) + multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, + ) return PrepareDecodeMetadata(input_tokens=input_tokens, input_positions=input_positions, attn_metadata=attn_metadata, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index fe504821a0d96..cf2f1c6b3b877 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -3,7 +3,6 @@ import inspect import itertools import time -import warnings import weakref from contextlib import contextmanager from dataclasses import dataclass @@ -41,7 +40,6 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalKwargs, MultiModalPlaceholderMap, MultiModalRegistry) -from vllm.platforms import current_platform from vllm.prompt_adapter.layers import PromptAdapterMapping from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.prompt_adapter.worker_manager import ( @@ -1151,34 +1149,6 @@ def load_model(self) -> None: self.prompt_adapter_manager.create_prompt_adapter_manager( self.model)) - if self.kv_cache_dtype == "fp8" and (current_platform.is_rocm() - or current_platform.is_cuda()): - # Currently only ROCm accepts kv-cache scaling factors - # via quantization_param_path and this will be deprecated - # in the future. - if self.model_config.quantization_param_path is not None: - if callable(getattr(self.model, "load_kv_cache_scales", None)): - warnings.warn( - "Loading kv cache scaling factor from JSON is " - "deprecated and will be removed. Please include " - "kv cache scaling factors in the model checkpoint.", - FutureWarning, - stacklevel=2) - self.model.load_kv_cache_scales( - self.model_config.quantization_param_path) - logger.info("Loaded KV cache scaling factors from %s", - self.model_config.quantization_param_path) - else: - raise RuntimeError( - "Using FP8 KV cache and scaling factors provided but " - "model %s does not support loading scaling factors.", - self.model.__class__) - else: - logger.warning( - "Using FP8 KV cache but no scaling factors " - "provided. Defaulting to scaling factors of 1.0. " - "This may lead to less accurate results!") - if self.vllm_config.compilation_config.level ==\ CompilationLevel.DYNAMO_AS_IS and supports_dynamo(): backend = self.vllm_config.compilation_config.init_backend( @@ -1366,6 +1336,10 @@ def _dummy_run(self, dtype=self.model_config.dtype, device=self.device) + # Disable KV Scale Calculation for dummy data during profile run + if model_input.attn_metadata is not None: + model_input.attn_metadata.enable_kv_scales_calculation = False + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -1510,7 +1484,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: batch_size, is_encoder_decoder_model=self.model_config. is_encoder_decoder)) - + # Disable KV Scale Calculation for graph capture + attn_metadata.enable_kv_scales_calculation = False if self.lora_config: lora_mapping = LoRAMapping( **dict(index_mapping=[0] * batch_size, diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index 9d0a759ca2f21..42fe2cf668ad8 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -282,6 +282,7 @@ def _prepare_model_input( block_indices_begins=block_indices_begins_tensor, max_context_len=max_context_len_tensor, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, ) multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index f5c7bc955a673..a3f648f4cc645 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -190,6 +190,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=None, context_lens=None, effective_query_lens=None, @@ -208,6 +209,7 @@ def _dummy_run( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=effective_query_lens, @@ -239,6 +241,7 @@ def _dummy_run( num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) @@ -425,6 +428,7 @@ def _prepare_prompt( num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, effective_query_lens=prompt_lens, @@ -496,6 +500,7 @@ def _prepare_decode( num_decode_tokens=batch_size, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, block_tables=block_tables, context_lens=context_lens, ) diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 053658d047311..b7b7b7227b22c 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -261,6 +261,7 @@ def _prepare_prompt( is_prompt=True, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=placeholder_index_maps, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=seqlen_q, max_seqlen=max_seqlen, @@ -345,6 +346,7 @@ def _prepare_decode( is_prompt=False, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, + enable_kv_scales_calculation=False, seq_lens=seq_lens, seqlen_q=torch.tensor([]), max_seqlen=0,