From 43af31054aee037558e1977b84cad6e23d447734 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 08:19:43 +0000 Subject: [PATCH 01/94] port dtype_float16.cuh and cache_kernels.cu --- csrc/attention/dtype_float16.cuh | 69 ++++++++++++++++++++++++++++++-- csrc/cache_kernels.cu | 17 ++++---- setup.py | 64 ----------------------------- 3 files changed, 74 insertions(+), 76 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52b..079fa607f96b7 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -63,30 +67,49 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); +#else + uint32_t b = a; + b <<= 16; + b |= a; +#endif return b; } inline __device__ float half_to_float(uint16_t h) { +#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; +#else + return __half2float(__ushort_as_half(h)); +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); +#else + uint16_t hi = (v >> 16) & 0xFFFF; + uint16_t lo = v & 0xFFFF; +#endif return make_float2(half_to_float(lo), half_to_float(hi)); } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; +#else + return __half_as_ushort(__float2half(f)); +#endif } inline __device__ uint32_t float2_to_half2(float2 f) { @@ -95,26 +118,48 @@ inline __device__ uint32_t float2_to_half2(float2 f) { uint16_t u16[2]; } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + __half2 h = __float22half2_rn(f); + tmp.u16[0] = h.x; + tmp.u16[1] = h.y; #endif return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + return __half_as_ushort(__hadd(__ushort_as_half(a), __ushort_as_half(b))); +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + __half2 h = __hadd2(a, b); + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = h.x; + tmp.u16[1] = h.y; + + return tmp.u32; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -157,16 +202,32 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + return __half_as_ushort(__hmul(__ushort_as_half(a), __ushort_as_half(b))); +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + __half2 h = __hmul2(a, b); + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = h.x; + tmp.u16[1] = h.y; + + return tmp.u32; +#endif } template<> diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index ddad2b5a29b9e..2d54ac5eab017 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -176,8 +177,8 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = __ldg(&key[src_key_idx]); - value_cache[tgt_value_idx] = __ldg(&value[src_value_idx]); + key_cache[tgt_key_idx] = VLLM_LDB(&key[src_key_idx]); + value_cache[tgt_value_idx] = VLLM_LDB(&value[src_value_idx]); } } @@ -262,8 +263,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDB(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDB(&value_cache[src_value_idx]); } } @@ -328,8 +329,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDB(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDB(&value_cache[src_value_idx]); } #pragma unroll diff --git a/setup.py b/setup.py index 8b2ad97dd5403..75433b336c357 100644 --- a/setup.py +++ b/setup.py @@ -24,10 +24,6 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -64,66 +60,6 @@ def get_torch_arch_list() -> Set[str]: return set(arch_list) -# First, check the TORCH_CUDA_ARCH_LIST environment variable. -compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = set(SUPPORTED_ARCHS) - if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") - if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if nvcc_cuda_version < Version("11.1"): - if any(cc.startswith("8.6") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.") - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): - raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") - -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] - ext_modules = [] # Cache operations. From cc818664c403b0a0c0683b565a428f0863f65a37 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 08:53:46 +0000 Subject: [PATCH 02/94] port dtype_bfloat16.cuh --- csrc/attention/dtype_bfloat16.cuh | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 2154bfcf8631a..9ad2e299c7aab 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,17 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + // See https://github.com/RadeonOpenCompute/ROCm/issues/2534 + hip_bfloat16 A, B; + __hip_bfloat16 c; + A.data = a.data; + B.data = b.data; + c.data = (A + B).data; + return c; + #endif #endif } From 475b5e2875f9f870b88206bf087ff6adc99517a9 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 09:07:33 +0000 Subject: [PATCH 03/94] port attention_utils.cuh --- csrc/attention/attention_kernels.cu | 30 ++++++++++++++++++++++++++++- csrc/attention/attention_utils.cuh | 4 ++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd78..423e784dca945 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -39,7 +39,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(uint32_t(-1), sum, mask); +#endif } // Warp leaders store the data to shared memory. @@ -58,11 +62,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(uint32_t(-1), sum, mask); +#endif } // Broadcast to other threads. +#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); +#else + return __shfl(uint32_t(-1), sum, 0); +#endif } // Grid: (num_heads, num_seqs). @@ -196,7 +208,11 @@ __global__ void single_query_cached_kv_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -208,10 +224,18 @@ __global__ void single_query_cached_kv_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); +#endif } // Broadcast the max qk value to all threads. +#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); +#else + qk_max = __shfl(uint32_t(-1), qk_max, 0); +#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -284,7 +308,11 @@ __global__ void single_query_cached_kv_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); +#else + acc += __shfl_xor(uint32_t(-1), acc, mask); +#endif } accs[i] = acc; } @@ -342,7 +370,7 @@ __global__ void single_query_cached_kv_attention_kernel( #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ + (void*)vllm::single_query_cached_kv_attention_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::single_query_cached_kv_attention_kernel \ <<>>( \ diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f02..1c3ea93694147 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -39,7 +39,11 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk += __shfl_xor_sync(uint32_t(-1), qk, mask); +#else + qk += __shfl_xor(uint32_t(-1), qk, mask); +#endif } return qk; } From ddc496c7166c055b2adf0ca772475a8add24b3d6 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 17:35:41 +0000 Subject: [PATCH 04/94] port more kernels --- csrc/activation_kernels.cu | 7 ++++--- csrc/cuda_utils_kernels.cu | 4 ++++ csrc/pos_encoding_kernels.cu | 9 +++++---- csrc/reduction_utils.cuh | 4 ++++ setup.py | 4 +++- 5 files changed, 20 insertions(+), 8 deletions(-) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index c6ae5db8f9c48..617cf6c0e4a55 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int token_idx = blockIdx.x; for (int idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99e..2439f5922a3f3 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,7 @@ +#ifdef USE_ROCM + #include +#endif + int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index b4351ee0d7941..1e977fa92837d 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b56..382ad162dfef8 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -23,7 +23,11 @@ template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) +#ifndef USE_ROCM val += __shfl_xor_sync(0xffffffff, val, mask, 32); +#else + val += __shfl_xor(val, mask, 32); +#endif return val; } diff --git a/setup.py b/setup.py index 75433b336c357..2d9b2afc067f1 100644 --- a/setup.py +++ b/setup.py @@ -117,6 +117,7 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(activation_extension) + # Quantization kernels. quantization_extension = CUDAExtension( name="vllm.quantization_ops", @@ -129,7 +130,8 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -ext_modules.append(quantization_extension) +if not torch.version.hip: + ext_modules.append(quantization_extension) # Misc. CUDA utils. cuda_utils_extension = CUDAExtension( From 5eaa7a10052ba77f2794dcd35316b371bddd92ea Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 17:41:06 +0000 Subject: [PATCH 05/94] fix typo --- csrc/cache_kernels.cu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 2d54ac5eab017..1a9376b3103ed 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -177,8 +177,8 @@ __global__ void reshape_and_cache_kernel( + head_idx * head_size * block_size + head_offset * block_size + block_offset; - key_cache[tgt_key_idx] = VLLM_LDB(&key[src_key_idx]); - value_cache[tgt_value_idx] = VLLM_LDB(&value[src_value_idx]); + key_cache[tgt_key_idx] = VLLM_LDG(&key[src_key_idx]); + value_cache[tgt_value_idx] = VLLM_LDG(&value[src_value_idx]); } } @@ -263,8 +263,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = VLLM_LDB(&key_cache[src_key_idx]); - value[tgt_value_idx] = VLLM_LDB(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -329,8 +329,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = VLLM_LDB(&key_cache[src_key_idx]); - values_to_store[j] = VLLM_LDB(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll From f7273c6831f87ffa8a2b2a569695cd876f47778c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Tue, 10 Oct 2023 18:18:08 +0000 Subject: [PATCH 06/94] add cuda_compat.h --- csrc/cuda_compat.h | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 csrc/cuda_compat.h diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 0000000000000..3348b78cfa19d --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,7 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif From f8093dc0dfeda653589c5f0db2fe5be46be614b3 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 20:27:30 +0000 Subject: [PATCH 07/94] sync branches --- csrc/attention/attention_kernels.cu | 498 +++++----------------------- csrc/attention/attention_utils.cuh | 2 +- csrc/attention/dtype_bfloat16.cuh | 5 - csrc/attention/dtype_float16.cuh | 65 ++-- 4 files changed, 121 insertions(+), 449 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index da7dedb0faf31..debde463786ef 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -26,7 +26,6 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) namespace vllm { @@ -43,7 +42,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { #ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); #else - sum += __shfl_xor(uint32_t(-1), sum, mask); + sum += __shfl_xor(sum, mask); #endif } @@ -66,7 +65,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { #ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); #else - sum += __shfl_xor(uint32_t(-1), sum, mask); + sum += __shfl_xor(sum, mask); #endif } @@ -74,22 +73,18 @@ inline __device__ float block_sum(float* red_smem, float sum) { #ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); #else - return __shfl(uint32_t(-1), sum, 0); + return __shfl(sum, 0); #endif } -// TODO(woosuk): Merge the last two dimensions of the grid. -// Grid: (num_heads, num_seqs, max_num_partitions). +// Grid: (num_heads, num_seqs). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE = 0> // Zero means no partitioning. -__device__ void paged_attention_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] + int NUM_THREADS> +__global__ void single_query_cached_kv_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -102,33 +97,10 @@ __device__ void paged_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { - const int seq_idx = blockIdx.y; - const int partition_idx = blockIdx.z; - const int max_num_partitions = gridDim.z; - constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; - const int context_len = context_lens[seq_idx]; - if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { - // No work to do. Terminate the thread block. - return; - } - - const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); - const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; - - // [start_block_idx, end_block_idx) is the range of blocks to process. - const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; - const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); - const int num_blocks = end_block_idx - start_block_idx; - - // [start_token_idx, end_token_idx) is the range of tokens to process. - const int start_token_idx = start_block_idx * BLOCK_SIZE; - const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); - const int num_tokens = end_token_idx - start_token_idx; - constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -137,6 +109,7 @@ __device__ void paged_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -181,12 +154,15 @@ __device__ void paged_attention_kernel( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -220,7 +196,7 @@ __device__ void paged_attention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= context_len; - logits[token_idx - start_token_idx] = mask ? 0.f : qk; + logits[token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); } @@ -235,7 +211,7 @@ __device__ void paged_attention_kernel( #ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); #else - qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); #endif } if (lane == 0) { @@ -251,19 +227,19 @@ __device__ void paged_attention_kernel( #ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); #else - qk_max = fmaxf(qk_max, __shfl_xor(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); #endif } // Broadcast the max qk value to all threads. #ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); #else - qk_max = __shfl(uint32_t(-1), qk_max, 0); + qk_max = __shfl(qk_max, 0); #endif // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; @@ -272,23 +248,11 @@ __device__ void paged_attention_kernel( // Compute softmax. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { logits[i] *= inv_sum; } __syncthreads(); - // If partitioning is enabled, store the max logit and exp_sum. - if (USE_PARTITIONING && thread_idx == 0) { - float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *max_logits_ptr = qk_max; - float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions - + partition_idx; - *exp_sums_ptr = exp_sum; - } - // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; @@ -297,7 +261,7 @@ __device__ void paged_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -308,12 +272,12 @@ __device__ void paged_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; @@ -323,7 +287,7 @@ __device__ void paged_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_context_blocks - 1) { + if (block_idx == num_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 @@ -347,7 +311,7 @@ __device__ void paged_attention_kernel( #ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); #else - acc += __shfl_xor(uint32_t(-1), acc, mask); + acc += __shfl_xor(acc, mask); #endif } accs[i] = acc; @@ -391,9 +355,7 @@ __device__ void paged_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE - + partition_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -404,167 +366,13 @@ __device__ void paged_attention_kernel( } } -// Grid: (num_heads, num_seqs, 1). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS> -__global__ void paged_attention_v1_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, - max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs, max_num_partitions). -template< - typename scalar_t, - int HEAD_SIZE, - int BLOCK_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_kernel( - float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] - const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] - const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] - const float scale, - const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_blocks_per_seq, - const float* __restrict__ alibi_slopes, // [num_heads] - const int q_stride, - const int kv_block_stride, - const int kv_head_stride) { - paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, - block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, - q_stride, kv_block_stride, kv_head_stride); -} - -// Grid: (num_heads, num_seqs). -template< - typename scalar_t, - int HEAD_SIZE, - int NUM_THREADS, - int PARTITION_SIZE> -__global__ void paged_attention_v2_reduce_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] - const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] - const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] - const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] - const int* __restrict__ context_lens, // [num_seqs] - const int max_num_partitions) { - const int num_heads = gridDim.x; - const int head_idx = blockIdx.x; - const int seq_idx = blockIdx.y; - const int context_len = context_lens[seq_idx]; - const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); - if (num_partitions == 1) { - // No need to reduce. Only copy tmp_out to out. - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { - out_ptr[i] = tmp_out_ptr[i]; - } - // Terminate the thread block. - return; - } - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - const int warp_idx = threadIdx.x / WARP_SIZE; - const int lane = threadIdx.x % WARP_SIZE; - - // Size: 2 * num_partitions. - extern __shared__ char shared_mem[]; - // Workspace for reduction. - __shared__ float red_smem[2 * NUM_WARPS]; - - // Load max logits to shared memory. - float* shared_max_logits = reinterpret_cast(shared_mem); - const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float max_logit = -FLT_MAX; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - const float l = max_logits_ptr[i]; - shared_max_logits[i] = l; - max_logit = fmaxf(max_logit, l); - } - __syncthreads(); - - // Get the global max logit. - // Reduce within the warp. -#pragma unroll - for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); - } - if (lane == 0) { - red_smem[warp_idx] = max_logit; - } - __syncthreads(); - // Reduce across warps. - max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; -#pragma unroll - for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); - } - // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); - - // Load rescaled exp sums to shared memory. - float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); - const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions - + head_idx * max_num_partitions; - float global_exp_sum = 0.0f; - for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { - float l = shared_max_logits[i]; - float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); - global_exp_sum += rescaled_exp_sum; - shared_exp_sums[i] = rescaled_exp_sum; - } - __syncthreads(); - global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); - const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); - - // Aggregate tmp_out to out. - const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE - + head_idx * max_num_partitions * HEAD_SIZE; - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; -#pragma unroll - for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { - float acc = 0.0f; - for (int j = 0; j < num_partitions; ++j) { - acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; - } - from_float(out_ptr[i], acc); - } -} - } // namespace vllm -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - (void*)vllm::paged_attention_v1_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::paged_attention_v1_kernel \ + (void*)vllm::single_query_cached_kv_attention_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::single_query_cached_kv_attention_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -585,7 +393,7 @@ template< typename T, int BLOCK_SIZE, int NUM_THREADS = 128> -void paged_attention_v1_launcher( +void single_query_cached_kv_attention_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -621,206 +429,45 @@ void paged_attention_v1_launcher( int* context_lens_ptr = context_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); - dim3 grid(num_heads, num_seqs, 1); - dim3 block(NUM_THREADS); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. - case 64: - LAUNCH_PAGED_ATTENTION_V1(64); - break; - case 80: - LAUNCH_PAGED_ATTENTION_V1(80); - break; - case 96: - LAUNCH_PAGED_ATTENTION_V1(96); - break; - case 112: - LAUNCH_PAGED_ATTENTION_V1(112); - break; - case 128: - LAUNCH_PAGED_ATTENTION_V1(128); - break; - case 256: - LAUNCH_PAGED_ATTENTION_V1(256); - break; - default: - TORCH_CHECK(false, "Unsupported head size: ", head_size); - break; - } -} - -#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v1_launcher( \ - out, \ - query, \ - key_cache, \ - value_cache, \ - head_mapping, \ - scale, \ - block_tables, \ - context_lens, \ - max_context_len, \ - alibi_slopes); - -// NOTE(woosuk): To reduce the compilation time, we omitted block sizes -// 1, 2, 4, 64, 128, 256. -#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ - switch (block_size) { \ - case 8: \ - CALL_V1_LAUNCHER(T, 8); \ - break; \ - case 16: \ - CALL_V1_LAUNCHER(T, 16); \ - break; \ - case 32: \ - CALL_V1_LAUNCHER(T, 32); \ - break; \ - default: \ - TORCH_CHECK(false, "Unsupported block size: ", block_size); \ - break; \ - } - -void paged_attention_v1( - torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& query, // [num_seqs, num_heads, head_size] - torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] - torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] - float scale, - torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] - torch::Tensor& context_lens, // [num_seqs] - int block_size, - int max_context_len, - const c10::optional& alibi_slopes) { - if (query.dtype() == at::ScalarType::Float) { - CALL_V1_LAUNCHER_BLOCK_SIZE(float); - } else if (query.dtype() == at::ScalarType::Half) { - CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); - } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); - } else { - TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); - } -} - -#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ - vllm::paged_attention_v2_kernel \ - <<>>( \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - head_mapping_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); \ - vllm::paged_attention_v2_reduce_kernel \ - <<>>( \ - out_ptr, \ - exp_sums_ptr, \ - max_logits_ptr, \ - tmp_out_ptr, \ - context_lens_ptr, \ - max_num_partitions); - -template< - typename T, - int BLOCK_SIZE, - int NUM_THREADS = 128, - int PARTITION_SIZE = 512> -void paged_attention_v2_launcher( - torch::Tensor& out, - torch::Tensor& exp_sums, - torch::Tensor& max_logits, - torch::Tensor& tmp_out, - torch::Tensor& query, - torch::Tensor& key_cache, - torch::Tensor& value_cache, - torch::Tensor& head_mapping, - float scale, - torch::Tensor& block_tables, - torch::Tensor& context_lens, - int max_context_len, - const c10::optional& alibi_slopes) { - int num_seqs = query.size(0); - int num_heads = query.size(1); - int head_size = query.size(2); - int max_num_blocks_per_seq = block_tables.size(1); - int q_stride = query.stride(0); - int kv_block_stride = key_cache.stride(0); - int kv_head_stride = key_cache.stride(1); - - int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); - assert(head_size % thread_group_size == 0); - - // NOTE: alibi_slopes is optional. - const float* alibi_slopes_ptr = alibi_slopes ? - reinterpret_cast(alibi_slopes.value().data_ptr()) - : nullptr; - - T* out_ptr = reinterpret_cast(out.data_ptr()); - float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); - float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); - T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); - T* query_ptr = reinterpret_cast(query.data_ptr()); - T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); - T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); - int* block_tables_ptr = block_tables.data_ptr(); - int* context_lens_ptr = context_lens.data_ptr(); - - constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); - int logits_size = PARTITION_SIZE * sizeof(float); - int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); - - // For paged attention v2 kernel. - dim3 grid(num_heads, num_seqs, max_num_partitions); - int shared_mem_size = std::max(logits_size, outputs_size); - // For paged attention v2 reduce kernel. - dim3 reduce_grid(num_heads, num_seqs); - int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); - + dim3 grid(num_heads, num_seqs); dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we only compile for the - // head sizes that we use in the model. However, we can easily extend this - // to support any head size which is a multiple of 16. + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; case 64: - LAUNCH_PAGED_ATTENTION_V2(64); + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); break; case 80: - LAUNCH_PAGED_ATTENTION_V2(80); + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); break; case 96: - LAUNCH_PAGED_ATTENTION_V2(96); + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); break; case 112: - LAUNCH_PAGED_ATTENTION_V2(112); + LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); break; case 128: - LAUNCH_PAGED_ATTENTION_V2(128); + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); break; + // case 160: + // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; case 256: - LAUNCH_PAGED_ATTENTION_V2(256); + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -828,12 +475,9 @@ void paged_attention_v2_launcher( } } -#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ - paged_attention_v2_launcher( \ +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_launcher( \ out, \ - exp_sums, \ - max_logits, \ - tmp_out, \ query, \ key_cache, \ value_cache, \ @@ -846,27 +490,42 @@ void paged_attention_v2_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ case 8: \ - CALL_V2_LAUNCHER(T, 8); \ + CALL_KERNEL_LAUNCHER(T, 8); \ break; \ case 16: \ - CALL_V2_LAUNCHER(T, 16); \ + CALL_KERNEL_LAUNCHER(T, 16); \ break; \ case 32: \ - CALL_V2_LAUNCHER(T, 32); \ + CALL_KERNEL_LAUNCHER(T, 32); \ break; \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -void paged_attention_v2( +void single_query_cached_kv_attention( torch::Tensor& out, // [num_seqs, num_heads, head_size] - torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] - torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] @@ -878,11 +537,11 @@ void paged_attention_v2( int max_context_len, const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { - CALL_V2_LAUNCHER_BLOCK_SIZE(float); + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { - CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } @@ -891,4 +550,3 @@ void paged_attention_v2( #undef WARP_SIZE #undef MAX #undef MIN -#undef DIVIDE_ROUND_UP diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 1c3ea93694147..7e6b64eea96f4 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -42,7 +42,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { #ifndef USE_ROCM qk += __shfl_xor_sync(uint32_t(-1), qk, mask); #else - qk += __shfl_xor(uint32_t(-1), qk, mask); + qk += __shfl_xor(qk, mask); #endif } return qk; diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 7f2b29de0d932..9ad2e299c7aab 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -439,11 +439,6 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } -// From bfloat16 to float32. -inline __device__ float to_float(__nv_bfloat16 u) { - return __bfloat162float(u); -} - // Zero-out a variable. inline __device__ void zero(__nv_bfloat16& dst) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 079fa607f96b7..dc45dbf3daea6 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -92,11 +92,15 @@ inline __device__ float2 half2_to_float2(uint32_t v) { #ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); #else - uint16_t hi = (v >> 16) & 0xFFFF; - uint16_t lo = v & 0xFFFF; + union { + __half2 h2; + uint32_t u32; + } V; + V.u32 = v; + return make_float2(half_to_float(V.h2.x), half_to_float(V.h2.y)); #endif - return make_float2(half_to_float(lo), half_to_float(hi)); } inline __device__ uint16_t float_to_half(float f) { @@ -113,24 +117,29 @@ inline __device__ uint16_t float_to_half(float f) { } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif + return tmp.u32; #else - __half2 h = __float22half2_rn(f); - tmp.u16[0] = h.x; - tmp.u16[1] = h.y; + union { + __half2 h2; + uint32_t u32; + } R; + + R.h2.x = __half_as_ushort(__float2half_rn(f.x)); + R.h2.y = __half_as_ushort(__float2half_rn(f.y)); + return R.u32; #endif - return tmp.u32; } // Vector addition. @@ -150,15 +159,14 @@ inline __device__ uint32_t add(uint32_t a, uint32_t b) { asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; #else - __half2 h = __hadd2(a, b); union { + __half2 h2; uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = h.x; - tmp.u16[1] = h.y; - - return tmp.u32; + } A, B, C; + A.u32 = a; + B.u32 = b; + C.h2 = __hadd2(A.h2, B.h2); + return C.u32; #endif } @@ -218,15 +226,14 @@ inline __device__ uint32_t mul(uint32_t a, uint32_t b) { asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; #else - __half2 h = __hmul2(a, b); - union { + union { + __half2 h2; uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = h.x; - tmp.u16[1] = h.y; - - return tmp.u32; + } A, B, C; + A.u32 = a; + B.u32 = b; + C.h2 = __hmul2(A.h2, B.h2); + return C.u32; #endif } @@ -332,9 +339,21 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { +#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; +#else + union { + __half2 h2; + uint32_t u32; + } A, B, C, D; + A.u32 = a; + B.u32 = b; + C.u32 = c; + D.h2 = __hfma2(A.h2, B.h2, C.h2); + return D.u32; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { From 41df6890a0398ed8535eeb644e70be78825cdd2c Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 20:31:05 +0000 Subject: [PATCH 08/94] update --- csrc/attention/attention_kernels.cu | 30 +---------------------------- 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index debde463786ef..505c63d2efd78 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -39,11 +39,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif } // Warp leaders store the data to shared memory. @@ -62,19 +58,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif } // Broadcast to other threads. -#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); -#else - return __shfl(sum, 0); -#endif } // Grid: (num_heads, num_seqs). @@ -208,11 +196,7 @@ __global__ void single_query_cached_kv_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -224,18 +208,10 @@ __global__ void single_query_cached_kv_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif } // Broadcast the max qk value to all threads. -#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); -#else - qk_max = __shfl(qk_max, 0); -#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -308,11 +284,7 @@ __global__ void single_query_cached_kv_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -#else - acc += __shfl_xor(acc, mask); -#endif } accs[i] = acc; } @@ -370,7 +342,7 @@ __global__ void single_query_cached_kv_attention_kernel( #define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ cudaFuncSetAttribute( \ - (void*)vllm::single_query_cached_kv_attention_kernel, \ + vllm::single_query_cached_kv_attention_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::single_query_cached_kv_attention_kernel \ <<>>( \ From 93be9c5b32aa36b96a2376e41498e0efd0dbb329 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 20:34:57 +0000 Subject: [PATCH 09/94] update --- csrc/attention/attention_kernels.cu | 484 ++++++++++++++++++++++++---- 1 file changed, 413 insertions(+), 71 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 505c63d2efd78..ee6b715adaef0 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -26,6 +26,7 @@ #define WARP_SIZE 32 #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) namespace vllm { @@ -65,14 +66,18 @@ inline __device__ float block_sum(float* red_smem, float sum) { return __shfl_sync(uint32_t(-1), sum, 0); } -// Grid: (num_heads, num_seqs). +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). template< typename scalar_t, int HEAD_SIZE, int BLOCK_SIZE, - int NUM_THREADS> -__global__ void single_query_cached_kv_attention_kernel( - scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + int NUM_THREADS, + int PARTITION_SIZE = 0> // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size] const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] @@ -85,10 +90,33 @@ __global__ void single_query_cached_kv_attention_kernel( const int q_stride, const int kv_block_stride, const int kv_head_stride) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int context_len = context_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= context_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int num_blocks_per_partition = USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_context_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = MIN(start_block_idx + num_blocks_per_partition, num_context_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = MIN(start_token_idx + num_blocks * BLOCK_SIZE, context_len); + const int num_tokens = end_token_idx - start_token_idx; + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); - constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_TOKENS_PER_THREAD_GROUP = DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; const int thread_idx = threadIdx.x; const int warp_idx = thread_idx / WARP_SIZE; @@ -97,7 +125,6 @@ __global__ void single_query_cached_kv_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; const int kv_head_idx = head_mapping[head_idx]; - const int seq_idx = blockIdx.y; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -142,15 +169,12 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int x = 16 / sizeof(scalar_t); float qk_max = -FLT_MAX; - const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; - const int context_len = context_lens[seq_idx]; - const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; - // Iterate over the key blocks. // Each warp fetches a block of keys for each iteration. // Each thread group in a warp fetches a key from the block, and computes // dot product with the query. - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; // Load a key to registers. @@ -184,7 +208,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Store the partial reductions to shared memory. // NOTE(woosuk): It is required to zero out the masked logits. const bool mask = token_idx >= context_len; - logits[token_idx] = mask ? 0.f : qk; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; // Update the max value. qk_max = mask ? qk_max : fmaxf(qk_max, qk); } @@ -215,7 +239,7 @@ __global__ void single_query_cached_kv_attention_kernel( // Get the sum of the exp values. float exp_sum = 0.f; - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { float val = __expf(logits[i] - qk_max); logits[i] = val; exp_sum += val; @@ -224,11 +248,23 @@ __global__ void single_query_cached_kv_attention_kernel( // Compute softmax. const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); - for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { logits[i] *= inv_sum; } __syncthreads(); + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + + partition_idx; + *exp_sums_ptr = exp_sum; + } + // Each thread will fetch 16 bytes from the value cache at a time. constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); using V_vec = typename Vec::Type; @@ -237,7 +273,7 @@ __global__ void single_query_cached_kv_attention_kernel( constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; - constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + constexpr int NUM_ROWS_PER_THREAD = DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. float accs[NUM_ROWS_PER_THREAD]; @@ -248,12 +284,12 @@ __global__ void single_query_cached_kv_attention_kernel( scalar_t zero_value; zero(zero_value); - for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; block_idx += NUM_WARPS) { const int physical_block_number = block_table[block_idx]; const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; L_vec logits_vec; - from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + from_float(logits_vec, *reinterpret_cast(logits + token_idx - start_token_idx)); const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; @@ -263,7 +299,7 @@ __global__ void single_query_cached_kv_attention_kernel( if (row_idx < HEAD_SIZE) { const int offset = row_idx * BLOCK_SIZE + physical_block_offset; V_vec v_vec = *reinterpret_cast(v_ptr + offset); - if (block_idx == num_blocks - 1) { + if (block_idx == num_context_blocks - 1) { // NOTE(woosuk): When v_vec contains the tokens that are out of the context, // we should explicitly zero out the values since they may contain NaNs. // See https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 @@ -327,7 +363,9 @@ __global__ void single_query_cached_kv_attention_kernel( // Write the final output. if (warp_idx == 0) { - scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; #pragma unroll for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; @@ -338,13 +376,167 @@ __global__ void single_query_cached_kv_attention_kernel( } } +// Grid: (num_heads, num_seqs, 1). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, + out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, + q_stride, kv_block_stride, kv_head_stride); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int NUM_THREADS, + int PARTITION_SIZE> +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + } // namespace vllm -#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::single_query_cached_kv_attention_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::single_query_cached_kv_attention_kernel \ + vllm::paged_attention_v1_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ <<>>( \ out_ptr, \ query_ptr, \ @@ -365,7 +557,7 @@ template< typename T, int BLOCK_SIZE, int NUM_THREADS = 128> -void single_query_cached_kv_attention_launcher( +void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, @@ -401,45 +593,206 @@ void single_query_cached_kv_attention_launcher( int* context_lens_ptr = context_lens.data_ptr(); constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; - int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int padded_max_context_len = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE) * BLOCK_SIZE; int logits_size = padded_max_context_len * sizeof(float); int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len // Keep that in sync with the logic here! int shared_mem_size = std::max(logits_size, outputs_size); - dim3 grid(num_heads, num_seqs); + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v1_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER(T, 32); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_V1_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, \ + exp_sums_ptr, \ + max_logits_ptr, \ + tmp_out_ptr, \ + context_lens_ptr, \ + max_num_partitions); + +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128, + int PARTITION_SIZE = 512> +void paged_attention_v2_launcher( + torch::Tensor& out, + torch::Tensor& exp_sums, + torch::Tensor& max_logits, + torch::Tensor& tmp_out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + dim3 block(NUM_THREADS); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); switch (head_size) { - // NOTE(woosuk): To reduce the compilation time, we omitted head sizes - // 32, 160, 192. - // case 32: - // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); - // break; + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. case 64: - LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(64); break; case 80: - LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(80); break; case 96: - LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(96); break; case 112: - LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(112); break; case 128: - LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(128); break; - // case 160: - // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); - // break; - // case 192: - // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); - // break; case 256: - LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + LAUNCH_PAGED_ATTENTION_V2(256); break; default: TORCH_CHECK(false, "Unsupported head size: ", head_size); @@ -447,9 +800,12 @@ void single_query_cached_kv_attention_launcher( } } -#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ - single_query_cached_kv_attention_launcher( \ +#define CALL_V2_LAUNCHER(T, BLOCK_SIZE) \ + paged_attention_v2_launcher( \ out, \ + exp_sums, \ + max_logits, \ + tmp_out, \ query, \ key_cache, \ value_cache, \ @@ -462,42 +818,27 @@ void single_query_cached_kv_attention_launcher( // NOTE(woosuk): To reduce the compilation time, we omitted block sizes // 1, 2, 4, 64, 128, 256. -#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T) \ switch (block_size) { \ - /* case 1: */ \ - /* CALL_KERNEL_LAUNCHER(T, 1); */ \ - /* break; */ \ - /* case 2: */ \ - /* CALL_KERNEL_LAUNCHER(T, 2); */ \ - /* break; */ \ - /* case 4: */ \ - /* CALL_KERNEL_LAUNCHER(T, 4); */ \ - /* break; */ \ case 8: \ - CALL_KERNEL_LAUNCHER(T, 8); \ + CALL_V2_LAUNCHER(T, 8); \ break; \ case 16: \ - CALL_KERNEL_LAUNCHER(T, 16); \ + CALL_V2_LAUNCHER(T, 16); \ break; \ case 32: \ - CALL_KERNEL_LAUNCHER(T, 32); \ + CALL_V2_LAUNCHER(T, 32); \ break; \ - /* case 64: */ \ - /* CALL_KERNEL_LAUNCHER(T, 64); */ \ - /* break; */ \ - /* case 128: */ \ - /* CALL_KERNEL_LAUNCHER(T, 128); */ \ - /* break; */ \ - /* case 256: */ \ - /* CALL_KERNEL_LAUNCHER(T, 256); */ \ - /* break; */ \ default: \ TORCH_CHECK(false, "Unsupported block size: ", block_size); \ break; \ } -void single_query_cached_kv_attention( +void paged_attention_v2( torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] @@ -509,11 +850,11 @@ void single_query_cached_kv_attention( int max_context_len, const c10::optional& alibi_slopes) { if (query.dtype() == at::ScalarType::Float) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + CALL_V2_LAUNCHER_BLOCK_SIZE(float); } else if (query.dtype() == at::ScalarType::Half) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); } else if (query.dtype() == at::ScalarType::BFloat16) { - CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); } else { TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); } @@ -522,3 +863,4 @@ void single_query_cached_kv_attention( #undef WARP_SIZE #undef MAX #undef MIN +#undef DIVIDE_ROUND_UP From d96fa3c2c7b45ae78052433e1b433244e8428ab0 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 21:18:11 +0000 Subject: [PATCH 10/94] fixes --- csrc/attention/attention_kernels.cu | 42 ++++++++++++++++++++++++++++- csrc/attention/dtype_bfloat16.cuh | 5 ++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ee6b715adaef0..8fe641aa307b4 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -40,7 +40,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(sum, mask); +#endif } // Warp leaders store the data to shared memory. @@ -59,11 +63,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM sum += __shfl_xor_sync(uint32_t(-1), sum, mask); +#else + sum += __shfl_xor(sum, mask); +#endif } // Broadcast to other threads. +#ifndef USE_ROCM return __shfl_sync(uint32_t(-1), sum, 0); +#else + return __shfl(sum, 0); +#endif } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -220,7 +232,11 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -232,10 +248,18 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); +#else + qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); +#endif } // Broadcast the max qk value to all threads. +#ifndef USE_ROCM qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); +#else + qk_max = __shfl(qk_max, 0); +#endif // Get the sum of the exp values. float exp_sum = 0.f; @@ -320,7 +344,11 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM acc += __shfl_xor_sync(uint32_t(-1), acc, mask); +#else + acc += __shfl_xor(acc, mask); +#endif } accs[i] = acc; } @@ -486,7 +514,11 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); +#else + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); +#endif } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -496,10 +528,18 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { +#ifndef USE_ROCM max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); +#else + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); +#endif } // Broadcast the max value to all threads. +#ifndef USE_ROCM max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); +#else + max_logit = __shfl(max_logit, 0); +#endif // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -534,7 +574,7 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ + (void*)vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 9ad2e299c7aab..7f2b29de0d932 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -439,6 +439,11 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { #endif } +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + // Zero-out a variable. inline __device__ void zero(__nv_bfloat16& dst) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 From 421365b5cf80268710c921645e31c8e48e0596fa Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 22:36:57 +0000 Subject: [PATCH 11/94] cleanup --- csrc/attention/attention_kernels.cu | 43 ++++++----------------------- csrc/cuda_compat.h | 6 ++++ 2 files changed, 14 insertions(+), 35 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 8fe641aa307b4..ebfe3f6a38b08 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -20,6 +20,7 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" +#include "../cuda_compat.h" #include @@ -40,11 +41,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. @@ -63,11 +60,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); -#else - sum += __shfl_xor(sum, mask); -#endif + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. @@ -232,11 +225,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { -#ifndef USE_ROCM - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -248,11 +237,7 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); -#else - qk_max = fmaxf(qk_max, __shfl_xor(qk_max, mask)); -#endif + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. #ifndef USE_ROCM @@ -344,11 +329,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); -#else - acc += __shfl_xor(acc, mask); -#endif + acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; } @@ -514,11 +495,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); -#else - max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); -#endif + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -528,11 +505,7 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); -#else - max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); -#endif + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. #ifndef USE_ROCM diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 3348b78cfa19d..c5f170fcb4752 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -5,3 +5,9 @@ #else #define VLLM_LDG(arg) *(arg) #endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif From 06b800e3d67097c05a0446b8d2f35047c6a794e7 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 22:48:14 +0000 Subject: [PATCH 12/94] update --- csrc/attention/attention_kernels.cu | 1 - csrc/attention/attention_utils.cuh | 7 ++----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index ebfe3f6a38b08..3bea905b0679c 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -20,7 +20,6 @@ #include "attention_dtypes.h" #include "attention_utils.cuh" -#include "../cuda_compat.h" #include diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index 7e6b64eea96f4..ff64c4bd8f80c 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,6 +17,7 @@ */ #pragma once +#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -39,11 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { -#ifndef USE_ROCM - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); -#else - qk += __shfl_xor(qk, mask); -#endif + qk += VLLM_SHFL_XOR_SYNC(qk, mask); } return qk; } From 2312beb1f41206bbe97ed17b8fd0b935d9014f28 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 22:58:34 +0000 Subject: [PATCH 13/94] update --- csrc/attention/attention_kernels.cu | 18 +++--------------- csrc/cuda_compat.h | 6 ++++++ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3bea905b0679c..babd15bb30fb2 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -63,11 +63,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { } // Broadcast to other threads. -#ifndef USE_ROCM - return __shfl_sync(uint32_t(-1), sum, 0); -#else - return __shfl(sum, 0); -#endif + return VLLM_SHFL_SYNC(sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -239,11 +235,7 @@ __device__ void paged_attention_kernel( qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. -#ifndef USE_ROCM - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); -#else - qk_max = __shfl(qk_max, 0); -#endif + qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -507,11 +499,7 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. -#ifndef USE_ROCM - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); -#else - max_logit = __shfl(max_logit, 0); -#endif + max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index c5f170fcb4752..8991462a862e8 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -11,3 +11,9 @@ #else #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) #endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif From 2958b39d9ad440ec90dd13fe597d5d5e49d59d3e Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:11:25 +0000 Subject: [PATCH 14/94] update --- csrc/layernorm_kernels.cu | 1 + csrc/reduction_utils.cuh | 8 +++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index f932b9e2d6150..9d4ada1f07159 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -4,6 +4,7 @@ #include "dispatch_utils.h" #include "reduction_utils.cuh" + namespace vllm { // TODO(woosuk): Further optimize this kernel. diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 382ad162dfef8..b95ccef162072 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -17,17 +17,15 @@ */ #pragma once +#include "cuda_compat.h" + namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) -#ifndef USE_ROCM - val += __shfl_xor_sync(0xffffffff, val, mask, 32); -#else - val += __shfl_xor(val, mask, 32); -#endif + val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } From 3f8973403a3f9fdb21eb48bea1066db0ea4d4fcd Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:14:19 +0000 Subject: [PATCH 15/94] fmt --- csrc/layernorm_kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 9d4ada1f07159..f932b9e2d6150 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -4,7 +4,6 @@ #include "dispatch_utils.h" #include "reduction_utils.cuh" - namespace vllm { // TODO(woosuk): Further optimize this kernel. From 5397a5748c19f3ceedc037ce101eadd83191edae Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:15:49 +0000 Subject: [PATCH 16/94] cleanup --- setup.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/setup.py b/setup.py index 5eeb36b5b75e8..a8e9aa4af8c30 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,11 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] +if not torch.version.hip: + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. From 90e02d25b64757b2ccf1b7fc452db207591b8cb8 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:29:19 +0000 Subject: [PATCH 17/94] refactor --- setup.py | 66 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/setup.py b/setup.py index a8e9aa4af8c30..a44e33bfcea28 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,72 @@ def get_torch_arch_list() -> Set[str]: f"{valid_archs}.") return arch_list +def get_cuda_compute_capabilities(nvcc_cuda_version): + # First, check the TORCH_CUDA_ARCH_LIST environment variable. + compute_capabilities = get_torch_arch_list() + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + + if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = SUPPORTED_ARCHS.copy() + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") + + return compute_capabilities + +def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities): + if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + if nvcc_cuda_version < Version("11.1"): + if any(cc.startswith("8.6") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.1 or higher is required for compute capability 8.6.") + if nvcc_cuda_version < Version("11.8"): + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.") + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") + +if not torch.version.hip: + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + compute_capabilities = get_cuda_compute_capabilities(nvcc_cuda_version) + validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities) + + # Add target compute capabilities to NVCC flags. + for capability in compute_capabilities: + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + + # Use NVCC threads to parallelize the build. + if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] From a42020206afea67123304f206c186523f31d5582 Mon Sep 17 00:00:00 2001 From: Philipp Moritz Date: Mon, 16 Oct 2023 23:31:29 +0000 Subject: [PATCH 18/94] update --- setup.py | 1 - 1 file changed, 1 deletion(-) diff --git a/setup.py b/setup.py index a44e33bfcea28..c2e5ad80aa758 100644 --- a/setup.py +++ b/setup.py @@ -201,7 +201,6 @@ def validate_nvcc_cuda_version(nvcc_cuda_version, compute_capabilities): ) ext_modules.append(activation_extension) - # Quantization kernels. quantization_extension = CUDAExtension( name="vllm.quantization_ops", From 2d1e43581c2a3eefc09ad1a93ab2c83145ae38ae Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 17 Oct 2023 06:59:09 +0000 Subject: [PATCH 19/94] detecting rocm and adding flag for compiling --- setup.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index c2e5ad80aa758..55a4358f734f6 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) @@ -24,12 +24,15 @@ CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] +if torch.version.hip: + if ROCM_HOME is not None: + NVCC_FLAGS += [f"-DUSE_ROCM"] + if not torch.version.hip: if CUDA_HOME is None: raise RuntimeError( "Cannot find CUDA_HOME. CUDA must be available to build the package.") - def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. From e231b7903dabbc698639d9a303edcb13d4e34d0f Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 17 Oct 2023 06:59:46 +0000 Subject: [PATCH 20/94] using asm volatile instead of hip api --- csrc/attention/dtype_float16.cuh | 100 ++++++++++++------------------- 1 file changed, 39 insertions(+), 61 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index dc45dbf3daea6..8e670d81ff3a9 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -71,21 +71,25 @@ inline __device__ uint32_t h0_h0(uint16_t a) { uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); #else - uint32_t b = a; - b <<= 16; - b |= a; + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + b = tmp.u32 #endif return b; } inline __device__ float half_to_float(uint16_t h) { -#ifndef USE_ROCM float f; +#ifndef USE_ROCM asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; #else - return __half2float(__ushort_as_half(h)); + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); #endif + return f; } inline __device__ float2 half2_to_float2(uint32_t v) { @@ -95,79 +99,68 @@ inline __device__ float2 half2_to_float2(uint32_t v) { return make_float2(half_to_float(lo), half_to_float(hi)); #else union { - __half2 h2; uint32_t u32; - } V; - V.u32 = v; - return make_float2(half_to_float(V.h2.x), half_to_float(V.h2.y)); + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; #endif } inline __device__ uint16_t float_to_half(float f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; +#ifndef USE_ROCM asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); - return tmp.u16[0]; #else - return __half_as_ushort(__float2half(f)); + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); #endif + return tmp.u16[0]; } inline __device__ uint32_t float2_to_half2(float2 f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; - +#ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif - return tmp.u32; #else - union { - __half2 h2; - uint32_t u32; - } R; - - R.h2.x = __half_as_ushort(__float2half_rn(f.x)); - R.h2.y = __half_as_ushort(__float2half_rn(f.y)); - return R.u32; + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); #endif + return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; +#ifndef USE_ROCM asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; #else - return __half_as_ushort(__hadd(__ushort_as_half(a), __ushort_as_half(b))); + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } inline __device__ uint32_t add(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; +#ifndef USE_ROCM asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; #else - union { - __half2 h2; - uint32_t u32; - } A, B, C; - A.u32 = a; - B.u32 = b; - C.h2 = __hadd2(A.h2, B.h2); - return C.u32; + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -210,31 +203,24 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; +#ifndef USE_ROCM asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; #else - return __half_as_ushort(__hmul(__ushort_as_half(a), __ushort_as_half(b))); + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; +#ifndef USE_ROCM asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; #else - union { - __half2 h2; - uint32_t u32; - } A, B, C; - A.u32 = a; - B.u32 = b; - C.h2 = __hmul2(A.h2, B.h2); - return C.u32; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return c; } template<> @@ -339,21 +325,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { -#ifndef USE_ROCM uint32_t d; +#ifndef USE_ROCM asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; #else - union { - __half2 h2; - uint32_t u32; - } A, B, C, D; - A.u32 = a; - B.u32 = b; - C.u32 = c; - D.h2 = __hfma2(A.h2, B.h2, C.h2); - return D.u32; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); #endif + return d; } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { From 31bb33568f1b14cbcc4e6aee84eed11c0da10c16 Mon Sep 17 00:00:00 2001 From: Amir Balwel Date: Tue, 17 Oct 2023 07:15:01 +0000 Subject: [PATCH 21/94] using asm volatile for type casting of f16 --- csrc/attention/dtype_float16.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 8e670d81ff3a9..b9c9275aae3f6 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -67,8 +67,8 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { -#ifndef USE_ROCM uint32_t b; +#ifndef USE_ROCM asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); #else union { @@ -77,7 +77,7 @@ inline __device__ uint32_t h0_h0(uint16_t a) { } tmp; tmp.u16[0] = a; tmp.u16[1] = a; - b = tmp.u32 + b = tmp.u32; #endif return b; } From b027d06f339db1b445d26587d4171b4a289808e7 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Nov 2023 05:11:53 +0000 Subject: [PATCH 22/94] Hipifying csrc file to accomodate rocm builds --- .gitignore | 4 + csrc/activation_kernels.cu | 7 +- csrc/attention/attention_kernels.cu | 48 +++-- csrc/attention/attention_utils.cuh | 3 +- csrc/attention/dtype_bfloat16.cuh | 19 +- csrc/attention/dtype_float16.cuh | 90 ++++++++- csrc/cache_kernels.cu | 13 +- csrc/cuda_compat.h | 19 ++ csrc/cuda_utils_kernels.cu | 3 + csrc/pos_encoding_kernels.cu | 9 +- .../squeezellm/quant_cuda_kernel.cu | 76 +++++++- csrc/reduction_utils.cuh | 4 +- setup.py | 181 ++++++++++-------- 13 files changed, 361 insertions(+), 115 deletions(-) create mode 100644 csrc/cuda_compat.h diff --git a/.gitignore b/.gitignore index b531b7918c305..58572718850e6 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,7 @@ _build/ # vim swap files *.swo *.swp + +# hip files generated by PyTorch +*.hip +*_hip* diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 89d1ba2d37dd8..1cca2c5fccc18 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); - const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -57,7 +58,7 @@ __global__ void activation_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = __ldg(&input[token_idx * d + idx]); + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 78e8d8ecd6d41..19ad43b554aff 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -23,7 +23,11 @@ #include +#ifndef USE_ROCM #define WARP_SIZE 32 +#else +#define WARP_SIZE 64 +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -40,7 +44,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Warp leaders store the data to shared memory. @@ -59,11 +63,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + sum += VLLM_SHFL_XOR_SYNC(sum, mask); } // Broadcast to other threads. - return __shfl_sync(uint32_t(-1), sum, 0); + return VLLM_SHFL_SYNC(sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -223,7 +227,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -235,10 +239,10 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); } // Broadcast the max qk value to all threads. - qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + qk_max = VLLM_SHFL_SYNC(qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -326,7 +330,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + acc += VLLM_SHFL_XOR_SYNC(acc, mask); } accs[i] = acc; } @@ -492,7 +496,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -502,10 +506,10 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); } // Broadcast the max value to all threads. - max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); + max_logit = VLLM_SHFL_SYNC(max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -538,9 +542,10 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm +#ifndef USE_ROCM #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - vllm::paged_attention_v1_kernel, \ + (void*)vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ @@ -557,6 +562,27 @@ __global__ void paged_attention_v2_reduce_kernel( q_stride, \ kv_block_stride, \ kv_head_stride); +#else +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + hipFuncSetAttribute( \ + (void*)vllm::paged_attention_v1_kernel, \ + hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); +#endif // TODO(woosuk): Tune NUM_THREADS. template< diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index bb7df25b14f02..ff64c4bd8f80c 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,6 +17,7 @@ */ #pragma once +#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + qk += VLLM_SHFL_XOR_SYNC(qk, mask); } return qk; } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 5786f77f7bca6..31e0cee01d2e1 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,8 +21,17 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#include -#include +#ifndef USE_ROCM + #include + #include +#else + #include + #include + + typedef __hip_bfloat162 __nv_bfloat162; + typedef __hip_bfloat16 __nv_bfloat16; +#endif + #include namespace vllm { @@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - return a + b; + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index e67921128d52b..1ec2009c74c87 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,6 +21,10 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" +#ifdef USE_ROCM + #include +#endif + #include namespace vllm { @@ -63,58 +67,114 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif } inline __device__ float half_to_float(uint16_t h) { +#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; +#else + float f; + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); + return f; +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; +#else + uint16_t ret; + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f)); + return ret; +#endif } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif + return tmp.u32; #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); -#endif + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); return tmp.u32; +#endif } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + uint16_t c; + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + uint32_t c; + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -157,16 +217,28 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { +#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; +#else + uint16_t c; + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { +#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; +#else + uint32_t c; + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); + return c; +#endif } template<> @@ -271,9 +343,15 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { +#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; +#else + uint32_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); + return d; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { @@ -441,4 +519,4 @@ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 3ad52b1681c0c..59bacffdf4642 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -28,8 +29,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - void *src_ptr = src.data_ptr(); - void *dst_ptr = dst.data_ptr(); + char *src_ptr = static_cast(src.data_ptr()); + char *dst_ptr = static_cast(dst.data_ptr()); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); - value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); + key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); + value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); } } @@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = __ldg(&key_cache[src_key_idx]); - values_to_store[j] = __ldg(&value_cache[src_value_idx]); + keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); + values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); } #pragma unroll diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h new file mode 100644 index 0000000000000..fcafc7e16687c --- /dev/null +++ b/csrc/cuda_compat.h @@ -0,0 +1,19 @@ +#pragma once + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index f1c30fe7ea99e..6c844a7f6c6ed 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,3 +1,6 @@ +#ifdef USE_ROCM + #include +#endif int get_device_attribute( int attribute, int device_id) diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index 0a5ec95f8c0d4..e1dc711778ffb 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,6 +1,7 @@ #include #include +#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = __ldg(cos_ptr + x_index); - sin = __ldg(sin_ptr + x_index); + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = __ldg(cos_ptr + x_index / 2); - sin = __ldg(sin_ptr + x_index / 2); + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1392b877397be..c2f77da15977a 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( +#ifndef USE_ROCM const half2* __restrict__ vec, +#else + const __half2* __restrict__ vec, +#endif const int* __restrict__ mat, +#ifndef USE_ROCM half2* __restrict__ mul, +#else + float2* __restrict__ mul, +#endif const __half* __restrict__ lookup_table, int height, int width, @@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel( int row = BLOCKHEIGHT4 * blockIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; +#ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; +#else + __shared__ __half2 blockvec[blockwidth2]; +#endif __shared__ __half deq2[16][BLOCKWIDTH]; int off = threadIdx.x; @@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel( } __half res; +#ifndef USE_ROCM half2 res2; half2 tmp2; +#else + __half2 res2; + __half2 tmp2; +#endif int i; int k; @@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel( while (k < blockwidth2) { tmp1 = as_unsigned(mat[i]); +#ifndef USE_ROCM res2 = {}; tmp2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); + tmp2.x = __half_as_ushort(__float2half(0)); + tmp2.y = __half_as_ushort(__float2half(0)); +#endif lut_index1 = tmp1 & 0xF; lut_index2 = (tmp1 >> 4) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 0], res2); lut_index1 = (tmp1 >> 8) & 0xF; lut_index2 = (tmp1 >> 12) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 1], res2); lut_index1 = (tmp1 >> 16) & 0xF; lut_index2 = (tmp1 >> 20) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 2], res2); lut_index1 = (tmp1 >> 24) & 0xF; lut_index2 = (tmp1 >> 28) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 3], res2); +#ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); +#else + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); +#endif i += width; k += 4; } // col%2 -> only set one of the two values +#ifndef USE_ROCM half2 res3 = {}; if (col % 2 == 0) { res3.x = res; } else { res3.y = res; } +#else + __half2 res3; + res3.x = __half_as_ushort(__float2half(0)); + res3.y = __half_as_ushort(__float2half(0)); + if (col % 2 == 0) { + res3.x = __half_as_ushort(res); + } else { + res3.y = __half_as_ushort(res); + } +#endif +#ifndef USE_ROCM atomicAdd(&mul[b * width / 2 + col / 2], res3); +#else + int tmp_addr = b * width / 2 + col / 2; + atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); + atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); +#endif } } @@ -136,13 +201,22 @@ void squeezellm_gemm( dim3 threads(BLOCKWIDTH); vllm::squeezellm::NUQ4MatMulKernel<<>>( +#ifndef USE_ROCM (half2*) vec.data(), +#else + (__half2*) vec.data_ptr(), +#endif mat.data_ptr(), +#ifndef USE_ROCM (half2*) mul.data(), (__half*) lookup_table.data(), +#else + (float2*) mul.data_ptr(), + (__half*) lookup_table.data_ptr(), +#endif height, width, batch, vec_height ); } #undef BLOCKWIDTH -#undef BLOCKHEIGHT4 +#undef BLOCKHEIGHT4 \ No newline at end of file diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index bc35aa0424b56..59ac565c585e8 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -16,14 +16,14 @@ * limitations under the License. */ #pragma once - +#include "cuda_compat.h" namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += __shfl_xor_sync(0xffffffff, val, mask, 32); + val += VLLM_SHFL_XOR_SYNC(val, mask); return val; } diff --git a/setup.py b/setup.py index 36f4913435628..0e28b93602772 100644 --- a/setup.py +++ b/setup.py @@ -8,20 +8,31 @@ from packaging.version import parse, Version import setuptools import torch -from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME +from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) MAIN_CUDA_VERSION = "12.1" # Supported NVIDIA GPU architectures. -SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} +NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030","gfx1100"} +SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] +if torch.version.hip: + if ROCM_HOME is not None: + NVCC_FLAGS += [f"-DUSE_ROCM"] + +if not torch.version.hip: + if CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") + ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] @@ -61,20 +72,20 @@ def get_torch_arch_list() -> Set[str]: return set() # Filter out the invalid architectures and print a warning. - valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS}) + valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}) arch_list = torch_arch_list.intersection(valid_archs) # If none of the specified architectures are valid, raise an error. if not arch_list: raise RuntimeError( - "None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env " + "None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env " f"variable ({env_arch_list}) is supported. " - f"Supported CUDA architectures are: {valid_archs}.") + f"Supported CUDA/ROCM architectures are: {valid_archs}.") invalid_arch_list = torch_arch_list - valid_archs if invalid_arch_list: warnings.warn( - f"Unsupported CUDA architectures ({invalid_arch_list}) are " + f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are " "excluded from the `TORCH_CUDA_ARCH_LIST` env variable " - f"({env_arch_list}). Supported CUDA architectures are: " + f"({env_arch_list}). Supported CUDA/ROCM architectures are: " f"{valid_archs}.", stacklevel=2) return arch_list @@ -82,64 +93,65 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) -if not compute_capabilities: - # If no GPU is specified nor available, add all supported architectures - # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() - if nvcc_cuda_version < Version("11.1"): - compute_capabilities.remove("8.6") - if nvcc_cuda_version < Version("11.8"): - compute_capabilities.remove("8.9") - compute_capabilities.remove("9.0") - -# Validate the NVCC CUDA version. -if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") -if (nvcc_cuda_version < Version("11.1") - and any(cc.startswith("8.6") for cc in compute_capabilities)): - raise RuntimeError( - "CUDA 11.1 or higher is required for compute capability 8.6.") -if nvcc_cuda_version < Version("11.8"): - if any(cc.startswith("8.9") for cc in compute_capabilities): - # CUDA 11.8 is required to generate the code targeting compute capability 8.9. - # However, GPUs with compute capability 8.9 can also run the code generated by - # the previous versions of CUDA 11 and targeting compute capability 8.0. - # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 - # instead of 8.9. - warnings.warn( - "CUDA 11.8 or higher is required for compute capability 8.9. " - "Targeting compute capability 8.0 instead.", - stacklevel=2) - compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) - compute_capabilities.add("8.0+PTX") - if any(cc.startswith("9.0") for cc in compute_capabilities): +if not torch.version.hip: + if not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + +if not torch.version.hip: + nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) + if not compute_capabilities: + # If no GPU is specified nor available, add all supported architectures + # based on the NVCC CUDA version. + compute_capabilities = SUPPORTED_ARCHS.copy() + if nvcc_cuda_version < Version("11.1"): + compute_capabilities.remove("8.6") + if nvcc_cuda_version < Version("11.8"): + compute_capabilities.remove("8.9") + compute_capabilities.remove("9.0") + # Validate the NVCC CUDA version. + if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + if (nvcc_cuda_version < Version("11.1") + and any(cc.startswith("8.6") for cc in compute_capabilities)): raise RuntimeError( - "CUDA 11.8 or higher is required for compute capability 9.0.") + "CUDA 11.1 or higher is required for compute capability 8.6.") + if nvcc_cuda_version < Version("11.8"): + if any(cc.startswith("8.9") for cc in compute_capabilities): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + warnings.warn( + "CUDA 11.8 or higher is required for compute capability 8.9. " + "Targeting compute capability 8.0 instead.", + stacklevel=2) + compute_capabilities = set(cc for cc in compute_capabilities + if not cc.startswith("8.9")) + compute_capabilities.add("8.0+PTX") + if any(cc.startswith("9.0") for cc in compute_capabilities): + raise RuntimeError( + "CUDA 11.8 or higher is required for compute capability 9.0.") -# Add target compute capabilities to NVCC flags. -for capability in compute_capabilities: - num = capability[0] + capability[2] - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] - if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + # Add target compute capabilities to NVCC flags. + for capability in compute_capabilities: + num = capability[0] + capability[2] + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + # Use NVCC threads to parallelize the build. + if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] ext_modules = [] @@ -199,18 +211,31 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append(activation_extension) # Quantization kernels. -quantization_extension = CUDAExtension( - name="vllm.quantization_ops", - sources=[ - "csrc/quantization.cpp", - "csrc/quantization/awq/gemm_kernels.cu", - "csrc/quantization/squeezellm/quant_cuda_kernel.cu", - ], - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS, - }, -) +if not torch.version.hip: + quantization_extension = CUDAExtension( + name="vllm.quantization_ops", + sources=[ + "csrc/quantization.cpp", + "csrc/quantization/awq/gemm_kernels.cu", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) +else: + quantization_extension = CUDAExtension( + name="vllm.quantization_ops", + sources=[ + "csrc/quantization.cpp", + "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS, + }, + ) ext_modules.append(quantization_extension) # Misc. CUDA utils. @@ -244,10 +269,10 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - cuda_version = str(nvcc_cuda_version) - if cuda_version != MAIN_CUDA_VERSION: - cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" + # cuda_version = str(nvcc_cuda_version) + # if cuda_version != MAIN_CUDA_VERSION: + # cuda_version_str = cuda_version.replace(".", "")[:3] + # version += f"+cu{cuda_version_str}" return version From 1477ba0f9a5ba0a7b3d9291db1c358808ffc2060 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 27 Nov 2023 16:45:02 -0800 Subject: [PATCH 23/94] Add multi-LoRA support --------- Co-authored-by: Chen Shen Co-authored-by: Shreyas Krishnaswamy Co-authored-by: Avnish Narayan --- csrc/punica/LICENSE | 217 +++ csrc/punica/bgmv/bgmv_all.cu | 21 + csrc/punica/bgmv/bgmv_config.h | 53 + csrc/punica/bgmv/bgmv_impl.cuh | 294 ++++ csrc/punica/bgmv/vec_dtypes.cuh | 1324 +++++++++++++++++ csrc/punica/punica_ops.cc | 563 +++++++ setup.py | 60 +- tests/lora/__init__.py | 0 tests/lora/conftest.py | 139 ++ tests/lora/test_layers.py | 697 +++++++++ tests/lora/test_llama.py | 141 ++ tests/lora/test_lora.py | 224 +++ tests/lora/test_lora_manager.py | 473 ++++++ tests/lora/test_punica.py | 196 +++ tests/lora/test_tokenizer.py | 69 + tests/lora/test_utils.py | 172 +++ tests/lora/test_worker.py | 56 + tests/lora/utils.py | 88 ++ vllm/config.py | 31 + vllm/core/scheduler.py | 21 +- vllm/engine/arg_utils.py | 42 +- vllm/engine/async_llm_engine.py | 80 +- vllm/engine/llm_engine.py | 87 +- vllm/entrypoints/llm.py | 16 +- vllm/lora/__init__.py | 0 vllm/lora/layers.py | 1002 +++++++++++++ vllm/lora/lora.py | 120 ++ vllm/lora/models.py | 666 +++++++++ vllm/lora/punica.py | 173 +++ vllm/lora/request.py | 19 + vllm/lora/utils.py | 39 + vllm/lora/worker_manager.py | 266 ++++ vllm/model_executor/layers/sampler.py | 33 +- .../layers/vocab_parallel_embedding.py | 14 +- vllm/model_executor/model_loader.py | 14 +- vllm/model_executor/models/llama.py | 23 +- vllm/model_executor/models/mistral.py | 25 +- vllm/outputs.py | 19 +- vllm/sequence.py | 14 + vllm/transformers_utils/tokenizer.py | 82 + vllm/utils.py | 90 ++ vllm/worker/worker.py | 134 +- 42 files changed, 7713 insertions(+), 84 deletions(-) create mode 100644 csrc/punica/LICENSE create mode 100644 csrc/punica/bgmv/bgmv_all.cu create mode 100644 csrc/punica/bgmv/bgmv_config.h create mode 100644 csrc/punica/bgmv/bgmv_impl.cuh create mode 100644 csrc/punica/bgmv/vec_dtypes.cuh create mode 100644 csrc/punica/punica_ops.cc create mode 100644 tests/lora/__init__.py create mode 100644 tests/lora/conftest.py create mode 100644 tests/lora/test_layers.py create mode 100644 tests/lora/test_llama.py create mode 100644 tests/lora/test_lora.py create mode 100644 tests/lora/test_lora_manager.py create mode 100644 tests/lora/test_punica.py create mode 100644 tests/lora/test_tokenizer.py create mode 100644 tests/lora/test_utils.py create mode 100644 tests/lora/test_worker.py create mode 100644 tests/lora/utils.py create mode 100644 vllm/lora/__init__.py create mode 100644 vllm/lora/layers.py create mode 100644 vllm/lora/lora.py create mode 100644 vllm/lora/models.py create mode 100644 vllm/lora/punica.py create mode 100644 vllm/lora/request.py create mode 100644 vllm/lora/utils.py create mode 100644 vllm/lora/worker_manager.py diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE new file mode 100644 index 0000000000000..a46e2cdcadf7d --- /dev/null +++ b/csrc/punica/LICENSE @@ -0,0 +1,217 @@ +Contains code from https://github.com/punica-ai/punica + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ + +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache-2.0 +* third_party/nvbench (with LLVM exception) +* third_party/flashinfer + +BSD-3-Clause: +* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu new file mode 100644 index 0000000000000..bc86416701f13 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -0,0 +1,21 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h new file mode 100644 index 0000000000000..3fd56b685be13 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_config.h @@ -0,0 +1,53 @@ +#pragma once + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale); + +// clang-format off + +#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 128) \ + f(in_T, out_T, W_T, narrow, 256) \ + f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1280) \ + f(in_T, out_T, W_T, narrow, 1728) \ + f(in_T, out_T, W_T, narrow, 1792) \ + f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2560) \ + f(in_T, out_T, W_T, narrow, 2752) \ + f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3456) \ + f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 5120) \ + f(in_T, out_T, W_T, narrow, 5504) \ + f(in_T, out_T, W_T, narrow, 6912) \ + f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 10240) \ + f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 12288) \ + f(in_T, out_T, W_T, narrow, 13824) \ + f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 20480) \ + f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 32000) \ + f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 49152) \ + +#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +// clang-format on diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh new file mode 100644 index 0000000000000..995de26e8bada --- /dev/null +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -0,0 +1,294 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace cg = cooperative_groups; + +// nthrs = (32, 4) +template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t j = blockIdx.x; + constexpr size_t num_pipeline_stages = 2; + constexpr size_t tile_size = tx * ty * vec_size; + __shared__ W_T W_shared[num_pipeline_stages * tile_size]; + __shared__ in_T X_shared[num_pipeline_stages * tile_size]; + __shared__ float y_warpwise[ty]; + + size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + auto pipe = cuda::make_pipeline(); + + // pipeline load W/X and compute WX; + pipe.producer_acquire(); + cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + pipe.producer_commit(); + size_t copy_idx, compute_idx; + float y = 0.f; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; + ++tile_idx) { + copy_idx = tile_idx % num_pipeline_stages; + // pipeline stage: async copy W fragment + pipe.producer_acquire(); + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + } + pipe.producer_commit(); + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // pipeline stage: compute WX + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = sum; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + } + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // final pipeline stage + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + + // write Y; + if (block.thread_rank() == 0) { + Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); + } +} + +// nthrs = (2, 16, 4) +template +__global__ void +bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t tile_idx = blockIdx.x; + + // load X; + vec_t x_vec; + x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); + + // load W; + vec_t w_vec; + w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + block.thread_rank() * vec_size); + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } + + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += g.shfl_down(sum, offset); + } + sum = g.shfl(sum, 0); + + if (threadIdx.x == 0) { + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y] += static_cast(sum); + } +} + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + constexpr size_t vec_size = 8; + constexpr int tz = 4; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if constexpr (feat_in < feat_out) { + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || + (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || + (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else { + constexpr int ty = 8 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } else { + static_assert(feat_in % (vec_size * 32) == 0 || + feat_in % (vec_size * 16) == 0 || + feat_in % (vec_size * 8) == 0); + + if constexpr (feat_in % (vec_size * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { + constexpr int tx = 16; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } +} + +#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ + template void bgmv_kernel( \ + out_T * __restrict__ Y, const in_T *__restrict__ X, \ + const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ + int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ + int64_t num_layers, int64_t layer_idx, float scale); + +#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ + INST_BGMV(narrow, wide, in_T, out_T, W_T) \ + INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh new file mode 100644 index 0000000000000..cf00d869cf635 --- /dev/null +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -0,0 +1,1324 @@ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#ifdef FLASHINFER_USE_FP8 +#include +#endif +#include + +#include + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t &operator[](size_t i); + FLASHINFER_INLINE const float_t &operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t *ptr); + FLASHINFER_INLINE void store(float_t *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src); + template + FLASHINFER_INLINE void cast_load(const T *ptr); + template + FLASHINFER_INLINE void cast_store(T *ptr) const; + FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); +}; + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, + vec_t &dst) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl(const vec_t &src, + tgt_float_t *dst_ptr) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_USE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x2_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x2_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x4_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x4_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x2_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x2_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x4_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x4_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((half2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((half2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((half2 *)dst) = *((half2 *)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2 *)(&data.x) = make_half2(val, val); + *(half2 *)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)data)[i]; + } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + *(half2 *)(&(data[i].x)) = make_half2(val, val); + *(half2 *)(&(data[i].y)) = make_half2(val, val); + *(half2 *)(&(data[i].z)) = make_half2(val, val); + *(half2 *)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((nv_bfloat162 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((nv_bfloat162 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { + data = *((float2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { + *((float2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(data))[i]; + } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)dst)[i] = ((float4 *)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162 *)(&dst.data))[i] = + __float22bfloat162_rn(((float2 *)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_USE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_USE_FP8 + +#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc new file mode 100644 index 0000000000000..4ad46e5e1f726 --- /dev/null +++ b/csrc/punica/punica_ops.cc @@ -0,0 +1,563 @@ +#include +#include +#include + +#include + +#include "bgmv/bgmv_config.h" + +namespace { + +//====== utils ====== + +inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, + const char *a_name, const char *b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) \ + TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) \ + TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +//====== bgmv ====== + +template +inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, + const int64_t *lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + switch (pack_u16(in_features, out_features)) { +#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ + case pack_u16(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, num_layers, \ + layer_idx, scale); \ + break; +#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) + + FOR_BGMV_WIDE_NARROW(CASE, _, _, _) +#undef CASE +#undef CASE_ONESIDE + default: + return false; + } + + return true; +} + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + int64_t num_layers = w.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t num_layers = w.size(1); + int64_t full_y_size = y.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +} // namespace + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/setup.py b/setup.py index 2b040e88f0aa4..2e11119043277 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,16 @@ +import contextlib import io import os import re import subprocess -from typing import List, Set import warnings +from pathlib import Path +from typing import List, Set from packaging.version import parse, Version import setuptools import torch +import torch.utils.cpp_extension as torch_cpp_ext from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME ROOT_DIR = os.path.dirname(__file__) @@ -31,6 +34,11 @@ "Cannot find CUDA_HOME. CUDA must be available to build the package.") +def glob(pattern: str): + root = Path(__name__).parent + return [str(p) for p in root.glob(pattern)] + + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -129,19 +137,59 @@ def get_torch_arch_list() -> Set[str]: raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 9.0.") +# Use NVCC threads to parallelize the build. +if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] + +NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() + # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: num = capability[0] + capability[2] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + if int(capability[0]) >= 8: + NVCC_FLAGS_PUNICA += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS_PUNICA += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] + +# changes for punica kernels +NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS +REMOVE_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', +] +for flag in REMOVE_NVCC_FLAGS: + with contextlib.suppress(ValueError): + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) ext_modules = [] + +install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1"))) +device_count = torch.cuda.device_count() +for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + install_punica = False + break +if install_punica: + ext_modules.append( + CUDAExtension( + name="vllm._punica_C", + sources=["csrc/punica/punica_ops.cc"] + + glob("csrc/punica/bgmv/*.cu"), + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, + }, + )) + vllm_extension = CUDAExtension( name="vllm._C", sources=[ diff --git a/tests/lora/__init__.py b/tests/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 0000000000000..263a2bc9d8156 --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,139 @@ +import gc +import tempfile +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +import pytest +import ray +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download + +import vllm +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) + + +def cleanup(): + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +@pytest.fixture(autouse=True) +def cleanup_fixture(): + yield + cleanup() + + +@pytest.fixture +def dist_init(): + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(1, 1) + yield + cleanup() + + +@pytest.fixture +def dist_init_torch_only(): + if torch.distributed.is_initialized(): + return + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + + +@pytest.fixture +def dummy_model() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture +def dummy_model_gate_up() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture(scope="session") +def sql_lora_files(): + return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +@pytest.fixture +def llama_2_7b_engine_extra_embeddings() -> nn.Module: + cleanup() + get_model_old = get_model + + def get_model_patched(model_config, lora_config=None): + return get_model_old(model_config, LoRAConfig(max_lora_rank=8)) + + with patch("vllm.worker.worker.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + yield engine.llm_engine + del engine + cleanup() + + +@pytest.fixture +def llama_2_7b_model_extra_embeddings( + llama_2_7b_engine_extra_embeddings) -> nn.Module: + yield llama_2_7b_engine_extra_embeddings.workers[0].model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 0000000000000..fa6a18e8d93d2 --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,697 @@ +import pytest +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Dict, Tuple + +import torch +import torch.nn.functional as F + +from vllm.lora.layers import ( + LoRAColumnParallelLinear, + LoRAMergedColumnParallelLinear2Slice, + LoRAQKVParallelLinear, + LoRAVocabParallelEmbedding, + LoRARowParallelLinear, + LoRASampler, + LoRAMapping, + LoRALayer, +) +from vllm.lora.models import LoRA, convert_mapping +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.utils import set_random_seed + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: LoRALayer, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRA], Dict[int, List[LoRA]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRA] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. Only useful when + # repeats > 1. + sublora_dict: Dict[int, List[LoRA]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager().init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = LoRA.pack(subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs, index_mapping, prompt_mapping = [], [], [] + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device="cuda")) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device="cuda") * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[512:, :] = 0 + lora_embedding = LoRAVocabParallelEmbedding(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[512:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + 512 + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=512) + expanded_embedding.weight.data[:512, :] = embedding_data + # We need to deepcopy the embedding as it will be modifed + # in place + lora_embedding = LoRAVocabParallelEmbedding( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, 512 + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append( + torch.zeros(embeddings_tensors[0].shape, device="cuda")) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = 512 + (embedding_id * embeddings_tensor_len) + original_input_[-1] = 512 + input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = 512 + embeddings_tensor_len - 1 + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + expanded_embedding.weight[512:512 + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lm_head_sampler(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_sampler_layer(): + linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, + 1024, 32000) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, 32000:] = 0 + sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) + lora_sampler = LoRASampler(sampler, 1024, linear.weight.dtype, + linear.weight.device) + lora_sampler.create_lora_weights(max_loras, lora_config) + + return linear, sampler, lora_sampler + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, sampler, lora_sampler = create_random_sampler_layer() + + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_sampler, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + input_ = torch.rand(20, 1024, device="cuda") + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 32000, + lora_config.lora_extra_vocab_size, + ) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) + + original_weight = linear.weight.clone() + + linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = sampler._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) + result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + sampler.org_vocab_size = 32000 + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_sampler.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 32000, + lora_config.lora_extra_vocab_size) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +def test_linear_parallel(dist_init, num_loras, orientation) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRARowParallelLinear(linear) + else: + linear = ColumnParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAColumnParallelLinear(linear) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [2, 3]) +def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAMergedColumnParallelLinear2Slice(linear) + else: + linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAQKVParallelLinear(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * ( + i + 1 + )] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py new file mode 100644 index 0000000000000..756fc55246092 --- /dev/null +++ b/tests/lora/test_llama.py @@ -0,0 +1,141 @@ +import pytest +import ray +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm, lora_path: str, lora_id: int): + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]" + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_llama_lora(sql_lora_files, tp_size): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=tp_size, + worker_use_ray=True) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + + +def test_llama_tensor_parallel_equality(sql_lora_files): + if torch.cuda.device_count() < 4: + pytest.skip(f"Not enough GPUs for tensor parallelism {4}") + + llm_tp1 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=1, + worker_use_ray=True) + output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) + + del llm_tp1 + ray.shutdown() + + llm_tp2 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=2, + worker_use_ray=True) + output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) + + del llm_tp2 + ray.shutdown() + + assert output_tp1 == output_tp2 + + llm_tp4 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=4, + worker_use_ray=True) + output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) + + del llm_tp4 + ray.shutdown() + + assert output_tp1 == output_tp4 + + +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more" + " conservative than without lora, therefore the number of memory blocks for the KV cache should be " + "less when using lora than when not using lora") diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py new file mode 100644 index 0000000000000..b86f7a480e749 --- /dev/null +++ b/tests/lora/test_lora.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_2slice, _apply_lora_packed_3slice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.float16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(8, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="cuda", + dtype=dtype) + lora_b_stack = torch.zeros(8, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="cuda", + dtype=dtype) + for i in range(lora_a_stack.shape[0]): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora( + input, lora_a_stack, lora_b_stack, + torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), + output) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="cuda"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora_packed_2slice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, m // 2) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_2slice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, m // 2) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) + _apply_lora_packed_3slice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (qkv[0], qkv[1])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_3slice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (qkv[0], qkv[1])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py new file mode 100644 index 0000000000000..de7b245ad4e79 --- /dev/null +++ b/tests/lora/test_lora_manager.py @@ -0,0 +1,473 @@ +import os +from typing import List + +import pytest +import torch +from safetensors.torch import load_file +from torch import nn + +from vllm.config import LoRAConfig +from vllm.lora.layers import (LoRAColumnParallelLinear, LoRARowParallelLinear, + LoRAMergedColumnParallelLinear2Slice) +from vllm.lora.lora import LoRA, PackedLoRA +from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, LoRAMapping) +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, + WorkerLoRAManager) +from vllm.model_executor.layers.linear import RowParallelLinear + + +def test_from_lora_tensors(sql_lora_files): + tensors = load_file( + os.path.join(sql_lora_files, "adapter_model.safetensors")) + new_embeddings = load_file( + os.path.join(sql_lora_files, "new_embeddings.safetensors")) + lora_model = LoRAModel.from_lora_tensors(1, + 8, + 16, + tensors, + "cuda", + embeddings=new_embeddings) + for module_name, lora in lora_model.loras.items(): + assert lora.module_name == module_name + assert lora.rank == 8 + assert lora.lora_alpha == 16 + assert lora.lora_a is not None + assert lora.lora_b is not None + assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[1] == 8 + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), None) + if embeddings_module: + assert torch.equal( + lora.embeddings_tensor, + new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( + device=lora.embeddings_tensor.device)) + else: + assert lora.embeddings_tensor is None + + +def create_lora(lora_id: int, model: nn.Module, + sub_modules: List[str]) -> LoRAModel: + loras = {} + for name in sub_modules: + w = model.get_submodule(name).weight + loras[name] = LoRA( + name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0]], device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def create_packed_lora( + lora_id: int, + model: nn.Module, + module_name, + replaced_module_names, + empty_replaced_module_name=None, +) -> LoRAModel: + w = model.get_submodule(module_name).weight + loras = {} + for replaced_module_name in replaced_module_names: + if replaced_module_name == empty_replaced_module_name: + continue + loras[replaced_module_name] = LoRA( + replaced_module_name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0] // len(replaced_module_names)], + device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def test_replace_submodules(dist_init, dummy_model): + model = dummy_model + manager = LoRAModelManager(model, + 1, + 1, + 1, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8), + lora_target_modules=["dense1", "layer1.dense2"]) + model = manager.model + + assert isinstance(model.get_submodule("dense1"), LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("layer1.dense1"), + LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("layer1.dense2"), + LoRARowParallelLinear) + + +def test_lora_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_id_to_index) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + with pytest.raises(ValueError): + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_id_to_index[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] is None + assert manager.add_lora(model_lora2) + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] is None + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + + +def test_lora_lru_cache_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_id_to_index) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_id_to_index[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 1 + assert manager.add_lora(model_lora2) + assert manager.deactivate_lora(3) + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 1 + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 2 + assert manager.lora_id_to_index[1] == 1 + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 2 + assert manager.lora_id_to_index[1] == 3 + + +def test_lru_lora_model_manager(dist_init, dummy_model): + # This tests just the LRU cache functionality, everything else is + # tested in test_lora_model_manager + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["dense1", "dense2", "lm_head"]) + + assert all(x is None for x in manager.lora_id_to_index) + + # Add up to capacity + assert manager.add_lora(model_lora1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(1) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {1, 2} + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + + # Add over capacity + assert manager.add_lora(model_lora3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(3) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 4 + + # Add 3 again to move it to the top and then add 2 + # should return false since it's in already + assert not manager.add_lora(model_lora3) + assert not manager.activate_lora(3) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {3, 2} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + + # Remove manually + assert manager.remove_lora(3) + assert not manager.remove_lora(3) + + assert set(manager.list_loras()) == {2} + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 2 + + assert manager.add_lora(model_lora3) + assert manager.activate_lora(3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == {4} + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_id_to_index) + + assert not manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_id_to_index) + + +def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = LRUCacheWorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 3, 4} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 6, 7, 8} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 7 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 8 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 6 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.apply_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + # Should remove every LoRA not specified in the request. + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = WorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 3, 4} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None + assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None + + worker_lora_manager.apply_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {6, 7, 8} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 8 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 6 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 7 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.apply_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_packed_loras(dist_init, dummy_model_gate_up): + model = dummy_model_gate_up + model_lora = create_packed_lora( + 1, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"]) + model_lora1 = create_packed_lora( + 2, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"], + empty_replaced_module_name="gate_proj", + ) + + manager = LoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["gate_up_proj"]) + model = manager.model + + assert isinstance(model.get_submodule("gate_up_proj"), + LoRAMergedColumnParallelLinear2Slice) + assert manager.add_lora(model_lora) + assert manager.add_lora(model_lora1) + + packed_lora = model_lora.get_lora("gate_up_proj") + assert packed_lora and isinstance(packed_lora, PackedLoRA) + + assert torch.allclose(packed_lora.lora_a[0], + model_lora.get_lora("gate_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[0], + model_lora.get_lora("gate_proj").lora_b) + assert torch.allclose(packed_lora.lora_a[1], + model_lora.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[1], + model_lora.get_lora("up_proj").lora_b) + + packed_lora1 = model_lora1.get_lora("gate_up_proj") + assert packed_lora1 and isinstance(packed_lora1, PackedLoRA) + + assert packed_lora1.lora_a[0] is None + assert packed_lora1.lora_b[0] is None + assert torch.allclose(packed_lora1.lora_a[1], + model_lora1.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora1.lora_b[1], + model_lora1.get_lora("up_proj").lora_b) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py new file mode 100644 index 0000000000000..26a7d47933309 --- /dev/null +++ b/tests/lora/test_punica.py @@ -0,0 +1,196 @@ +# Based on code from https://github.com/punica-ai/punica + +import pytest +import torch + +import vllm.lora.punica as punica + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), + torch.float32: (None, None), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def _lora_ref_impl( + y_final: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, +): + y_stage_1 = torch.empty( + (x.size(0), wa_T_all.size(-2)), + dtype=torch.float32, + device=x.device, + ) + bs = x.shape[0] + s = torch.tensor(scale, dtype=torch.float32, device=x.device) + for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): + xi = x[i].unsqueeze(0).to(torch.float32) + wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + + tmp = xi @ wa + y_stage_1[i] = tmp.squeeze(0) + y_final[i] += (tmp @ wb).squeeze(0) * s + return y_final, y_stage_1 + + +H1 = H2 = [ + 128, + 256, + 512, + 1024, + 1280, + 2048, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 5120, + 5504, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 32000, + 32256, +] +SEED = [0xabcdabcd987] + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness(dtype_str, h1, h2, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all = torch.randn(num_loras, + num_layers, + h2, + r, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) + + y_our = y.clone() + punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, + scale) + + assert_close(y_ref, y_our) + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness_slice(dtype_str, h1, h2, seed): + if h2 % 3 != 0 or h2 // 3 not in H1: + pytest.skip("h2 must be divisible by 3 and in supported shapes") + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all_0 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_1 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_2 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all_0 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_1 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_2 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + s = h2 // 3 + + y_ref = y.clone() + _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale) + + y_our = y.clone() + punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale, 0, s) + punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale, s, s) + punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale, s * 2, s) + + assert_close(y_ref[:, :s], y_our[:, :s]) + assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) + assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py new file mode 100644 index 0000000000000..af0fc41f3fa45 --- /dev/null +++ b/tests/lora/test_tokenizer.py @@ -0,0 +1,69 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import MultiLoRATokenizer, get_lora_tokenizer + + +@pytest.mark.asyncio +async def test_transformers_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=None) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +async def test_transformers_tokenizer_lora(sql_lora_files): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + lora_request) != tokenizer.get_lora_tokenizer(None) + assert tokenizer.get_lora_tokenizer( + lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py new file mode 100644 index 0000000000000..a874a72d919fa --- /dev/null +++ b/tests/lora/test_utils.py @@ -0,0 +1,172 @@ +from collections import OrderedDict + +from torch import nn + +from vllm.lora.utils import (LRUCache, parse_fine_tuned_lora_name, + replace_submodule) + + +def test_parse_fine_tuned_lora_name(): + fixture = { + ("base_model.model.lm_head.lora_A.weight", "lm_head", True), + ("base_model.model.lm_head.lora_B.weight", "lm_head", False), + ( + "base_model.model.model.embed_tokens.lora_embedding_A", + "model.embed_tokens", + True, + ), + ( + "base_model.model.model.embed_tokens.lora_embedding_B", + "model.embed_tokens", + False, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "model.layers.9.mlp.down_proj", + True, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "model.layers.9.mlp.down_proj", + False, + ), + } + for name, module_name, is_lora_a in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) + + +def test_replace_submodule(): + model = nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ])) + + sigmoid = nn.Sigmoid() + + replace_submodule(model, "act1", sigmoid) + assert dict(model.named_modules())["act1"] == sigmoid + + dense2 = nn.Linear(1, 5) + replace_submodule(model, "seq1.dense2", dense2) + assert dict(model.named_modules())["seq1.dense2"] == dense2 + + +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache.get(2) == 2 + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py new file mode 100644 index 0000000000000..8c11f6c472ff7 --- /dev/null +++ b/tests/lora/test_worker.py @@ -0,0 +1,56 @@ +import os +import random +import tempfile +from unittest.mock import patch + +from vllm.lora.models import LoRAMapping +from vllm.lora.utils import LoRARequest +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig +from vllm.worker.worker import Worker + + +@patch.dict(os.environ, {"RANK": "0"}) +def test_worker_apply_lora(sql_lora_files): + worker = Worker( + model_config=ModelConfig("meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None), + parallel_config=ParallelConfig(1, 1, False), + scheduler_config=SchedulerConfig(32, 32, 32, 256), + lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, + max_loras=32), + distributed_init_method=f"file://{tempfile.mkstemp()[1]}", + ) + worker.init_model() + worker.load_model() + + worker.apply_loras([], LoRAMapping([], [])) + assert worker.list_loras() == set() + + n_loras = 32 + lora_requests = [ + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) + ] + + worker.apply_loras(lora_requests, LoRAMapping([], [])) + assert worker.list_loras() == { + lora_request.lora_int_id + for lora_request in lora_requests + } + + for i in range(32): + random.seed(i) + iter_lora_requests = random.choices(lora_requests, + k=random.randint(1, n_loras)) + random.shuffle(iter_lora_requests) + iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] + worker.apply_loras(iter_lora_requests, LoRAMapping([], [])) + assert worker.list_loras().issuperset( + {lora_request.lora_int_id + for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 0000000000000..072a0d957758b --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,88 @@ +from typing import List, Optional + +import torch + +from vllm.lora.lora import LoRA + + +class DummyLoRAManager: + + def __init__(self): + super().__init__() + self._loras = {} + + def set_module_lora(self, module_name: str, lora: LoRA): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> Optional[LoRA]: + return self._loras.get(module_name, None) + + def init_random_lora(self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0): + lora = LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([weight.shape[1], rank], + dtype=weight.dtype, + device="cuda"), + lora_b=torch.rand([rank, weight.shape[0]], + dtype=weight.dtype, + device="cuda"), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand(5, + generate_embeddings_tensor, + dtype=weight.dtype, + device="cuda") + self.set_module_lora(module_name, lora) + + return lora + + def init_lora(self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None): + lora = LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([input_dim, rank], device="cuda"), + lora_b=torch.rand([rank, output_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: List[int], + noop_lora_index: List[int] = None, + rank=8, + ): + base_loras = [] + noop_lora_index = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index, + ) + base_loras.append(base_lora) + packed_lora = LoRA.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora diff --git a/vllm/config.py b/vllm/config.py index 1adf830ffcc12..eef6e53be2855 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ from typing import Optional, Union +from dataclasses import dataclass import os import torch @@ -350,6 +351,36 @@ def _verify_args(self) -> None: f"({self.max_num_seqs}).") +@dataclass +class LoRAConfig: + max_lora_rank: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + max_loras: Optional[int] = None + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + + self.max_loras = scheduler_config.max_num_seqs + if self.max_cpu_loras is None: + self.max_cpu_loras = scheduler_config.max_num_seqs + elif self.max_cpu_loras < scheduler_config.max_num_seqs: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({scheduler_config.max_num_seqs})") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc2fb95..f8fb4c6ea1518 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,10 +1,11 @@ import enum import time -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union, Set from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory +from vllm.lora.request import LoRARequest from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -36,6 +37,7 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], + lora_enabled: bool = False, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run @@ -47,11 +49,23 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + if lora_enabled: + self.num_loras = len(set(self.lora_requests)) + self._sort_by_lora_ids() + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups.sort(key=lambda g: ( + g.lora_request.lora_int_id if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + class Scheduler: @@ -59,9 +73,11 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + lora_enabled: bool = False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + self.lora_enabled = lora_enabled self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -202,6 +218,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, + lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -274,6 +291,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], + lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -299,6 +317,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + lora_request=seq_group.lora_request, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 746b0e64ece7b..4d1233c473980 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) @dataclass @@ -33,6 +33,11 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + enable_lora: bool = False + max_lora_rank: int = 8 + lora_extra_vocab_size: int = 256 + lora_dtype = 'bfloat16' + lora_max_cpu_loras: int = -1 def __post_init__(self): if self.tokenizer is None: @@ -182,6 +187,30 @@ def add_cli_args( choices=['awq', 'squeezellm', None], default=None, help='Method used to quantize the weights') + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='enable lora adapters') + parser.add_argument('--max-lora-rank', + type=int, + default=16, + help='max LoRA rank') + parser.add_argument('--lora-extra-vocab-size', + type=int, + default=256, + help='LoRA extra vocab size') + parser.add_argument('--lora-dtype', + type=str, + default=EngineArgs.dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help='data type for lora') + parser.add_argument( + '--lora-max-cpu-loras', + type=int, + default=-1, + help=('Maximum number of loras to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) return parser @classmethod @@ -194,7 +223,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + Optional[LoRAConfig]]: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, @@ -212,7 +242,13 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > + 0 else None) if self.enable_lora else None + return model_config, cache_config, parallel_config, scheduler_config, lora_config @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7dcd2eb632c4c..53bc7080b3273 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,6 +4,7 @@ from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) +from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine @@ -198,6 +199,50 @@ async def step_async(self) -> List[RequestOutput]: return self._process_model_outputs(output, scheduler_outputs) + ignored + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + async def _run_workers_async( self, method: str, @@ -318,7 +363,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: await self.engine.add_request.remote(**new_request) else: - self.engine.add_request(**new_request) + await self.engine.add_request_async(**new_request) if finished_requests: await self._engine_abort(finished_requests) @@ -357,6 +402,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -370,7 +416,8 @@ async def add_request( logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling params: {sampling_params}, " - f"prompt token ids: {shortened_token_ids}.") + f"prompt token ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") if not self.is_running: if self.start_engine_loop: @@ -382,12 +429,22 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + lora_request=lora_request, + ) return stream @@ -396,7 +453,8 @@ async def generate( prompt: Optional[str], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None) -> RequestOutput: + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None) -> RequestOutput: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -410,6 +468,7 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -420,11 +479,14 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + ) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e33d8aa2a2131..c6e74b1d26586 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,8 +3,9 @@ from functools import partial from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray @@ -15,7 +16,7 @@ SequenceGroupMetadata, SequenceGroupOutputs, SequenceOutputs, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) + MultiLoRATokenizer) from vllm.utils import Counter if ray: @@ -65,6 +66,7 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -90,17 +92,13 @@ def __init__( self.cache_config = cache_config assert self.cache_config.sliding_window == getattr( self.model_config.hf_config, "sliding_window", None) + self.lora_config = lora_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) + self._init_tokenizer() self.seq_counter = Counter() # Create the parallel GPU workers. @@ -137,6 +135,7 @@ def _init_workers(self, distributed_init_method: str): self.scheduler_config, 0, distributed_init_method, + lora_config=self.lora_config, ) self.workers.append(worker) self._run_workers( @@ -150,6 +149,18 @@ def _init_workers(self, distributed_init_method: str): max_parallel_loading_workers, ) + def _init_tokenizer(self, **kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(kwargs) + self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( + self.model_config.tokenizer, **init_kwargs) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -183,6 +194,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config, None, None, + lora_config=self.lora_config, )) self._run_workers( "init_model", @@ -198,6 +210,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache.""" @@ -246,6 +262,20 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + def add_request( self, request_id: str, @@ -253,6 +283,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -270,20 +301,26 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(prompt) + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, lora_request) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -648,7 +685,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer, + self.tokenizer.get_lora_tokenizer(seq.lora_request), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -689,11 +726,29 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + if ((not sampling_params.ignore_eos) and seq.get_last_token_id() + == self.tokenizer.get_lora_tokenizer( + seq.lora_request).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + def _run_workers_in_batch( self, workers, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b05ba71c6d352..9061909d72c33 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.lora.request import LoRARequest from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.outputs import RequestOutput @@ -109,6 +110,7 @@ def generate( sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -123,6 +125,7 @@ def generate( prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. Returns: A list of `RequestOutput` objects containing the generated @@ -149,7 +152,10 @@ def generate( prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, sampling_params, token_ids) + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request) return self._run_engine(use_tqdm) def _add_request( @@ -157,10 +163,14 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000000000..6ba8b0585847d --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,1002 @@ +# pylint: disable=unused-argument +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear, + MergedColumnParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim + +if TYPE_CHECKING: + pass + + +def _apply_lora( + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + indices: torch.Tensor, + output: torch.Tensor, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: (num_loras, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + return output.view_as(org_output) + + +def _apply_lora_packed_2slice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_dim: int, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of 2 sublayers + (slices) packed together (eg. gate_proj + up_proj -> + gate_up_proj). + + Both slices must have the same size (output_dim), meaning the output + tensor will have size output_dim*2. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 2 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 2 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim*2) + output_dim: scalar + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, + 1.0, 0, output_dim) + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_dim, output_dim) + return output.view_as(org_output) + + +def _apply_lora_packed_3slice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, int], +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of 3 sublayers + (slices) packed together (attention projection). The + first slice (Q) may have different size from the two subsequent + slices (K, V). + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: 2 element tuple of (q_slice_size, kv_slice_size) + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, + 1.0, 0, output_slices[0]) + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_slices[0], output_slices[1]) + add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, + 1.0, output_slices[0] + output_slices[1], output_slices[1]) + return output.view_as(org_output) + + +@dataclass +class LoRAMapping: + index_mapping: Tuple[int, ...] + prompt_mapping: Tuple[int, ...] + + def __eq__(self, __value: object) -> bool: + return (isinstance(__value, self.__class__) + and self.prompt_mapping == __value.prompt_mapping + and self.index_mapping == __value.index_mapping) + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +class LoRALayer(nn.Module): + + def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, + model_config: PretrainedConfig) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + """Sets the mapping indices.""" + ... + + +class LoRAVocabParallelEmbedding(LoRALayer): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + lora_vocab_start_idx = self.base_layer.org_vocab_size + weights_idx = None + if self.base_layer.vocab_end_index > lora_vocab_start_idx: + # We can start adding lora weights + weights_idx = max( + lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) + self.embeddings_slice = (self.base_layer.vocab_start_index - + self.base_layer.org_vocab_size + + weights_idx, + self.base_layer.vocab_end_index - + self.base_layer.org_vocab_size) + self.embeddings_weights = self.base_layer.weight.data[weights_idx:] + self.embeddings_weights.fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.embeddings_indices = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1]].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + self.embeddings_weights.copy_( + self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2]) + [self.embeddings_slice[0]:self.embeddings_slice[1]]) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.embeddings_indices = embeddings_indices + self.indices_len = indices_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], -1) + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + return full_output.view_as(full_output_org) + + +class LoRAColumnParallelLinear(LoRALayer): + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.output_dim = self.lora_b_stacked.shape[1] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply_weights(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @property + def linear_weights(self): + return self.base_layer.linear_weights + + +class LoRAMergedColumnParallelLinear2Slice(LoRAColumnParallelLinear): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0] // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + + self.indices: Optional[torch.Tensor] = None + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[0][:, + start_idx:end_idx], lora_b[1][:, + start_idx:end_idx] + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_2slice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_dim, + ) + return output + + +class LoRAQKVParallelLinear(LoRAColumnParallelLinear): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + # q, k, v + self.lora_a_stacked = (torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + self.lora_b_stacked = (torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + else: + if lora_b[0] is not None: + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_b[1] is not None: + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + if lora_b[2] is not None: + self.lora_b_stacked[2][ + index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( + lora_b[2].T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_3slice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_slices, + ) + return output + + +class LoRARowParallelLinear(LoRALayer): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.base_layer.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.base_layer.weight.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return self.base_layer.weight + + +class LoRASampler(LoRALayer): + + def __init__( + self, + base_layer: Sampler, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + self.indices = None + self.indices_padded = None + self.indices_len = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = sampler_indices + self.indices_padded = sampler_indices_padded + self.indices_len = indices_len + + def _get_logits( + self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + + logits[:, self.base_layer.org_vocab_size:] = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, + self.indices_padded[:self.indices_len[2]]).nan_to_num_( + nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + _apply_lora( + hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[1]], + logits, + ) + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> LoRALayer: + supported_layer_types = { + VocabParallelEmbedding: LoRAVocabParallelEmbedding, + ColumnParallelLinear: LoRAColumnParallelLinear, + QKVParallelLinear: LoRAQKVParallelLinear, + MergedColumnParallelLinear: LoRAMergedColumnParallelLinear2Slice, + RowParallelLinear: LoRARowParallelLinear, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_sampler( + layer: Sampler, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LoRASampler: + ret = LoRASampler(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000000000..042a98597ab26 --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,120 @@ +from typing import List, Optional + +import torch + + +class LoRA: + """A LoRA that is composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + @classmethod + def pack(cls, loras: List["LoRA"]) -> "PackedLoRA": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = PackedLoRA( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "LoRA": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + +class PackedLoRA(LoRA): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[int], + lora_a: List[torch.Tensor], + lora_b: List[torch.Tensor], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ + lora_alpha / self.rank for lora_alpha in self.lora_alphas + ] + + def optimize(self) -> "PackedLoRA": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: + continue + self.lora_b[i] *= self.scaling[i] + self.scaling[i] = 1 + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000000000..913234475b182 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,666 @@ +import copy +import json +import logging +import math +import os +import re +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, + Union) + +import safetensors.torch +import torch +from torch import nn + +from vllm.config import LoRAConfig +from vllm.utils import LRUCache + +from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRA +from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule) + +logger = logging.getLogger(__name__) + +PACKED_MODULES_CFG = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], +} + +TARGET_MODULES_QKV = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", +] + +EMBEDDING_MODULES = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", +} + +EMBEDDING_PADDING_MODULES = ["lm_head"] + +_GLOBAL_LORA_ID = 0 + + +def convert_mapping( + mapping: LoRAMapping, lora_id_to_index: List[Optional[int]], + max_loras: int, vocab_size: int, extra_vocab_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_id_to_index: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + indices_len: List of lengths of the above tensors. + """ + indices = list(mapping.index_mapping).copy() + embedding_indices = indices.copy() + lora_indices = indices.copy() + prompt_mapping = [ + lora_id_to_index.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(indices)): + # TODO index can be slow. optimize + lora_idx = (lora_id_to_index.index(indices[i]) + if indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if indices[i] > 0 else 0 + indices[i] = i + lora_indices[i] = lora_idx + + indices = torch.tensor([indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size) + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1]) + + return (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, indices_len) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +def _create_dummy_lora(module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRA": + lora_a = torch.zeros([input_dim, rank], dtype=dtype, device=device) + lora_b = torch.zeros([rank, output_dim], dtype=dtype, device=device) + embeddings_tensor = torch.rand( + 10, embeddings_tensor_dim, dtype=dtype, + device=device) if embeddings_tensor_dim else None + if str(device) == "cpu": + lora_a = lora_a.pin_memory() + lora_b = lora_b.pin_memory() + if embeddings_tensor is not None: + embeddings_tensor = embeddings_tensor.pin_memory() + return LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRA], + ) -> None: + self.id = lora_model_id + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRA] = loras + + def get_lora(self, module_name: str) -> Optional[LoRA]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + loras: Dict[str, LoRA] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + EMBEDDING_MODULES[embeddings_module]].to( + device=device, dtype=dtype) + if device == "cpu": + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRA(module_name, rank, lora_alpha, None, + None, lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if device == "cpu": + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + if any(name in module_name + for name in EMBEDDING_PADDING_MODULES + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if device == "cpu": + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for _, lora in loras.items(): + lora.optimize() + return cls(lora_model_id, rank, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint.""" + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + if os.path.isfile(lora_tensor_path): + tensors = safetensors.torch.load_file(lora_tensor_path) + elif os.path.isfile(lora_bin_file_path): + tensors = torch.load(lora_bin_file_path) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path) + + with open(lora_config_path) as f: + config = json.load(f) + rank = config["r"] + lora_alpha = config["lora_alpha"] + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + ) + + +class LoRAModelManager: + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + lora_target_modules: the target modules patterns to be adapted. + Support both single module name and a list of module names. + packed_modules_mapping: the mapping for packed modules. vLLM + packs some modules into one module, e.g., qkv_proj + is packed of q_proj, k_proj, and v_proj. These modules + have a single layer in the original model, but they are split + into multiple layers in the adapted model. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_id_to_index: List[Optional[int]] = [None] * self._lora_slots + self.vocab_size = vocab_size + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_seqs, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_seqs, + dtype=torch.long, + device="cuda") + self.embeddings_indices = torch.empty(2, + self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.offsets = [] + self.indices_len = [None] * 4 + + self.model: nn.Module = model + self.lora_target_modules: List[str] = ([ + lora_target_modules + ] if isinstance(lora_target_modules, str) else lora_target_modules) + self.lora_target_modules = copy.deepcopy(lora_target_modules) + self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "LoRALayer"] = {} + self._registered_loras: Dict[int, LoRAModel] = {} + self._active_loras: Dict[int, None] = {} + self._last_mapping = None + self._create_lora_modules() + self.model.lora_manager = self + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def _lora_slots(self) -> int: + return self.max_num_seqs + + def __len__(self) -> int: + return len(self._registered_loras) + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id in self._active_loras: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_id_to_index) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_loras[lora_id] = None + lora_model = self._registered_loras[lora_id] + logger.debug( + f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + self.lora_id_to_index[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_lora(self, lora_id: int): + try: + index = self.lora_id_to_index.index(lora_id) + self.lora_id_to_index[index] = None + except ValueError: + pass + + def deactivate_lora(self, lora_id: int) -> bool: + if lora_id in self._active_loras: + self._deactivate_lora(lora_id) + self._active_loras.pop(lora_id) + return True + return False + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + if lora.id not in self._registered_loras: + if len(self._registered_loras) >= self.capacity: + raise RuntimeError("No free LoRA slots.") + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + return True + return False + + def remove_lora(self, lora_id: int) -> bool: + """Remove a LoRAModel from the manager.""" + # TODO: should we check active lora? + self.deactivate_lora(lora_id) + return bool(self._registered_loras.pop(lora_id, None)) + + # TODO see if this can be vectorized + def convert_mapping(self, mapping: LoRAMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, + indices_len) = convert_mapping(mapping, self.lora_id_to_index, + self._lora_slots + 1, self.vocab_size, + self.lora_config.lora_extra_vocab_size) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self.embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + # Maintain the reference + self.indices_len[:] = indices_len + + def set_row_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + if self._last_mapping != lora_mapping: + self.convert_mapping(lora_mapping) + self._last_mapping = lora_mapping + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras) + + def get_lora(self, lora_id: int) -> Optional[LoRAModel]: + return self._registered_loras.get(lora_id, None) + + def remove_all_loras(self) -> bool: + """Remove all LoRAModels from the manager.""" + self._registered_loras.clear() + self.lora_id_to_index = [None] * self._lora_slots + self._active_loras.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name): + continue + + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.capacity, self.lora_config, + self.model.config)) + # (yard1): TODO make this more robust + if "lm_head" in module_name: + sampler_module = self.model.get_submodule("sampler") + new_module = replace_submodule( + self.model, "sampler", + from_layer_sampler(sampler_module, module, self.capacity, + self.lora_config, self.model.config)) + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + new_module.set_mapping(self.base_indices, self.sampler_indices, + self.sampler_indices_padded, + self.embeddings_indices, self.indices_len) + + def register_module(self, module_name: str, module: "LoRALayer"): + assert isinstance(module, LoRALayer) + self.modules[module_name] = module + + def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name) or not isinstance( + module, LoRALayer): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + if parts[-1] in EMBEDDING_MODULES: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = _create_dummy_lora( + module_name, + input_dim, + output_dim, + rank, + module.base_layer.weight.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = _create_dummy_lora( + module_name, + module.base_layer.weight.shape[1], + module.base_layer.weight.shape[0], + rank, + module.base_layer.weight.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras = [] + for r in replacements: + lora = _create_dummy_lora( + module_name + "." + r, + module.base_layer.weight.shape[1], + module.base_layer.weight.shape[0] // len(replacements), + rank, + module.base_layer.weight.dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = LoRA.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.lora_target_modules) + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name) + if not replacements: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = LoRA.pack(replacement_loras) + + +class LoRALRUCache(LRUCache): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_lora_fn = deactivate_lora_fn + + def _on_remove(self, key: Hashable, value: Any): + logger.debug(f"Removing LoRA. int id: {key}") + self.deactivate_lora_fn(key) + return super()._on_remove(key, value) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, lora_target_modules, + packed_modules_mapping) + self._registered_loras: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_lora) + self._active_loras: LoRALRUCache = LoRALRUCache( + self.max_num_seqs, self._deactivate_lora) + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras.cache) + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + was_added = False + if lora.id not in self._registered_loras: + was_added = True + logger.debug(f"Adding LoRA. Model id: {lora.id}, " + f"int id: {lora.id}") + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + else: + # We always touch to update the LRU cache order + self._registered_loras.touch(lora.id) + return was_added + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_loras and len( + self._active_loras) >= self.max_num_seqs: + self._active_loras.remove_oldest() + result = super().activate_lora(lora_id) + # We always touch to update the LRU cache order + self._active_loras.touch(lora_id) + return result + + def remove_oldest_lora(self) -> bool: + if len(self._registered_loras) > 0: + self._registered_loras.remove_oldest() + return True + return False + + +def create_lora_adapter( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config:LoRAConfig, + target_modules: Union[str, + List[str]] = TARGET_MODULES_QKV, + lora_manager_cls:Type[LoRAModelManager] = LoRAModelManager, **kwargs)\ + -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not getattr(model, "supports_lora", False): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + lora_target_modules=target_modules, + **kwargs) + return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000000000..ac96931b2d071 --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,173 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + +import_exc = None + +try: + import vllm._punica_C as punica_kernels +except ImportError as e: + import_exc = e + +if import_exc is None: + + def bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + ): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indicies: Shape: `[B]`. Indices of the weight matrices. + layer_idx: Layer index of the weight matrices. + scale: Scaling factor. + """ + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical innacuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, + 1.0) + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + scale) + + def add_lora_slice(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + y_offset: int, + y_slice_size: int, + *, + buffer: Optional[torch.Tensor] = None): + """ + Same as `add_lora` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv_low_level( + buffer, + x, + wa_t_all, + indicies, + layer_idx, + 1.0, + x.size(1), + buffer.size(1), + 0, + ) + punica_kernels.dispatch_bgmv_low_level( + y, + buffer, + wb_t_all, + indicies, + layer_idx, + scale, + buffer.size(1), + y_slice_size, + y_offset, + ) + +else: + + def _raise_exc( + *args, # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument + ): + if torch.cuda.get_device_capability() < (8, 0): + raise ImportError( + "LoRA kernels require compute capability>=8.0") from import_exc + else: + raise import_exc + + bgmv = _raise_exc + add_lora = _raise_exc + add_lora_slice = _raise_exc + +__all__ = [ + "bgmv", + "add_lora", + "add_lora_slice", +] diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000000000..3ae5be59b1b88 --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + + +@dataclass +class LoRARequest: + lora_id: str + lora_int_id: int + lora_local_path: str + + def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError( + f"lora_int_id must be > 0, got {self.lora_int_id}") + + def __eq__(self, value: object) -> bool: + return isinstance(value, LoRARequest) and self.lora_id == value.lora_id + + def __hash__(self) -> int: + return self.lora_int_id diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000000000..f67a3812fb046 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,39 @@ +import logging +from typing import Tuple + +from torch import nn + +logger = logging.getLogger(__name__) + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + assert parts[0] == "base_model" + assert parts[1] == "model" + if parts[-1] == "weight": + assert parts[-2] == "lora_A" or parts[-2] == "lora_B" + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported format") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000000000..be6f4cf0589bd --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,266 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Union + +import torch + +from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_adapter) +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class AbstractWorkerLoRAManager(ABC): + """Abstract class for managing LoRA models on the worker side.""" + + def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, + vocab_size: int, lora_config: LoRAConfig, + device: torch.device): + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.lora_config = lora_config + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + ... + + @abstractmethod + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + ... + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + ... + + @abstractmethod + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + ... + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + ... + + @abstractmethod + def remove_all_loras(self) -> bool: + ... + + @abstractmethod + def list_loras(self) -> Set[int]: + ... + + +class DisabledWorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that does nothing.""" + + @property + def is_enabled(self) -> bool: + return False + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + return model + + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + return + + def add_lora(self, lora_request: LoRARequest) -> bool: + return False + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + return False + + def remove_lora(self, lora_id: int) -> bool: + return False + + def remove_all_loras(self) -> bool: + return + + def list_loras(self) -> Set[int]: + return set() + + +class WorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_model_cls: Type[LoRAModel] = LoRAModel, + ): + self._lora_manager: Optional[LoRAModelManager] = None + self._lora_model_cls = lora_model_cls + super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, + lora_config, device) + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_adapter( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + target_modules=target_modules, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._lora_manager_cls, + ) + self._lora_manager = lora_manager + return lora_manager.model + + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self._apply_loras(lora_requests) + self._lora_manager.set_row_lora_mapping(lora_mapping) + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_that_exist = self.list_loras() + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.max_num_seqs: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.max_num_seqs}).") + + new_loras = set(loras_map) + loras_to_add = new_loras - loras_that_exist + loras_to_remove = loras_that_exist - new_loras + + for lora_id in loras_to_remove: + self.remove_lora(lora_id) + + for lora_id in loras_to_add: + self.add_lora(loras_map[lora_id]) + + def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + try: + lora = self._lora_model_cls.from_local_checkpoint( + lora_request.lora_local_path, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + ) + except Exception as e: + raise RuntimeError( + f"Loading lora {lora_request.lora_local_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + return self._lora_manager.add_lora( + self._lora_manager.create_dummy_lora(lora_request.lora_int_id, + rank)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + self._lora_manager.activate_lora(lora.id) + return loaded + + def remove_lora(self, lora_id: int) -> bool: + return self._lora_manager.remove_lora(lora_id) + + def remove_all_loras(self) -> bool: + self._lora_manager.remove_all_loras() + + def list_loras(self) -> Set[int]: + return set(self._lora_manager.list_loras()) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _lora_manager_cls: Type[ + LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_adapter( + model, + target_modules=target_modules, + lora_manager_cls=self._lora_manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._lora_manager = lora_manager + return lora_manager.model + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.max_num_seqs: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.max_num_seqs}).") + for lora in loras_map.values(): + self.add_lora(lora) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_loras(): + # Remove before we load the new lora to save memory + if len(self._lora_manager) + 1 > self._lora_manager.capacity: + self._lora_manager.remove_oldest_lora() + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + self._lora_manager.activate_lora(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c874ec5921155..5bce287a92ae5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -29,9 +29,24 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, vocab_size: int) -> None: + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: super().__init__() self.vocab_size = vocab_size + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, @@ -44,8 +59,7 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, - self.vocab_size) + logits = self._get_logits(hidden_states, embedding, embedding_bias) # Apply logits processors (if any). logits = _apply_logits_processors(logits, input_metadata) @@ -97,19 +111,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :vocab_size] - return logits - - def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b08d5555b0faa..9e4ac26e73d00 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -43,16 +43,19 @@ class VocabParallelEmbedding(torch.nn.Module): num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). """ def __init__(self, num_embeddings: int, embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings + self.org_vocab_size = org_num_embeddings or num_embeddings self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.embedding_dim = embedding_dim if params_dtype is None: @@ -77,7 +80,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.num_embeddings + assert loaded_weight.shape[parallel_dim] == self.org_vocab_size loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) @@ -114,14 +117,17 @@ class ParallelLMHead(VocabParallelEmbedding): embedding_dim: size of hidden state. bias: whether to use bias. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). """ def __init__(self, num_embeddings: int, embedding_dim: int, bias: bool = False, - params_dtype: Optional[torch.dtype] = None): - super().__init__(num_embeddings, embedding_dim, params_dtype) + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 54b87c4b866e3..cf84b9810c575 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,12 +1,12 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type +from typing import Optional, Type import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import * from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -58,7 +58,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -87,7 +88,12 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config, linear_method) + # TODO(yard1): Clean this up (lora_config) + try: + model = model_class(model_config.hf_config, linear_method, + lora_config) + except TypeError: + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8e7344da4888e..999c1097d0a42 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -43,6 +43,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -223,14 +224,19 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, linear_method) @@ -264,18 +270,25 @@ def forward( class LlamaForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index d18572610741c..c67c3fae2028a 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -43,6 +43,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -217,15 +218,20 @@ def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MistralDecoderLayer(config, linear_method) @@ -259,18 +265,27 @@ def forward( class MistralForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = MistralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = MistralModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index fe54926e06e64..534e9d5ea8a53 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,6 +2,7 @@ from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, SequenceStatus) +from vllm.lora.request import LoRARequest class CompletionOutput: @@ -16,6 +17,7 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -26,6 +28,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index self.text = text @@ -33,6 +36,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.lora_request = lora_request def finished(self) -> bool: return self.finish_reason is not None @@ -56,6 +60,7 @@ class RequestOutput: prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -66,6 +71,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -73,6 +79,7 @@ def __init__( self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished + self.lora_request = lora_request @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -108,8 +115,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls(seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + lora_request=seq_group.lora_request) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -117,4 +129,5 @@ def __repr__(self) -> str: f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " - f"finished={self.finished})") + f"finished={self.finished}, " + f"lora_request={self.lora_request})") diff --git a/vllm/sequence.py b/vllm/sequence.py index ecfaee6e8c3d6..06170ab79d69a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] @@ -105,6 +106,7 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + lora_request: LoRA request. """ def __init__( @@ -113,10 +115,12 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.lora_request = lora_request self.data = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -228,6 +232,7 @@ class SequenceGroup: seqs: The list of sequences. sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. + lora_request: LoRA request. """ def __init__( @@ -236,11 +241,13 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -335,6 +342,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + lora_request: LoRA request. """ def __init__( @@ -344,12 +352,18 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.lora_request = lora_request + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 class SequenceOutputs: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 5b0481480a63b..b84f50c3bd5d7 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -4,6 +4,8 @@ PreTrainedTokenizerFast) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import make_async, LRUCache logger = init_logger(__name__) @@ -69,6 +71,86 @@ def get_tokenizer( return tokenizer +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[PreTrainedTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_local_path, *args, + **kwargs) + except OSError as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + f"No tokenizer found in {lora_request.lora_local_path}, " + "using base model tokenizer instead. " + f"(Exception: {str(e)})") + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) + + +class MultiLoRATokenizer: + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def ping(self): + return True + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/utils.py b/vllm/utils.py index 47e51048fed45..9282db842c1d2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,9 +4,20 @@ import psutil import torch +import asyncio +from functools import partial +from typing import ( + Awaitable, + Callable, + TypeVar, +) +from collections import OrderedDict +from typing import Any, Hashable, Optional from vllm._C import cuda_utils +T = TypeVar("T") + class Device(enum.Enum): GPU = enum.auto() @@ -27,6 +38,69 @@ def reset(self) -> None: self.counter = 0 +class LRUCache: + + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> Any: + return self.get(key) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: Any) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def _on_remove(self, key: Hashable, value: Any): + pass + + def remove_oldest(self): + if not self.cache: + return + key, value = self.cache.popitem(last=False) + self._on_remove(key, value) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + run_on_remove = key in self.cache + value = self.cache.pop(key, default_value) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest() + self.cache.clear() + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html @@ -53,3 +127,19 @@ def random_uuid() -> str: def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() + + +def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args, **kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 702767ebd8d09..d316b9588bf75 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,12 +1,13 @@ """A GPU worker class.""" +import gc import os -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) @@ -14,6 +15,14 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.utils import get_gpu_memory +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import ( + DisabledWorkerLoRAManager, + LRUCacheWorkerLoRAManager, +) +from vllm.lora.layers import LoRAMapping + +LORA_WARMUP_RANK = 8 class Worker: @@ -31,12 +40,14 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, distributed_init_method: Optional[str] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.rank = rank self.distributed_init_method = distributed_init_method + self.lora_config = lora_config # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). @@ -46,6 +57,7 @@ def __init__( self.cache_engine = None self.cache_events = None self.gpu_cache = None + self.lora_manager = None def init_model(self): # This env var set by Ray causes exceptions with graph building. @@ -69,7 +81,21 @@ def init_model(self): set_random_seed(self.model_config.seed) def load_model(self): - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.lora_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) + self.model = self.lora_manager.create_lora_adapter(self.model) + else: + self.lora_manager = DisabledWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) @torch.inference_mode() def profile_num_available_blocks( @@ -91,6 +117,24 @@ def profile_num_available_blocks( sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs + + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + if self.lora_config: + for idx in range(max_num_seqs): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_id=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + seqs = [] for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + @@ -102,11 +146,21 @@ def profile_num_available_blocks( seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, + lora_request=dummy_lora_requests[group_id] + if dummy_lora_requests else None, ) seqs.append(seq) - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seqs) + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + prepared_lora_requests, + ) = self._prepare_inputs(seqs) + + if dummy_lora_requests: + self.apply_loras(prepared_lora_requests, lora_mapping) # Execute the model. num_layers = self.model_config.get_num_layers(self.parallel_config) @@ -131,6 +185,8 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.lora_manager.remove_all_loras() + gc.collect() torch.cuda.empty_cache() # Reset the seed to ensure that the random state is not affected by @@ -151,7 +207,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: def _prepare_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, LoRAMapping, + Set[LoRARequest]]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -160,6 +217,9 @@ def _prepare_inputs( selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + lora_requests: Set[LoRARequest] = set() + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] # Add prompt tokens. prompt_lens: List[int] = [] @@ -170,6 +230,7 @@ def _prepare_inputs( seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id # Use any sequence in the group. seq_id = seq_ids[0] @@ -187,6 +248,17 @@ def _prepare_inputs( categorized_sample_indices_start_idx) categorized_sample_indices_start_idx += 1 + if lora_id > 0: + # if we are preparing inputs for the warmup step, we want the + # lora computation to take up the maximum possible amount of + # memory that way we can get a tighter upper bound on the + # amount of memory we can use and therefore not oom. If + # for_warmup is true, we add the lora lora mapping that is used + # during generation. + lora_requests.add(seq_group_metadata.lora_request) + lora_index_mapping.append([lora_id] * prompt_len) + lora_prompt_mapping.append(lora_id) + input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -233,6 +305,7 @@ def _prepare_inputs( seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id num_seqs = len(seq_ids) selected_token_indices.extend( @@ -255,6 +328,7 @@ def _prepare_inputs( if self.sliding_window is not None: context_len = min(context_len, self.sliding_window) input_positions.append([position]) + lora_index_mapping.append([lora_id]) block_table = seq_group_metadata.block_tables[seq_id] @@ -274,6 +348,11 @@ def _prepare_inputs( block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) + # Update LoRA mapping. + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + lora_prompt_mapping.append(lora_id) + padded_input_tokens = [ _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens ] @@ -281,6 +360,10 @@ def _prepare_inputs( _pad_to_max(positions, max_seq_len, pad=0) for positions in input_positions ] + padded_lora_input_mapping = [ + _pad_to_max(mapping, max_seq_len, pad=0) + for mapping in lora_index_mapping + ] padded_slot_mapping = [ _pad_to_max(mapping, max_seq_len, pad=-1) for mapping in slot_mapping @@ -318,6 +401,14 @@ def _prepare_inputs( for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) + flat_padded_lora_input_mapping = [ + item for sublist in padded_lora_input_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_padded_lora_input_mapping, + lora_prompt_mapping, + ) + input_metadata = InputMetadata( seq_groups=seq_groups, seq_data=seq_data, @@ -330,7 +421,7 @@ def _prepare_inputs( categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, ) - return tokens_tensor, positions_tensor, input_metadata + return tokens_tensor, positions_tensor, input_metadata, lora_mapping, lora_requests @torch.inference_mode() def execute_model( @@ -362,8 +453,20 @@ def execute_model( return {} # Prepare input tensors. - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seq_group_metadata_list) + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + lora_requests, + ) = self._prepare_inputs(seq_group_metadata_list) + + if self.lora_config: + lora_requests = [ + seq_group_metadata.lora_request + for seq_group_metadata in seq_group_metadata_list + ] + self.apply_loras(lora_requests, lora_mapping) # Execute the model. output = self.model( @@ -375,6 +478,19 @@ def execute_model( ) return output + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self.lora_manager.apply_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.lora_manager.list_loras() + def _init_distributed_environment( parallel_config: ParallelConfig, From dd1726f61d1b356fafca130458ae869fb05665e4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 27 Nov 2023 17:15:04 -0800 Subject: [PATCH 24/94] Lint --- vllm/model_executor/model_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index bfe72bbf8c1e5..dc3299dc36f29 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -102,5 +102,5 @@ def get_model(model_config: ModelConfig, else: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + model_config.load_format, model_config.revision) return model.eval() From 9a1781c7ee4cbb56dce5e8415e4b0958bf0df3de Mon Sep 17 00:00:00 2001 From: TJian Date: Wed, 29 Nov 2023 21:40:31 +0800 Subject: [PATCH 25/94] Checked CUDA ROCm Compatibility (#15) Features * Auto-code path selection * support llama2 * support squeezellm rocm * add documentation amd-installation.rst. Describing how to setup vllm ROCm version * format.sh all the code * add base amd.Dockerfile --------- Co-authored-by: tjtanaa Co-authored-by: kuanfu --- csrc/quantization.cpp | 17 ++- .../getting_started/amd-installation.rst | 104 +++++++++++++ patch_xformers-0.0.22.post7.rocm.sh | 22 +++ pyproject.toml | 2 +- requirements-rocm.txt | 16 ++ rocm.Dockerfile | 64 ++++++++ .../commonpy_xformers-0.0.22.post7.rocm.patch | 13 ++ .../flashpy_xformers-0.0.22.post7.rocm.patch | 134 +++++++++++++++++ setup.py | 138 +++++++++++++----- vllm/engine/arg_utils.py | 98 ++++++++----- vllm/engine/ray_utils.py | 4 +- vllm/model_executor/layers/attention.py | 4 + .../layers/quantization/__init__.py | 8 +- .../model_executor/layers/quantization/awq.py | 6 +- .../layers/quantization/squeezellm.py | 17 ++- 15 files changed, 560 insertions(+), 87 deletions(-) create mode 100644 docs/source/getting_started/amd-installation.rst create mode 100644 patch_xformers-0.0.22.post7.rocm.sh create mode 100644 requirements-rocm.txt create mode 100644 rocm.Dockerfile create mode 100644 rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch create mode 100644 rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch diff --git a/csrc/quantization.cpp b/csrc/quantization.cpp index dfe17a496c780..764188597fd30 100644 --- a/csrc/quantization.cpp +++ b/csrc/quantization.cpp @@ -1,11 +1,14 @@ #include -torch::Tensor awq_gemm( - torch::Tensor _in_feats, - torch::Tensor _kernel, - torch::Tensor _scaling_factors, - torch::Tensor _zeros, - int split_k_iters); + +#ifndef USE_ROCM + torch::Tensor awq_gemm( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int split_k_iters); +#endif void squeezellm_gemm( torch::Tensor vec, @@ -14,6 +17,8 @@ void squeezellm_gemm( torch::Tensor lookup_table); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#ifndef USE_ROCM m.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); +#endif m.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); } diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst new file mode 100644 index 0000000000000..862b8a1d6e89b --- /dev/null +++ b/docs/source/getting_started/amd-installation.rst @@ -0,0 +1,104 @@ +.. _installation: + +Installation with ROCm +============ + +vLLM-ROCm is here! Currently it is supporting llama-2. + +Requirements +------------ + +* OS: Linux +* Python: 3.8 -- 3.11 (Recommended 3.10 as this is the version that has been tested on.) +* GPU: MI210 +* Pytorch 2.0.1/2.1.1 +* ROCm 5.7 + + +Install with pip +---------------- + +You can install vLLM using pip: + +.. code-block:: console + + $ # (Optional) Create a new conda environment. + $ conda create -n myenv python=3.8 -y + $ conda activate myenv + + $ # Install vLLM with CUDA 12.1. + $ pip install vllm + +.. note:: + + As of now, vLLM's binaries are compiled on CUDA 12.1 by default. + However, you can install vLLM with CUDA 11.8 by running: + + .. code-block:: console + + $ # Install vLLM with CUDA 11.8. + $ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`). + $ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl + + $ # Re-install PyTorch with CUDA 11.8. + $ pip uninstall torch -y + $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 + + +.. _build_from_source: + +Build from source with docker +----------------- + +You can also build and install vLLM from source: + +Build a docker image from `rocm.Dockerfile`, and launch a docker container. + +.. code-block:: console + + $ docker build -f rocm.Dockerfile -t vllm-rocm . + $ docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/hf_model \ + vllm-rocm \ + bash + +If you are going to setup on new pytorch+rocm5.7 docker container, you can follow the following steps. + +1. Install flash-attention-2-rocm + + If you are using Pytorch-2.0.1+rocm5.7. + + Install flash-attention-2 (v2.0.4) following the instruction from [ROCmSoftwarePlatform/flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm) + + + If you are using Pytorch-2.1.x+rocm5.7 or Pytorch-2.2.x+rocm5.7, you don't need to apply the `hipify_python.patch`. + You can directly build the flash-attention-2. + + .. code-block:: console + + $ bash patch_torch211_flash_attn2.rocm.sh + + .. note:: + - Flash-attention-2 (v2.0.4) does not support sliding windows attention. + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) + +2. Setup xformers==0.0.22.post7 without dependencies, and apply patches + + .. code-block:: console + + $ pip install xformers==0.0.22.post7 --no-deps + $ bash patch_xformers-0.0.22.post7.rocm.sh + +3. Build vllm. + + .. code-block:: console + $ cd vllm + $ python setup.py install # This may take 5-10 minutes. diff --git a/patch_xformers-0.0.22.post7.rocm.sh b/patch_xformers-0.0.22.post7.rocm.sh new file mode 100644 index 0000000000000..c8e58f721ae8c --- /dev/null +++ b/patch_xformers-0.0.22.post7.rocm.sh @@ -0,0 +1,22 @@ +#!/bin/bash +export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') +export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') + +echo $XFORMERS_FMHA_FLASH_PATH +echo $XFORMERS_FMHA_COMMON_PATH + +if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then + echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" + patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch" + echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" +else + echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" +fi + +if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then + echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" + patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch" + echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" +else + echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" +fi \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e3e3e389f7897..f9390ed8c52a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "ninja", "packaging", "setuptools", - "torch >= 2.1.0", + # "torch >= 2.1.0", # commented out to accommodate ROCm "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-rocm.txt b/requirements-rocm.txt new file mode 100644 index 0000000000000..53fd3ea24d927 --- /dev/null +++ b/requirements-rocm.txt @@ -0,0 +1,16 @@ +ninja # For faster builds. +typing-extensions>=4.8.0 +starlette +psutil +ray >= 2.5.1 +pandas # Required for Ray data. +pyarrow # Required for Ray data. +sentencepiece # Required for LLaMA tokenizer. +numpy +tokenizers>=0.15.0 +huggingface_hub<0.18,>=0.16.4 +einops # Required for phi-1_5 +transformers >= 4.34.0 # Required for Mistral. +fastapi +uvicorn[standard] +pydantic == 1.10.13 # Required for OpenAI server. diff --git a/rocm.Dockerfile b/rocm.Dockerfile new file mode 100644 index 0000000000000..27b57097740d2 --- /dev/null +++ b/rocm.Dockerfile @@ -0,0 +1,64 @@ +FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 + +# Install some basic utilities +RUN apt-get update && apt-get install python3 python3-pip -y + +# Install some basic utilities +RUN apt-get update && apt-get install -y \ + curl \ + ca-certificates \ + sudo \ + git \ + bzip2 \ + libx11-6 \ + build-essential \ + wget \ + unzip \ + nvidia-cuda-toolkit \ + tmux \ + && rm -rf /var/lib/apt/lists/* + +### Mount Point ### +# When launching the container, mount the code directory to /app +ARG APP_MOUNT=/app +VOLUME [ ${APP_MOUNT} ] +WORKDIR ${APP_MOUNT} + +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers + +ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer +ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: +ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: +ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: +ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101 + +# Install ROCm flash-attention +RUN mkdir libs \ + && cd libs \ + && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ + && cd flash-attention \ + && git submodule update --init \ + && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ + && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ + && python3 setup.py install \ + && cd .. + +COPY ./ /app/vllm-rocm/ + +# RUN cd /app \ +# && cd vllm-rocm \ +# && git checkout v0.2.1.post1-rocm \ +# && python3 setup.py install \ +# && cd .. + +# RUN cd /app \ +# && mkdir dataset \ +# && cd .. + +# COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh + +RUN python3 -m pip install --upgrade pip +# RUN python3 -m pip install --no-cache-dir ray[all] + +CMD ["/bin/bash"] \ No newline at end of file diff --git a/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch new file mode 100644 index 0000000000000..4d7495cf13e1d --- /dev/null +++ b/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch @@ -0,0 +1,13 @@ +--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000 ++++ common.py 2023-11-28 16:14:19.846233146 +0000 +@@ -298,8 +298,8 @@ + dtype = d.query.dtype + if device_type not in cls.SUPPORTED_DEVICES: + reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})") +- if device_type == "cuda" and not _built_with_cuda: +- reasons.append("xFormers wasn't build with CUDA support") ++ #if device_type == "cuda" and not _built_with_cuda: ++ # reasons.append("xFormers wasn't build with CUDA support") + if device_type == "cuda": + device_capability = torch.cuda.get_device_capability(d.device) + if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY: diff --git a/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch new file mode 100644 index 0000000000000..4798f1efd461c --- /dev/null +++ b/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch @@ -0,0 +1,134 @@ +--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 ++++ flash.py 2023-11-28 16:14:25.206128903 +0000 +@@ -31,39 +31,39 @@ + + FLASH_VERSION = "0.0.0" + try: +- try: +- from ... import _C_flashattention # type: ignore[attr-defined] +- from ..._cpp_lib import _build_metadata +- +- if _build_metadata is not None: +- FLASH_VERSION = _build_metadata.flash_version +- except ImportError: +- import flash_attn +- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention +- +- FLASH_VERSION = flash_attn.__version__ +- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) +- if flash_ver_parsed < (2, 3): +- raise ImportError("Requires 2.3 for sliding window support") ++ #try: ++ # from ... import _C_flashattention # type: ignore[attr-defined] ++ # from ..._cpp_lib import _build_metadata ++ ++ # if _build_metadata is not None: ++ # FLASH_VERSION = _build_metadata.flash_version ++ #except ImportError: ++ import flash_attn ++ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention ++ ++ FLASH_VERSION = flash_attn.__version__ ++ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) ++ # if flash_ver_parsed < (2, 3): ++ # raise ImportError("Requires 2.3 for sliding window support") + + # create library so that flash-attn goes through the PyTorch Dispatcher +- _flash_lib = torch.library.Library("xformers_flash", "DEF") ++ #_flash_lib = torch.library.Library("xformers_flash", "DEF") + +- _flash_lib.define( +- "flash_fwd(Tensor query, Tensor key, Tensor value, " +- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " +- "int max_seqlen_q, int max_seqlen_k, " +- "float p, float softmax_scale, " +- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" +- ) +- +- _flash_lib.define( +- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " +- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " +- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " +- "int max_seqlen_q, int max_seqlen_k, " +- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" +- ) ++ #_flash_lib.define( ++ # "flash_fwd(Tensor query, Tensor key, Tensor value, " ++ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " ++ # "int max_seqlen_q, int max_seqlen_k, " ++ # "float p, float softmax_scale, " ++ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" ++ #) ++ ++ #_flash_lib.define( ++ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " ++ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " ++ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " ++ # "int max_seqlen_q, int max_seqlen_k, " ++ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" ++ #) + + def _flash_fwd( + query, +@@ -98,8 +98,8 @@ + p, + softmax_scale, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + return_softmax, + None, # rng + ) +@@ -127,8 +127,8 @@ + softmax_scale, + False, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + return_softmax, + None, + ) +@@ -169,8 +169,8 @@ + p, + softmax_scale, + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + None, + rng_state, + ) +@@ -193,15 +193,15 @@ + softmax_scale, + False, # zero_tensors + is_causal, +- window_size - 1, # window_size_left +- -1, # window_size_right ++ # window_size - 1, # window_size_left ++ # -1, # window_size_right + None, + rng_state, + ) + return dq, dk, dv + +- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") +- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") ++ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA") ++ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA") + except ImportError: + pass + +@@ -348,7 +348,7 @@ + implementation. + """ + +- OPERATOR = get_operator("xformers_flash", "flash_fwd") ++ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd") + SUPPORTED_DEVICES: Set[str] = {"cuda"} + CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0) + SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16} diff --git a/setup.py b/setup.py index 0e28b93602772..b75d0912a9bb6 100644 --- a/setup.py +++ b/setup.py @@ -16,30 +16,64 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} -ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030","gfx1100"} -SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) +ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} +# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] -if torch.version.hip: - if ROCM_HOME is not None: - NVCC_FLAGS += [f"-DUSE_ROCM"] +if torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None: + NVCC_FLAGS += ["-DUSE_ROCM"] -if not torch.version.hip: - if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") +if torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is None: + raise RuntimeError( + "Cannot find CUDA_HOME. CUDA must be available to build the package.") ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] -if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") + +def get_amdgpu_offload_arch(): + error_message = "" + command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" + try: + output = subprocess.check_output([command]) + return output.decode('utf-8').strip() + except subprocess.CalledProcessError as e: + error_message = f"Error: {e}" + except FileNotFoundError: + # If the command is not found, print an error message + error_message = f"The command {command} was not found." + + if error_message: + raise RuntimeError(error_message) + + return None + + +def get_hipcc_rocm_version(): + # Run the hipcc --version command + result = subprocess.run(['hipcc', '--version'], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True) + + # Check if the command was executed successfully + if result.returncode != 0: + print("Error running 'hipcc --version'") + return None + + # Extract the version using a regular expression + match = re.search(r'HIP version: (\S+)', result.stdout) + if match: + # Return the version string + return match.group(1) + else: + print("Could not find HIP version in the output") + return None def get_nvcc_cuda_version(cuda_dir: str) -> Version: @@ -72,7 +106,9 @@ def get_torch_arch_list() -> Set[str]: return set() # Filter out the invalid architectures and print a warning. - valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in NVIDIA_SUPPORTED_ARCHS}) + valid_archs = NVIDIA_SUPPORTED_ARCHS.union( + {s + "+PTX" + for s in NVIDIA_SUPPORTED_ARCHS}) arch_list = torch_arch_list.intersection(valid_archs) # If none of the specified architectures are valid, raise an error. if not arch_list: @@ -93,24 +129,24 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if not torch.version.hip: - if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") - -if not torch.version.hip: +if torch.cuda.is_available( +) and torch.version.cuda and not compute_capabilities: + # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available + # GPUs on the current machine. + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability below 7.0 are not supported.") + compute_capabilities.add(f"{major}.{minor}") + +if torch.cuda.is_available() and torch.version.cuda: nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: # If no GPU is specified nor available, add all supported architectures # based on the NVCC CUDA version. - compute_capabilities = SUPPORTED_ARCHS.copy() + compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy() if nvcc_cuda_version < Version("11.1"): compute_capabilities.remove("8.6") if nvcc_cuda_version < Version("11.8"): @@ -118,7 +154,8 @@ def get_torch_arch_list() -> Set[str]: compute_capabilities.remove("9.0") # Validate the NVCC CUDA version. if nvcc_cuda_version < Version("11.0"): - raise RuntimeError("CUDA 11.0 or higher is required to build the package.") + raise RuntimeError( + "CUDA 11.0 or higher is required to build the package.") if (nvcc_cuda_version < Version("11.1") and any(cc.startswith("8.6") for cc in compute_capabilities)): raise RuntimeError( @@ -135,7 +172,7 @@ def get_torch_arch_list() -> Set[str]: "Targeting compute capability 8.0 instead.", stacklevel=2) compute_capabilities = set(cc for cc in compute_capabilities - if not cc.startswith("8.9")) + if not cc.startswith("8.9")) compute_capabilities.add("8.0+PTX") if any(cc.startswith("9.0") for cc in compute_capabilities): raise RuntimeError( @@ -146,13 +183,22 @@ def get_torch_arch_list() -> Set[str]: num = capability[0] + capability[2] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): - NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] + NVCC_FLAGS += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): num_threads = min(os.cpu_count(), 8) NVCC_FLAGS += ["--threads", str(num_threads)] +elif torch.cuda.is_available() and torch.version.hip: + amd_arch = get_amdgpu_offload_arch() + if amd_arch not in ROCM_SUPPORTED_ARCHS: + raise RuntimeError( + f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" + f"amdgpu_arch_found: {amd_arch}") + ext_modules = [] # Cache operations. @@ -211,7 +257,7 @@ def get_torch_arch_list() -> Set[str]: ext_modules.append(activation_extension) # Quantization kernels. -if not torch.version.hip: +if torch.cuda.is_available() and torch.version.cuda: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ @@ -224,7 +270,7 @@ def get_torch_arch_list() -> Set[str]: "nvcc": NVCC_FLAGS, }, ) -else: +elif torch.cuda.is_available() and torch.version.hip: quantization_extension = CUDAExtension( name="vllm.quantization_ops", sources=[ @@ -269,10 +315,20 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - # cuda_version = str(nvcc_cuda_version) - # if cuda_version != MAIN_CUDA_VERSION: - # cuda_version_str = cuda_version.replace(".", "")[:3] - # version += f"+cu{cuda_version_str}" + + if torch.cuda.is_available() and torch.version.cuda: + cuda_version = str(nvcc_cuda_version) + if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + version += f"+cu{cuda_version_str}" + + elif torch.cuda.is_available() and torch.version.hip: + # Get the HIP version + hipcc_version = get_hipcc_rocm_version() + if hipcc_version != MAIN_CUDA_VERSION: + rocm_version_str = hipcc_version.replace(".", "")[:3] + version += f"+rocm{rocm_version_str}" + return version @@ -287,8 +343,14 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" - with open(get_path("requirements.txt")) as f: - requirements = f.read().strip().split("\n") + if torch.cuda.is_available() and torch.version.hip: + with open(get_path("requirements-rocm.txt")) as f: + requirements = f.read().strip().split("\n") + elif torch.cuda.is_available() and torch.version.cuda: + with open(get_path("requirements.txt")) as f: + requirements = f.read().strip().split("\n") + print("requirements: ", requirements) + # exit() return requirements diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c7e476c704740..c7612b3ac4072 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,6 +3,8 @@ from dataclasses import dataclass from typing import Optional, Tuple +import torch + from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -83,32 +85,52 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument( - '--dtype', - type=str, - default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') + if torch.cuda.is_available() and torch.version.hip: + # do something specific for HIP + parser.add_argument( + '--load-format', + type=str, + default='pt', + choices=['pt'], + help='The format of the model weights to load. ' + '"pt" will load the weights in the pytorch bin format. ') + parser.add_argument( + '--dtype', + type=str, + default='half', + choices=['half', 'float16', 'bfloat16'], + help='data type for model weights and activations. ' + 'The default option is FP16 precision ' + 'Supports FP16 and BF16 ') + elif torch.cuda.is_available() and torch.version.cuda: + # do something specific for CUDA + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument('--max-model-len', type=int, default=None, @@ -171,13 +193,23 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') - # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'squeezellm', None], - default=None, - help='Method used to quantize the weights') + if torch.cuda.is_available() and torch.version.hip: + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['squeezellm', None], + default=None, + help='Method used to quantize the weights') + + elif torch.cuda.is_available() and torch.version.cuda: + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', 'squeezellm', None], + default=None, + help='Method used to quantize the weights') return parser @classmethod diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index ee58b8b9074a7..6bff8153e2a5c 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -73,7 +73,9 @@ def initialize_cluster( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=ray_address, ignore_reinit_error=True) + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) if not parallel_config.worker_use_ray: # Initialize cluster locally. diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index e51bb311decd9..2e042721d9a27 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -113,6 +113,8 @@ def multi_query_kv_attention( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.view_as(output)) @@ -451,6 +453,8 @@ def multi_query_kv_attention( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) # TODO(woosuk): Unnecessary copy. Optimize. output.copy_(out.view_as(output)) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3d937ba64f9fa..f4d25566cf599 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,14 +1,16 @@ from typing import Type - -from vllm.model_executor.layers.quantization.awq import AWQConfig +import torch from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig _QUANTIZATION_CONFIG_REGISTRY = { - "awq": AWQConfig, "squeezellm": SqueezeLLMConfig, } +if torch.cuda.is_available() and torch.version.cuda: + from vllm.model_executor.layers.quantization.awq import AWQConfig + _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig + def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in _QUANTIZATION_CONFIG_REGISTRY: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 0ab5819d930aa..5561a9309fc32 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -2,8 +2,12 @@ import torch from torch.nn.parameter import Parameter +if torch.cuda.is_available() and torch.version.hip: + # do something specific for HIP + print("Warning: vLLM does not support AWQ on ROCm.") +elif torch.cuda.is_available() and torch.version.cuda: + from vllm import quantization_ops -from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 61ec8b79b6ddc..be318b2ef205b 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -114,10 +114,19 @@ def apply_weights(self, lookup_table = weights["lookup_table"] out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, - lookup_table) + if torch.cuda.is_available() and torch.version.hip: + out_float = torch.zeros(out_shape, + device="cuda", + dtype=torch.float) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out_float, + lookup_table) + out = out_float.to(dtype=torch.float16) + # do something specific for HIP + elif torch.cuda.is_available() and torch.version.cuda: + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, + lookup_table) if bias is not None: out = out + bias From 7dbf2d409e6dec885c54607f6ded41408d84f9d5 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 29 Nov 2023 16:10:48 +0000 Subject: [PATCH 26/94] format code --- setup.py | 5 ++--- vllm/model_executor/layers/quantization/squeezellm.py | 6 ++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index b8986b670f4ff..8f0b357e0b282 100644 --- a/setup.py +++ b/setup.py @@ -201,7 +201,7 @@ def get_torch_arch_list() -> Set[str]: ext_modules = [] -vllm_extension_sources=[ +vllm_extension_sources = [ "csrc/cache_kernels.cu", "csrc/attention/attention_kernels.cu", "csrc/pos_encoding_kernels.cu", @@ -212,10 +212,9 @@ def get_torch_arch_list() -> Set[str]: "csrc/pybind.cpp", ] - if torch.cuda.is_available() and torch.version.cuda: vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") - + vllm_extension = CUDAExtension( name="vllm._C", sources=vllm_extension_sources, diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index e83dec8aabb1c..b38615cbdbb83 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -118,15 +118,13 @@ def apply_weights(self, out_float = torch.zeros(out_shape, device="cuda", dtype=torch.float) - ops.squeezellm_gemm(reshaped_x, qweight, out_float, - lookup_table) + ops.squeezellm_gemm(reshaped_x, qweight, out_float, lookup_table) out = out_float.to(dtype=torch.float16) # do something specific for HIP elif torch.cuda.is_available() and torch.version.cuda: # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - ops.squeezellm_gemm(reshaped_x, qweight, out, - lookup_table) + ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) if bias is not None: out = out + bias From 52ffcf0b7df38c5183d7b11e7141c5a401e357e9 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 29 Nov 2023 16:16:50 +0000 Subject: [PATCH 27/94] downgrade torch requirement in toml to torch 2.0.1 to accommodate ROCm support --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index f9390ed8c52a9..4548f53a2421c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "ninja", "packaging", "setuptools", - # "torch >= 2.1.0", # commented out to accommodate ROCm + "torch >= 2.0.1", # downgrade to accommodate ROCm supported version "wheel", ] build-backend = "setuptools.build_meta" From 6c66b6e7527a3f23a0af6b9777c6a71e52c52ac2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 10:56:35 -0800 Subject: [PATCH 28/94] Add rank check --- csrc/punica/bgmv/bgmv_config.h | 6 ++++-- vllm/config.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 3fd56b685be13..2c77663c0c617 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -44,10 +44,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ +// Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 128) // clang-format on diff --git a/vllm/config.py b/vllm/config.py index eef6e53be2855..2b0a767ff08bf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -359,6 +359,14 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 max_loras: Optional[int] = None + def __post_init__(self): + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + possible_max_ranks = (8, 16, 32, 64, 128) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): self.lora_dtype = model_config.dtype From 70eaca69ca4b4069c74d58f79d457f8d733b3d08 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 11:54:44 -0800 Subject: [PATCH 29/94] Add example, minor tweaks --- examples/multilora_inference.py | 91 +++++++++++++++++++++++++++++ vllm/engine/arg_utils.py | 12 ++-- vllm/engine/llm_engine.py | 18 ++++-- vllm/lora/models.py | 4 +- vllm/lora/request.py | 12 ++++ vllm/model_executor/model_loader.py | 5 +- vllm/worker/worker.py | 4 +- 7 files changed, 128 insertions(+), 18 deletions(-) create mode 100644 examples/multilora_inference.py diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py new file mode 100644 index 0000000000000..65885e534b508 --- /dev/null +++ b/examples/multilora_inference.py @@ -0,0 +1,91 @@ +""" +This example shows how to use the multi-LoRA functionality for offline inference. + +Requires HuggingFace credentials for access to Llama2. +""" + +from typing import Optional, List, Tuple + +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput +from vllm.lora.request import LoRARequest + + +def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 2 requests for the LoRA. + + In this example, we only use one LoRA adapter. However, we could + specify multiple adapters and use them in the same way. + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + presence_penalty=0.2, + max_tokens=128), None), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print(request_output) + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_num_seqs=32) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + +if __name__ == '__main__': + main() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4d1233c473980..6dc695aaa554c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,9 +34,9 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enable_lora: bool = False - max_lora_rank: int = 8 + max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 - lora_dtype = 'bfloat16' + lora_dtype = 'auto' lora_max_cpu_loras: int = -1 def __post_init__(self): @@ -193,21 +193,21 @@ def add_cli_args( help='enable lora adapters') parser.add_argument('--max-lora-rank', type=int, - default=16, + default=EngineArgs.max_lora_rank, help='max LoRA rank') parser.add_argument('--lora-extra-vocab-size', type=int, - default=256, + default=EngineArgs.lora_extra_vocab_size, help='LoRA extra vocab size') parser.add_argument('--lora-dtype', type=str, - default=EngineArgs.dtype, + default=EngineArgs.lora_dtype, choices=['auto', 'float16', 'bfloat16', 'float32'], help='data type for lora') parser.add_argument( '--lora-max-cpu-loras', type=int, - default=-1, + default=EngineArgs.lora_max_cpu_loras, help=('Maximum number of loras to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c6e74b1d26586..74e18561e6401 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -369,11 +369,13 @@ def _check_beam_search_early_stopping( current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.get_lora_tokenizer( + current_worst_seq.lora_request).eos_token_id)) if early_stopping is False: highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.get_lora_tokenizer( + best_running_seq.lora_request).eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -387,7 +389,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.get_lora_tokenizer( + best_running_seq.lora_request).eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -396,7 +399,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.get_lora_tokenizer( + best_running_seq.lora_request).eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -487,7 +491,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request + ).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -515,7 +520,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request + ).eos_token_id), reverse=True) # Check if we can stop the beam search. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 913234475b182..60034bdbb6e6b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -321,10 +321,10 @@ def __init__( self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.sampler_indices = torch.empty(self.max_num_seqs, + self.sampler_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.sampler_indices_padded = torch.empty(self.max_num_seqs, + self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") self.embeddings_indices = torch.empty(2, diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 3ae5be59b1b88..5d45f8a0f396d 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -3,6 +3,18 @@ @dataclass class LoRARequest: + """ + Request for a LoRA adapter. + + Note that this class should be be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_id and lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + lora_id: str lora_int_id: int lora_local_path: str diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index dc3299dc36f29..0cd890615c918 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -89,11 +89,10 @@ def get_model(model_config: ModelConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - # TODO(yard1): Clean this up (lora_config) - try: + if getattr(model_class, "supports_lora", True): model = model_class(model_config.hf_config, linear_method, lora_config) - except TypeError: + else: model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d316b9588bf75..b18668007028f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -257,7 +257,9 @@ def _prepare_inputs( # during generation. lora_requests.add(seq_group_metadata.lora_request) lora_index_mapping.append([lora_id] * prompt_len) - lora_prompt_mapping.append(lora_id) + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len if sampling_params.prompt_logprobs else 1)) input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt From a3f191ae554edbf7212c62321bf5ce8d5702e375 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 12:23:59 -0800 Subject: [PATCH 30/94] Fix dummy lora init for packed layers --- vllm/lora/models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 60034bdbb6e6b..4c01748f52f95 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -503,16 +503,16 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: input_dim, output_dim, rank, - module.base_layer.weight.dtype, + module.lora_a_stacked.dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim) else: lora = _create_dummy_lora( module_name, - module.base_layer.weight.shape[1], - module.base_layer.weight.shape[0], + module.lora_a_stacked.shape[-1], + module.lora_b_stacked.shape[-2], rank, - module.base_layer.weight.dtype, + module.lora_a_stacked.dtype, "cpu", ) lora.optimize() @@ -520,13 +520,13 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] subloras = [] - for r in replacements: + for i, r in enumerate(replacements): lora = _create_dummy_lora( module_name + "." + r, - module.base_layer.weight.shape[1], - module.base_layer.weight.shape[0] // len(replacements), + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], rank, - module.base_layer.weight.dtype, + module.lora_a_stacked[i].dtype, "cpu", ) lora.optimize() From 240cee93b98655b8b6942d66063507c1ef18a967 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:09:26 -0800 Subject: [PATCH 31/94] Fix capacity --- vllm/lora/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 4c01748f52f95..3f3fe8b997677 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -458,14 +458,14 @@ def _create_lora_modules(self): new_module = replace_submodule( self.model, module_name, - from_layer(module, self.capacity, self.lora_config, + from_layer(module, self._lora_slots, self.lora_config, self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: sampler_module = self.model.get_submodule("sampler") new_module = replace_submodule( self.model, "sampler", - from_layer_sampler(sampler_module, module, self.capacity, + from_layer_sampler(sampler_module, module, self._lora_slots, self.lora_config, self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) From c4d57a531699c6e41926213f02a3008b8d6d5215 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:23:19 -0800 Subject: [PATCH 32/94] Lint --- vllm/lora/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3f3fe8b997677..ecb29c94c9821 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -465,8 +465,9 @@ def _create_lora_modules(self): sampler_module = self.model.get_submodule("sampler") new_module = replace_submodule( self.model, "sampler", - from_layer_sampler(sampler_module, module, self._lora_slots, - self.lora_config, self.model.config)) + from_layer_sampler(sampler_module, module, + self._lora_slots, self.lora_config, + self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, From 471f25a6b51e87cec00568b17e294d07c9087603 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:39:07 -0800 Subject: [PATCH 33/94] Remove rank 128 for now --- csrc/punica/bgmv/bgmv_config.h | 3 +-- vllm/config.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2c77663c0c617..da6e6a611ecaa 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -49,7 +49,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 128) + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) // clang-format on diff --git a/vllm/config.py b/vllm/config.py index 2b0a767ff08bf..0b03565a5031a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -361,7 +361,7 @@ class LoRAConfig: def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - possible_max_ranks = (8, 16, 32, 64, 128) + possible_max_ranks = (8, 16, 32, 64) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " From ccbb4b7f88395fea4362d42688a46712160adb0b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:57:38 -0800 Subject: [PATCH 34/94] Pass to scheduler --- vllm/core/scheduler.py | 10 +++++++--- vllm/engine/llm_engine.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f8fb4c6ea1518..fce3f2acb65f5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -2,7 +2,7 @@ import time from typing import Dict, Iterable, List, Optional, Tuple, Union, Set -from vllm.config import CacheConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.lora.request import LoRARequest @@ -73,11 +73,11 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, - lora_enabled: bool = False, + lora_config: Optional[LoRAConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.lora_enabled = lora_enabled + self.lora_config = lora_config self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -99,6 +99,10 @@ def __init__( # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74e18561e6401..8ba04181ff47b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -111,7 +111,7 @@ def __init__( self._init_cache() # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # Logging. self.last_logging_time = 0.0 From 5a1a0be6e06b1f155168023c28c8100fe6464143 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 15:52:13 -0800 Subject: [PATCH 35/94] Add simple scheduler support --- examples/multilora_inference.py | 20 ++++++++++-- vllm/config.py | 18 +++++------ vllm/core/scheduler.py | 55 +++++++++++++++++++++++++++------ vllm/engine/arg_utils.py | 6 ++++ vllm/lora/models.py | 23 +++++++------- vllm/lora/worker_manager.py | 8 ++--- vllm/sequence.py | 8 +++++ vllm/worker/worker.py | 13 +++++--- 8 files changed, 111 insertions(+), 40 deletions(-) diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 65885e534b508..9aa0edc35d32c 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -46,6 +46,21 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: max_tokens=128, stop_token_ids=[32003]), LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), ] @@ -68,14 +83,15 @@ def process_requests(engine: LLMEngine, for request_output in request_outputs: if request_output.finished: - print(request_output) + print(request_output.lora_request) def initialize_engine() -> LLMEngine: """Initialize the LLMEngine.""" engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", enable_lora=True, - max_num_seqs=32) + max_loras=1, + max_num_seqs=256) return LLMEngine.from_engine_args(engine_args) diff --git a/vllm/config.py b/vllm/config.py index 0b03565a5031a..007f1026affa0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -354,10 +354,10 @@ def _verify_args(self) -> None: @dataclass class LoRAConfig: max_lora_rank: int + max_loras: int max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 - max_loras: Optional[int] = None def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h @@ -366,6 +366,14 @@ def __post_init__(self): raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " f"{possible_max_ranks}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({self.max_loras})") def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): @@ -380,14 +388,6 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): "max_num_batched_tokens must be <= 65528 when " "LoRA is enabled.") - self.max_loras = scheduler_config.max_num_seqs - if self.max_cpu_loras is None: - self.max_cpu_loras = scheduler_config.max_num_seqs - elif self.max_cpu_loras < scheduler_config.max_num_seqs: - raise ValueError( - f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_num_seqs ({scheduler_config.max_num_seqs})") - _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fce3f2acb65f5..e9081f6b7d726 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -77,6 +77,9 @@ def __init__( ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. self.lora_config = lora_config self.prompt_limit = min(self.scheduler_config.max_model_len, @@ -151,14 +154,16 @@ def _schedule(self) -> SchedulerOutputs: # requests in the generation phase. num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. - while self.waiting: - seq_group = self.waiting[0] - + waiting_indices_to_remove = [] + for i, seq_group in enumerate(self.waiting): assert seq_group.num_seqs() == 1, ( "Waiting sequence group should have only one prompt " "sequence.") @@ -170,7 +175,7 @@ def _schedule(self) -> SchedulerOutputs: for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + waiting_indices_to_remove.append(i) continue # If the sequence group cannot be allocated, stop. @@ -184,9 +189,18 @@ def _schedule(self) -> SchedulerOutputs: for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + waiting_indices_to_remove.append(i) continue + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue + # If the number of batched tokens exceeds the limit, stop. new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) @@ -206,12 +220,17 @@ def _schedule(self) -> SchedulerOutputs: break seq_lens = new_seq_lens - seq_group = self.waiting.pop(0) + waiting_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) + for i in reversed(waiting_indices_to_remove): + self.waiting.pop(i) + if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, @@ -260,9 +279,22 @@ def _schedule(self) -> SchedulerOutputs: if not preempted: num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + swapped_indices_to_remove = [] + + for i, seq_group in enumerate(self.swapped): + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue - while self.swapped: - seq_group = self.swapped[0] # If the sequence group cannot be swapped in, stop. if not self.block_manager.can_swap_in(seq_group): break @@ -274,12 +306,17 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - seq_group = self.swapped.pop(0) + swapped_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) self._swap_in(seq_group, blocks_to_swap_in) self._append_slot(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs self.running.append(seq_group) + for i in reversed(swapped_indices_to_remove): + self.swapped.pop(i) + # Each sequence in the generation phase only takes one token slot. # Therefore, the number of batched tokens is equal to the number of # sequences in the RUNNING state. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6dc695aaa554c..a8c9d87215737 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enable_lora: bool = False + max_loras: int = 1 max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 lora_dtype = 'auto' @@ -191,6 +192,10 @@ def add_cli_args( parser.add_argument('--enable-lora', action='store_true', help='enable lora adapters') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='max number of LoRAs in a single batch') parser.add_argument('--max-lora-rank', type=int, default=EngineArgs.max_lora_rank, @@ -244,6 +249,7 @@ def create_engine_configs( self.max_paddings) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > diff --git a/vllm/lora/models.py b/vllm/lora/models.py index ecb29c94c9821..bdf100f52b5c3 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -314,9 +314,9 @@ def __init__( """ self.lora_config = lora_config self.max_num_seqs = max_num_seqs - assert self.capacity >= self.max_num_seqs + assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_id_to_index: List[Optional[int]] = [None] * self._lora_slots + self.lora_id_to_index: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, @@ -353,8 +353,8 @@ def capacity(self) -> int: return self.lora_config.max_cpu_loras @property - def _lora_slots(self) -> int: - return self.max_num_seqs + def lora_slots(self) -> int: + return self.lora_config.max_loras def __len__(self) -> int: return len(self._registered_loras) @@ -421,7 +421,7 @@ def convert_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, indices_len) = convert_mapping(mapping, self.lora_id_to_index, - self._lora_slots + 1, self.vocab_size, + self.lora_slots + 1, self.vocab_size, self.lora_config.lora_extra_vocab_size) self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) @@ -448,7 +448,7 @@ def get_lora(self, lora_id: int) -> Optional[LoRAModel]: def remove_all_loras(self) -> bool: """Remove all LoRAModels from the manager.""" self._registered_loras.clear() - self.lora_id_to_index = [None] * self._lora_slots + self.lora_id_to_index = [None] * self.lora_slots self._active_loras.clear() def _create_lora_modules(self): @@ -458,16 +458,15 @@ def _create_lora_modules(self): new_module = replace_submodule( self.model, module_name, - from_layer(module, self._lora_slots, self.lora_config, + from_layer(module, self.lora_slots, self.lora_config, self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: sampler_module = self.model.get_submodule("sampler") new_module = replace_submodule( self.model, "sampler", - from_layer_sampler(sampler_module, module, - self._lora_slots, self.lora_config, - self.model.config)) + from_layer_sampler(sampler_module, module, self.lora_slots, + self.lora_config, self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, @@ -604,7 +603,7 @@ def __init__( self._registered_loras: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_lora) self._active_loras: LoRALRUCache = LoRALRUCache( - self.max_num_seqs, self._deactivate_lora) + self.lora_slots, self._deactivate_lora) def list_loras(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" @@ -629,7 +628,7 @@ def activate_lora( lora_id: int, ) -> bool: if lora_id not in self._active_loras and len( - self._active_loras) >= self.max_num_seqs: + self._active_loras) >= self.lora_slots: self._active_loras.remove_oldest() result = super().activate_lora(lora_id) # We always touch to update the LRU cache order diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index be6f4cf0589bd..4b90c6a556285 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -137,7 +137,7 @@ def create_lora_adapter( lora_config=self.lora_config, lora_manager_cls=self._lora_manager_cls, ) - self._lora_manager = lora_manager + self._lora_manager: LoRAModelManager = lora_manager return lora_manager.model def apply_loras(self, lora_requests: List[LoRARequest], @@ -155,7 +155,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.max_num_seqs}).") + f"({self._lora_manager.lora_slots}).") new_loras = set(loras_map) loras_to_add = new_loras - loras_that_exist @@ -235,7 +235,7 @@ def create_lora_adapter( lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager = lora_manager + self._lora_manager: LRUCacheLoRAModelManager = lora_manager return lora_manager.model def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: @@ -247,7 +247,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.max_num_seqs}).") + f"({self._lora_manager.lora_slots}).") for lora in loras_map.values(): self.add_lora(lora) diff --git a/vllm/sequence.py b/vllm/sequence.py index 06170ab79d69a..036a697ab3491 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -137,6 +137,10 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -262,6 +266,10 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).data.prompt_token_ids + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b18668007028f..4cf04babd299f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -123,8 +123,9 @@ def profile_num_available_blocks( # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(max_num_seqs): + for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( lora_id=f"warmup_{lora_id}", @@ -134,6 +135,10 @@ def profile_num_available_blocks( self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] seqs = [] for group_id in range(max_num_seqs): @@ -146,8 +151,8 @@ def profile_num_available_blocks( seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, - lora_request=dummy_lora_requests[group_id] - if dummy_lora_requests else None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, ) seqs.append(seq) @@ -159,7 +164,7 @@ def profile_num_available_blocks( prepared_lora_requests, ) = self._prepare_inputs(seqs) - if dummy_lora_requests: + if self.lora_config: self.apply_loras(prepared_lora_requests, lora_mapping) # Execute the model. From 1b00e500f1a1500e25840f26ac5c814ce790d0cf Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 15:58:38 -0800 Subject: [PATCH 36/94] Update example --- examples/multilora_inference.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 9aa0edc35d32c..91b8675827b93 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -15,10 +15,11 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: """Create a list of test prompts with their sampling parameters. - 2 requests for base model, 2 requests for the LoRA. - - In this example, we only use one LoRA adapter. However, we could - specify multiple adapters and use them in the same way. + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + Since we also set `max_loras=1`, the expectation is that the requests + with the second LoRA adapter will be ran after all requests with the + first adapter have finished. """ return [ ("A robot may not injure a human being", @@ -88,9 +89,18 @@ def process_requests(engine: LLMEngine, def initialize_engine() -> LLMEngine: """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", enable_lora=True, max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, max_num_seqs=256) return LLMEngine.from_engine_args(engine_args) From 6bda3c369a867a2b6496e884a62a40519c42ca79 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 16:02:50 -0800 Subject: [PATCH 37/94] Fix --- vllm/lora/worker_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 4b90c6a556285..43c016c32193b 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -151,7 +151,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.max_num_seqs: + if len(loras_map) > self._lora_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " @@ -243,7 +243,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.max_num_seqs: + if len(loras_map) > self._lora_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " From de029618448c2d68c3a1e3e4af2624f0d6710bc8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 16:08:18 -0800 Subject: [PATCH 38/94] Update tests --- tests/lora/conftest.py | 3 ++- tests/lora/test_layers.py | 16 ++++++++++++---- tests/lora/test_llama.py | 4 ++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 263a2bc9d8156..31803b741bd2e 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -124,7 +124,8 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: get_model_old = get_model def get_model_patched(model_config, lora_config=None): - return get_model_old(model_config, LoRAConfig(max_lora_rank=8)) + return get_model_old(model_config, + LoRAConfig(max_loras=4, max_lora_rank=8)) with patch("vllm.worker.worker.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index fa6a18e8d93d2..319b33652b61a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,7 +170,9 @@ def create_random_inputs( @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) def test_embeddings(dist_init, num_loras) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_random_embedding_layer(): @@ -258,7 +260,9 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_random_embedding_layer(): @@ -495,7 +499,9 @@ def create_random_sampler_layer(): @pytest.mark.parametrize("orientation", ["row", "column"]) def test_linear_parallel(dist_init, num_loras, orientation) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_random_linear_parallel_layer(): @@ -589,7 +595,9 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [2, 3]) def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_column_parallel_packed_layer(): diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index 756fc55246092..4760c5cc1e950 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -43,6 +43,7 @@ def test_llama_lora(sql_lora_files, tp_size): llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=tp_size, worker_use_ray=True) @@ -85,6 +86,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): llm_tp1 = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=1, worker_use_ray=True) output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) @@ -95,6 +97,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=2, worker_use_ray=True) output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) @@ -107,6 +110,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): llm_tp4 = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=4, worker_use_ray=True) output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) From 849831e300417e4f2df96e3279727f5b6082e460 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 30 Nov 2023 11:47:27 -0800 Subject: [PATCH 39/94] Cleanup --- examples/multilora_inference.py | 2 +- vllm/engine/arg_utils.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 91b8675827b93..8fdd243af69ff 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -84,7 +84,7 @@ def process_requests(engine: LLMEngine, for request_output in request_outputs: if request_output.finished: - print(request_output.lora_request) + print(request_output) def initialize_engine() -> LLMEngine: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cfeb1600ac8c6..0dae1613690e5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,7 +38,7 @@ class EngineArgs: max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 lora_dtype = 'auto' - lora_max_cpu_loras: int = -1 + max_cpu_loras: int = -1 def __post_init__(self): if self.tokenizer is None: @@ -208,12 +208,12 @@ def add_cli_args( type=str, default=EngineArgs.lora_dtype, choices=['auto', 'float16', 'bfloat16', 'float32'], - help='data type for lora') + help='data type for LoRA') parser.add_argument( - '--lora-max-cpu-loras', + '--max-cpu-loras', type=int, - default=EngineArgs.lora_max_cpu_loras, - help=('Maximum number of loras to store in CPU memory. ' + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) return parser @@ -253,8 +253,8 @@ def create_engine_configs( max_loras=self.max_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, - max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > - 0 else None) if self.enable_lora else None + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras > 0 else None) if self.enable_lora else None return model_config, cache_config, parallel_config, scheduler_config, lora_config From 5cce6495ad54e6229196452285d92b104d6d06b3 Mon Sep 17 00:00:00 2001 From: miloice Date: Fri, 1 Dec 2023 08:14:10 +0000 Subject: [PATCH 40/94] Merged with changes in vllm main --- vllm/model_executor/layers/attention.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index b84af362efca6..7ce763313937b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -160,6 +160,8 @@ def forward( attn_bias=input_metadata.attn_bias, p=0.0, scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (torch.cuda.is_available() and torch.version.hip) else None, ) output = out.view_as(query) else: From 16d3ccc27467c1815e1b18abf45994430f964366 Mon Sep 17 00:00:00 2001 From: miloice Date: Fri, 1 Dec 2023 17:36:45 +0000 Subject: [PATCH 41/94] Updated Dockerfile, rocm installation guide and setuppy --- rocm.Dockerfile => Dockerfile.rocm | 23 +++---- .../getting_started/amd-installation.rst | 60 ++++++------------- setup.py | 31 +++++----- 3 files changed, 42 insertions(+), 72 deletions(-) rename rocm.Dockerfile => Dockerfile.rocm (78%) diff --git a/rocm.Dockerfile b/Dockerfile.rocm similarity index 78% rename from rocm.Dockerfile rename to Dockerfile.rocm index 27b57097740d2..2ef143d20f020 100644 --- a/rocm.Dockerfile +++ b/Dockerfile.rocm @@ -25,7 +25,7 @@ VOLUME [ ${APP_MOUNT} ] WORKDIR ${APP_MOUNT} RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers +RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: @@ -39,26 +39,21 @@ RUN mkdir libs \ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ && cd flash-attention \ && git submodule update --init \ - && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ + && export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \ + # && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ && python3 setup.py install \ && cd .. COPY ./ /app/vllm-rocm/ -# RUN cd /app \ -# && cd vllm-rocm \ -# && git checkout v0.2.1.post1-rocm \ -# && python3 setup.py install \ -# && cd .. - -# RUN cd /app \ -# && mkdir dataset \ -# && cd .. - -# COPY ./benchmark_throughput.sh /app/benchmark_throughput.sh +RUN cd /app \ + && cd vllm-rocm \ + && git checkout vllm-cuda-rocm \ + && python3 setup.py install \ + && cd .. -RUN python3 -m pip install --upgrade pip +# RUN python3 -m pip install --upgrade pip # RUN python3 -m pip install --no-cache-dir ray[all] CMD ["/bin/bash"] \ No newline at end of file diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 9c76edf7b872a..422b067b3dac4 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,46 +3,17 @@ Installation with ROCm ============ -vLLM-ROCm is here! Currently it is supporting llama-2. +vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. At the moment AWQ quantization is not supported, but SqueezeLLM has been incorporated. +The ROCm version is tested and verified on LLaMA-2 models. Requirements ------------ * OS: Linux -* Python: 3.8 -- 3.11 (Recommended 3.10 as this is the version that has been tested on.) -* GPU: MI210 +* Python: 3.8 -- 3.11 (Verified on 3.10) +* GPU: MI200s * Pytorch 2.0.1/2.1.1 -* ROCm 5.7 - - -Install with pip ----------------- - -You can install vLLM using pip: - -.. code-block:: console - - $ # (Optional) Create a new conda environment. - $ conda create -n myenv python=3.8 -y - $ conda activate myenv - - $ # Install vLLM with CUDA 12.1. - $ pip install vllm - -.. note:: - - As of now, vLLM's binaries are compiled on CUDA 12.1 by default. - However, you can install vLLM with CUDA 11.8 by running: - - .. code-block:: console - - $ # Install vLLM with CUDA 11.8. - $ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`). - $ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl - - $ # Re-install PyTorch with CUDA 11.8. - $ pip uninstall torch -y - $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 +* ROCm >= 5.7.0 .. _build_from_source: @@ -52,11 +23,11 @@ Build from source with docker You can also build and install vLLM from source: -Build a docker image from `rocm.Dockerfile`, and launch a docker container. +Build a docker image from `Dockerfile.rocm`, and launch a docker container. .. code-block:: console - $ docker build -f rocm.Dockerfile -t vllm-rocm . + $ docker build -f Dockerfile.rocm -t vllm-rocm . $ docker run -it \ --network=host \ --group-add=video \ @@ -66,17 +37,21 @@ Build a docker image from `rocm.Dockerfile`, and launch a docker container. --shm-size 8G \ --device /dev/kfd \ --device /dev/dri \ - -v :/app/hf_model \ + -v :/app/model \ vllm-rocm \ bash -If you are going to setup on new pytorch+rocm5.7 docker container, you can follow the following steps. +If you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. pytorch+rocm5.7), you can follow the steps below: + +0. Install prerequisites (skip if you are already in an environment/docker with the following installed): + - `ROCm `_ and + - `Pytorch `_ -1. Install flash-attention-2-rocm +1. Install flash attention for ROCm If you are using Pytorch-2.0.1+rocm5.7. - Install flash-attention-2 (v2.0.4) following the instruction from [ROCmSoftwarePlatform/flash-attention](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm) + Install flash-attention-2 (v2.0.4) following the instruction from `ROCmSoftwarePlatform/flash-attention `_ If you are using Pytorch-2.1.x+rocm5.7 or Pytorch-2.2.x+rocm5.7, you don't need to apply the `hipify_python.patch`. @@ -87,10 +62,10 @@ If you are going to setup on new pytorch+rocm5.7 docker container, you can follo $ bash patch_torch211_flash_attn2.rocm.sh .. note:: - - Flash-attention-2 (v2.0.4) does not support sliding windows attention. + - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -2. Setup xformers==0.0.22.post7 without dependencies, and apply patches +2. Setup xformers==0.0.22.post7 without dependencies, and apply patches to adapt for ROCm flash attention .. code-block:: console @@ -103,3 +78,4 @@ If you are going to setup on new pytorch+rocm5.7 docker container, you can follo $ cd vllm $ pip install -U -r requirements-rocm.txt $ python setup.py install # This may take 5-10 minutes. + diff --git a/setup.py b/setup.py index 8f0b357e0b282..bcfbc49b869db 100644 --- a/setup.py +++ b/setup.py @@ -24,10 +24,10 @@ # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] -if torch.cuda.is_available() and torch.version.hip and ROCM_HOME is not None: +if torch.version.hip and ROCM_HOME is not None: NVCC_FLAGS += ["-DUSE_ROCM"] -if torch.cuda.is_available() and torch.version.cuda and CUDA_HOME is None: +if torch.version.cuda and CUDA_HOME is None: raise RuntimeError( "Cannot find CUDA_HOME. CUDA must be available to build the package.") @@ -129,8 +129,7 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if torch.cuda.is_available( -) and torch.version.cuda and not compute_capabilities: +if torch.version.cuda and not compute_capabilities: # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available # GPUs on the current machine. device_count = torch.cuda.device_count() @@ -141,7 +140,7 @@ def get_torch_arch_list() -> Set[str]: "GPUs with compute capability below 7.0 are not supported.") compute_capabilities.add(f"{major}.{minor}") -if torch.cuda.is_available() and torch.version.cuda: +if torch.version.cuda: nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: # If no GPU is specified nor available, add all supported architectures @@ -192,7 +191,7 @@ def get_torch_arch_list() -> Set[str]: num_threads = min(os.cpu_count(), 8) NVCC_FLAGS += ["--threads", str(num_threads)] -elif torch.cuda.is_available() and torch.version.hip: +elif torch.version.hip: amd_arch = get_amdgpu_offload_arch() if amd_arch not in ROCM_SUPPORTED_ARCHS: raise RuntimeError( @@ -212,7 +211,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/pybind.cpp", ] -if torch.cuda.is_available() and torch.version.cuda: +if torch.version.cuda: vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension = CUDAExtension( @@ -246,18 +245,18 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - if torch.cuda.is_available() and torch.version.cuda: - cuda_version = str(nvcc_cuda_version) - if cuda_version != MAIN_CUDA_VERSION: - cuda_version_str = cuda_version.replace(".", "")[:3] - version += f"+cu{cuda_version_str}" - - elif torch.cuda.is_available() and torch.version.hip: + if torch.version.hip: # Get the HIP version hipcc_version = get_hipcc_rocm_version() if hipcc_version != MAIN_CUDA_VERSION: rocm_version_str = hipcc_version.replace(".", "")[:3] version += f"+rocm{rocm_version_str}" + else: + cuda_version = str(nvcc_cuda_version) + if cuda_version != MAIN_CUDA_VERSION: + cuda_version_str = cuda_version.replace(".", "")[:3] + version += f"+cu{cuda_version_str}" + return version @@ -273,10 +272,10 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" - if torch.cuda.is_available() and torch.version.hip: + if torch.version.hip: with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") - elif torch.cuda.is_available() and torch.version.cuda: + else: with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") print("requirements: ", requirements) From d764f9d7a45ff82e915e32a4c9928b8ffabfc0ea Mon Sep 17 00:00:00 2001 From: miloice Date: Sat, 2 Dec 2023 04:26:09 +0000 Subject: [PATCH 42/94] Updated amd installation guide and dockerfile --- Dockerfile.rocm | 10 ++++---- .../getting_started/amd-installation.rst | 24 +++++++++---------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 2ef143d20f020..79a1648d12cb4 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -40,20 +40,18 @@ RUN mkdir libs \ && cd flash-attention \ && git submodule update --init \ && export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \ - # && sed -i -e "s/--offload-arch=native/--offload-arch=$(/opt/rocm/llvm/bin/amdgpu-offload-arch)/g" setup.py \ && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ && python3 setup.py install \ && cd .. -COPY ./ /app/vllm-rocm/ +COPY ./ /app/vllm RUN cd /app \ - && cd vllm-rocm \ - && git checkout vllm-cuda-rocm \ + && cd vllm \ && python3 setup.py install \ && cd .. -# RUN python3 -m pip install --upgrade pip -# RUN python3 -m pip install --no-cache-dir ray[all] +RUN python3 -m pip install --upgrade pip +RUN python3 -m pip install --no-cache-dir ray[all] CMD ["/bin/bash"] \ No newline at end of file diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 422b067b3dac4..886724c6c0cfc 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -44,26 +44,23 @@ Build a docker image from `Dockerfile.rocm`, and launch a docker container. If you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. pytorch+rocm5.7), you can follow the steps below: 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): - - `ROCm `_ and - - `Pytorch `_ -1. Install flash attention for ROCm +- `ROCm `_ +- `Pytorch `_ - If you are using Pytorch-2.0.1+rocm5.7. +1. Install `flash attention for ROCm `_ - Install flash-attention-2 (v2.0.4) following the instruction from `ROCmSoftwarePlatform/flash-attention `_ + Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ + Note: If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - If you are using Pytorch-2.1.x+rocm5.7 or Pytorch-2.2.x+rocm5.7, you don't need to apply the `hipify_python.patch`. - You can directly build the flash-attention-2. - - .. code-block:: console +.. code-block:: console - $ bash patch_torch211_flash_attn2.rocm.sh + $ bash patch_torch211_flash_attn2.rocm.sh - .. note:: - - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) +.. note:: + - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) 2. Setup xformers==0.0.22.post7 without dependencies, and apply patches to adapt for ROCm flash attention @@ -75,6 +72,7 @@ If you plan to install vLLM-ROCm on a local machine or start from a fresh docker 3. Build vllm. .. code-block:: console + $ cd vllm $ pip install -U -r requirements-rocm.txt $ python setup.py install # This may take 5-10 minutes. From e798632f318f05a8bb28a530fd7d41a432c48a2b Mon Sep 17 00:00:00 2001 From: miloice Date: Sat, 2 Dec 2023 15:12:33 +0000 Subject: [PATCH 43/94] Added num_gpus for ray init in ROCm --- vllm/engine/ray_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 48564d964a6d4..d91cfbd75c4ea 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -4,6 +4,8 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger +import torch + logger = init_logger(__name__) try: @@ -73,9 +75,12 @@ def initialize_cluster( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=ray_address, - ignore_reinit_error=True, - num_gpus=parallel_config.world_size) + if torch.version.hip: + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address, ignore_reinit_error=True) if not parallel_config.worker_use_ray: # Initialize cluster locally. From 0e8129f89936ca0c99c27559b3055669dd7c8be8 Mon Sep 17 00:00:00 2001 From: miloice Date: Sat, 2 Dec 2023 15:14:47 +0000 Subject: [PATCH 44/94] Synced torch version with vllm main in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9b5101f98a44f..045d925cb1131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch >= 2.0.1", # downgrade to accommodate ROCm supported version + "torch >= 2.1.0", "wheel", ] build-backend = "setuptools.build_meta" From 2b3821b68c2c189f833661c472ba2e0a8f6c6433 Mon Sep 17 00:00:00 2001 From: miloice Date: Sat, 2 Dec 2023 15:27:05 +0000 Subject: [PATCH 45/94] Format code --- setup.py | 1 - vllm/engine/ray_utils.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index bcfbc49b869db..3be4fd4fc3950 100644 --- a/setup.py +++ b/setup.py @@ -257,7 +257,6 @@ def get_vllm_version() -> str: cuda_version_str = cuda_version.replace(".", "")[:3] version += f"+cu{cuda_version_str}" - return version diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index d91cfbd75c4ea..e67cc33774e71 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -77,8 +77,8 @@ def initialize_cluster( # Connect to a ray cluster. if torch.version.hip: ray.init(address=ray_address, - ignore_reinit_error=True, - num_gpus=parallel_config.world_size) + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) else: ray.init(address=ray_address, ignore_reinit_error=True) From 5793f30615a2ac7b809e91cde8da918d585d4387 Mon Sep 17 00:00:00 2001 From: miloice Date: Mon, 4 Dec 2023 07:30:08 +0000 Subject: [PATCH 46/94] Updated dockerfile.rocm and requirements-rocm.txt --- Dockerfile.rocm | 6 ++++++ requirements-rocm.txt | 1 + 2 files changed, 7 insertions(+) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 79a1648d12cb4..11a16aff5c4d3 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -38,6 +38,7 @@ RUN mkdir libs \ && cd libs \ && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ && cd flash-attention \ + && git checkout 3d2b6f5 \ && git submodule update --init \ && export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \ && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ @@ -46,8 +47,13 @@ RUN mkdir libs \ COPY ./ /app/vllm +RUN python3 -m pip install --upgrade pip +RUN pip install xformers==0.0.22.post7 --no-deps + RUN cd /app \ && cd vllm \ + && pip install -U -r requirements-rocm.txt \ + && bash patch_xformers-0.0.22.post7.rocm.sh \ && python3 setup.py install \ && cd .. diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 53fd3ea24d927..42ab9d14450d7 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -14,3 +14,4 @@ transformers >= 4.34.0 # Required for Mistral. fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. +aioprometheus[starlette] From b172cdd71d73f874fa6de0b744769002cce039f8 Mon Sep 17 00:00:00 2001 From: miloice Date: Mon, 4 Dec 2023 09:34:31 +0000 Subject: [PATCH 47/94] Disable mistral for ROCm --- vllm/model_executor/model_loader.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index fe6ff36edb882..67ecb74f2d3e9 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -38,6 +38,16 @@ "YiForCausalLM": YiForCausalLM, } +# Models to be disabled in ROCm +_ROCM_DISABLED_MODELS = [ + # ROCm's flash attention does not support sliding window attention + # in models such as Mistral + "MistralForCausalLM", +] +if torch.version.hip: + for rocm_model in _ROCM_DISABLED_MODELS: + del _MODEL_REGISTRY[rocm_model] + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -53,6 +63,10 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: for arch in architectures: if arch in _MODEL_REGISTRY: return _MODEL_REGISTRY[arch] + elif arch in _ROCM_DISABLED_MODELS: + raise ValueError( + f"Model architecture {arch} is not supported by ROCm for now. \n" + f"Supported architectures {list(_MODEL_REGISTRY.keys())}") raise ValueError( f"Model architectures {architectures} are not supported for now. " f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") From 9cd5b1838f0e6e2bee55f078c86f5e69af56f828 Mon Sep 17 00:00:00 2001 From: miloice Date: Mon, 4 Dec 2023 09:38:34 +0000 Subject: [PATCH 48/94] Format code --- vllm/model_executor/model_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 67ecb74f2d3e9..4132adf89a0bc 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -40,7 +40,7 @@ # Models to be disabled in ROCm _ROCM_DISABLED_MODELS = [ - # ROCm's flash attention does not support sliding window attention + # ROCm's flash attention does not support sliding window attention # in models such as Mistral "MistralForCausalLM", ] From 66540339898f4378f3572251099559042944c20d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Dec 2023 11:18:42 -0800 Subject: [PATCH 49/94] Do not pin memory in WSL --- vllm/lora/models.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e821ab14cbf62..653d3b924cc3a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -12,7 +12,7 @@ from torch import nn from vllm.config import LoRAConfig -from vllm.utils import LRUCache +from vllm.utils import LRUCache, in_wsl from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler from vllm.lora.lora import LoRA @@ -138,16 +138,21 @@ def _create_dummy_lora(module_name: str, dtype: torch.dtype, device: torch.device, embeddings_tensor_dim: Optional[int] = None) -> "LoRA": - lora_a = torch.zeros([input_dim, rank], dtype=dtype, device=device) - lora_b = torch.zeros([rank, output_dim], dtype=dtype, device=device) + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) embeddings_tensor = torch.rand( - 10, embeddings_tensor_dim, dtype=dtype, - device=device) if embeddings_tensor_dim else None - if str(device) == "cpu": - lora_a = lora_a.pin_memory() - lora_b = lora_b.pin_memory() - if embeddings_tensor is not None: - embeddings_tensor = embeddings_tensor.pin_memory() + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None return LoRA( module_name, rank=rank, @@ -191,6 +196,7 @@ def from_lora_tensors( target_embedding_padding: Optional[int] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and not in_wsl() loras: Dict[str, LoRA] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) @@ -204,7 +210,7 @@ def from_lora_tensors( lora_embeddings_tensor = embeddings[ EMBEDDING_MODULES[embeddings_module]].to( device=device, dtype=dtype) - if device == "cpu": + if pin_memory: lora_embeddings_tensor = ( lora_embeddings_tensor.pin_memory()) loras[module_name] = LoRA(module_name, rank, lora_alpha, None, @@ -212,7 +218,7 @@ def from_lora_tensors( if is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype).t() - if device == "cpu": + if pin_memory: loras[module_name].lora_a = loras[ module_name].lora_a.pin_memory() else: @@ -226,7 +232,7 @@ def from_lora_tensors( addition = target_embedding_padding - lora_b.shape[1] loras[module_name].lora_b = torch.nn.functional.pad( lora_b, (0, addition)) - if device == "cpu": + if pin_memory: loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() From b86f88a5a4c21730f14585a3ba3ada0844764d30 Mon Sep 17 00:00:00 2001 From: miloice Date: Tue, 5 Dec 2023 07:15:42 +0000 Subject: [PATCH 50/94] Revert to cuda kernels --- csrc/activation_kernels.cu | 7 +- csrc/attention/attention_kernels.cu | 48 +++------- csrc/attention/attention_utils.cuh | 3 +- csrc/attention/dtype_bfloat16.cuh | 19 +--- csrc/attention/dtype_float16.cuh | 90 ++----------------- csrc/cache_kernels.cu | 13 ++- csrc/cuda_compat.h | 19 ---- csrc/cuda_utils_kernels.cu | 3 - csrc/ops.h | 2 - csrc/pos_encoding_kernels.cu | 9 +- csrc/pybind.cpp | 4 - .../squeezellm/quant_cuda_kernel.cu | 76 +--------------- csrc/reduction_utils.cuh | 4 +- 13 files changed, 37 insertions(+), 260 deletions(-) delete mode 100644 csrc/cuda_compat.h diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 1cca2c5fccc18..89d1ba2d37dd8 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -1,7 +1,6 @@ #include #include -#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -19,8 +18,8 @@ __global__ void silu_and_mul_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); - const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]); + const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]); out[token_idx * d + idx] = silu(x) * y; } } @@ -58,7 +57,7 @@ __global__ void activation_kernel( const int d) { const int64_t token_idx = blockIdx.x; for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { - const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); + const scalar_t x = __ldg(&input[token_idx * d + idx]); out[token_idx * d + idx] = ACT_FN(x); } } diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 19ad43b554aff..78e8d8ecd6d41 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -23,11 +23,7 @@ #include -#ifndef USE_ROCM #define WARP_SIZE 32 -#else -#define WARP_SIZE 64 -#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -44,7 +40,7 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Compute the sum per warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } // Warp leaders store the data to shared memory. @@ -63,11 +59,11 @@ inline __device__ float block_sum(float* red_smem, float sum) { // Parallel reduction inside the warp. #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - sum += VLLM_SHFL_XOR_SYNC(sum, mask); + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); } // Broadcast to other threads. - return VLLM_SHFL_SYNC(sum, 0); + return __shfl_sync(uint32_t(-1), sum, 0); } // TODO(woosuk): Merge the last two dimensions of the grid. @@ -227,7 +223,7 @@ __device__ void paged_attention_kernel( // The 0-th thread of each thread group already has its max qk value. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } if (lane == 0) { red_smem[warp_idx] = qk_max; @@ -239,10 +235,10 @@ __device__ void paged_attention_kernel( qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); } // Broadcast the max qk value to all threads. - qk_max = VLLM_SHFL_SYNC(qk_max, 0); + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); // Get the sum of the exp values. float exp_sum = 0.f; @@ -330,7 +326,7 @@ __device__ void paged_attention_kernel( float acc = accs[i]; #pragma unroll for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { - acc += VLLM_SHFL_XOR_SYNC(acc, mask); + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); } accs[i] = acc; } @@ -496,7 +492,7 @@ __global__ void paged_attention_v2_reduce_kernel( // Reduce within the warp. #pragma unroll for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); } if (lane == 0) { red_smem[warp_idx] = max_logit; @@ -506,10 +502,10 @@ __global__ void paged_attention_v2_reduce_kernel( max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; #pragma unroll for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { - max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask)); } // Broadcast the max value to all threads. - max_logit = VLLM_SHFL_SYNC(max_logit, 0); + max_logit = __shfl_sync(uint32_t(-1), max_logit, 0); // Load rescaled exp sums to shared memory. float* shared_exp_sums = reinterpret_cast(shared_mem + sizeof(float) * num_partitions); @@ -542,10 +538,9 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#ifndef USE_ROCM #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ - (void*)vllm::paged_attention_v1_kernel, \ + vllm::paged_attention_v1_kernel, \ cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ @@ -562,27 +557,6 @@ __global__ void paged_attention_v2_reduce_kernel( q_stride, \ kv_block_stride, \ kv_head_stride); -#else -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - hipFuncSetAttribute( \ - (void*)vllm::paged_attention_v1_kernel, \ - hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - head_mapping_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); -#endif // TODO(woosuk): Tune NUM_THREADS. template< diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh index ff64c4bd8f80c..bb7df25b14f02 100644 --- a/csrc/attention/attention_utils.cuh +++ b/csrc/attention/attention_utils.cuh @@ -17,7 +17,6 @@ */ #pragma once -#include "../cuda_compat.h" #include "attention_dtypes.h" #include @@ -40,7 +39,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { float qk = sum(qk_vec); #pragma unroll for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { - qk += VLLM_SHFL_XOR_SYNC(qk, mask); + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); } return qk; } diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 31e0cee01d2e1..5786f77f7bca6 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -21,17 +21,8 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#ifndef USE_ROCM - #include - #include -#else - #include - #include - - typedef __hip_bfloat162 __nv_bfloat162; - typedef __hip_bfloat16 __nv_bfloat16; -#endif - +#include +#include #include namespace vllm { @@ -107,11 +98,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 assert(false); #else - #ifndef USE_ROCM - return a + b; - #else - return __hadd(a, b); - #endif + return a + b; #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 1ec2009c74c87..e67921128d52b 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -21,10 +21,6 @@ #include "attention_generic.cuh" #include "dtype_float32.cuh" -#ifdef USE_ROCM - #include -#endif - #include namespace vllm { @@ -67,114 +63,58 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { -#ifndef USE_ROCM uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); return b; -#else - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = a; - tmp.u16[1] = a; - return tmp.u32; -#endif } inline __device__ float half_to_float(uint16_t h) { -#ifndef USE_ROCM float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); return f; -#else - float f; - asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); - return f; -#endif } inline __device__ float2 half2_to_float2(uint32_t v) { -#ifndef USE_ROCM uint16_t lo, hi; asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); return make_float2(half_to_float(lo), half_to_float(hi)); -#else - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u32 = v; - float2 ret; - ret.x = half_to_float(tmp.u16[0]); - ret.y = half_to_float(tmp.u16[1]); - return ret; -#endif } inline __device__ uint16_t float_to_half(float f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); return tmp.u16[0]; -#else - uint16_t ret; - asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f)); - return ret; -#endif } inline __device__ uint32_t float2_to_half2(float2 f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 - asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); - #else - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); - asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); - #endif - return tmp.u32; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; - tmp.u16[0] = float_to_half(f.x); - tmp.u16[1] = float_to_half(f.y); - return tmp.u32; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif + return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; -#else - uint16_t c; - asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; -#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; -#else - uint32_t c; - asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; -#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -217,28 +157,16 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); return c; -#else - uint16_t c; - asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; -#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); return c; -#else - uint32_t c; - asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; -#endif } template<> @@ -343,15 +271,9 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { -#ifndef USE_ROCM uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); return d; -#else - uint32_t d; - asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); - return d; -#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { @@ -519,4 +441,4 @@ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 59bacffdf4642..3ad52b1681c0c 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -1,7 +1,6 @@ #include #include -#include "cuda_compat.h" #include "dispatch_utils.h" #include @@ -29,8 +28,8 @@ void swap_blocks( TORCH_CHECK(false, "Invalid device combination"); } - char *src_ptr = static_cast(src.data_ptr()); - char *dst_ptr = static_cast(dst.data_ptr()); + void *src_ptr = src.data_ptr(); + void *dst_ptr = dst.data_ptr(); const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); @@ -268,8 +267,8 @@ __global__ void gather_cached_kv_kernel( + head_offset * block_size + block_offset; - key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]); - value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]); + key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]); + value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]); } } @@ -334,8 +333,8 @@ __global__ void gather_cached_kv_kernel_optimized( src_key_indices[j] = src_key_idx; src_value_indices[j] = src_value_idx; - keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]); - values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]); + keys_to_store[j] = __ldg(&key_cache[src_key_idx]); + values_to_store[j] = __ldg(&value_cache[src_value_idx]); } #pragma unroll diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h deleted file mode 100644 index fcafc7e16687c..0000000000000 --- a/csrc/cuda_compat.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -#ifndef USE_ROCM - #define VLLM_LDG(arg) __ldg(arg) -#else - #define VLLM_LDG(arg) *(arg) -#endif - -#ifndef USE_ROCM - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask) -#else - #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) -#endif - -#ifndef USE_ROCM - #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); -#else - #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) -#endif \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 6c844a7f6c6ed..f1c30fe7ea99e 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,6 +1,3 @@ -#ifdef USE_ROCM - #include -#endif int get_device_attribute( int attribute, int device_id) diff --git a/csrc/ops.h b/csrc/ops.h index e12c34f0aafa5..cfb18fbefd7a9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -61,14 +61,12 @@ void gelu_fast( torch::Tensor& out, torch::Tensor& input); -#ifndef USE_ROCM torch::Tensor awq_gemm( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); -#endif void squeezellm_gemm( torch::Tensor vec, diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index e1dc711778ffb..0a5ec95f8c0d4 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -1,7 +1,6 @@ #include #include -#include "cuda_compat.h" #include "dispatch_utils.h" namespace vllm { @@ -20,14 +19,14 @@ inline __device__ void apply_rotary_embedding( // GPT-NeoX style rotary embedding. x_index = rot_offset; y_index = embed_dim + rot_offset; - cos = VLLM_LDG(cos_ptr + x_index); - sin = VLLM_LDG(sin_ptr + x_index); + cos = __ldg(cos_ptr + x_index); + sin = __ldg(sin_ptr + x_index); } else { // GPT-J style rotary embedding. x_index = 2 * rot_offset; y_index = 2 * rot_offset + 1; - cos = VLLM_LDG(cos_ptr + x_index / 2); - sin = VLLM_LDG(sin_ptr + x_index / 2); + cos = __ldg(cos_ptr + x_index / 2); + sin = __ldg(sin_ptr + x_index / 2); } const scalar_t x = arr[x_index]; diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b0120e9689773..9e31429690021 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -48,12 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); -#ifndef USE_ROCM // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); -#endif - - ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); // Cache ops diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index c2f77da15977a..1392b877397be 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -20,17 +20,9 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( -#ifndef USE_ROCM const half2* __restrict__ vec, -#else - const __half2* __restrict__ vec, -#endif const int* __restrict__ mat, -#ifndef USE_ROCM half2* __restrict__ mul, -#else - float2* __restrict__ mul, -#endif const __half* __restrict__ lookup_table, int height, int width, @@ -43,11 +35,7 @@ __global__ void NUQ4MatMulKernel( int row = BLOCKHEIGHT4 * blockIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; -#ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; -#else - __shared__ __half2 blockvec[blockwidth2]; -#endif __shared__ __half deq2[16][BLOCKWIDTH]; int off = threadIdx.x; @@ -58,13 +46,8 @@ __global__ void NUQ4MatMulKernel( } __half res; -#ifndef USE_ROCM half2 res2; half2 tmp2; -#else - __half2 res2; - __half2 tmp2; -#endif int i; int k; @@ -85,96 +68,48 @@ __global__ void NUQ4MatMulKernel( while (k < blockwidth2) { tmp1 = as_unsigned(mat[i]); -#ifndef USE_ROCM res2 = {}; tmp2 = {}; -#else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); - tmp2.x = __half_as_ushort(__float2half(0)); - tmp2.y = __half_as_ushort(__float2half(0)); -#endif lut_index1 = tmp1 & 0xF; lut_index2 = (tmp1 >> 4) & 0xF; -#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif res2 = __hfma2(tmp2, blockvec[k + 0], res2); lut_index1 = (tmp1 >> 8) & 0xF; lut_index2 = (tmp1 >> 12) & 0xF; -#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif res2 = __hfma2(tmp2, blockvec[k + 1], res2); lut_index1 = (tmp1 >> 16) & 0xF; lut_index2 = (tmp1 >> 20) & 0xF; -#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif res2 = __hfma2(tmp2, blockvec[k + 2], res2); lut_index1 = (tmp1 >> 24) & 0xF; lut_index2 = (tmp1 >> 28) & 0xF; -#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif res2 = __hfma2(tmp2, blockvec[k + 3], res2); -#ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); -#else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); -#endif i += width; k += 4; } // col%2 -> only set one of the two values -#ifndef USE_ROCM half2 res3 = {}; if (col % 2 == 0) { res3.x = res; } else { res3.y = res; } -#else - __half2 res3; - res3.x = __half_as_ushort(__float2half(0)); - res3.y = __half_as_ushort(__float2half(0)); - if (col % 2 == 0) { - res3.x = __half_as_ushort(res); - } else { - res3.y = __half_as_ushort(res); - } -#endif -#ifndef USE_ROCM atomicAdd(&mul[b * width / 2 + col / 2], res3); -#else - int tmp_addr = b * width / 2 + col / 2; - atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); - atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); -#endif } } @@ -201,22 +136,13 @@ void squeezellm_gemm( dim3 threads(BLOCKWIDTH); vllm::squeezellm::NUQ4MatMulKernel<<>>( -#ifndef USE_ROCM (half2*) vec.data(), -#else - (__half2*) vec.data_ptr(), -#endif mat.data_ptr(), -#ifndef USE_ROCM (half2*) mul.data(), (__half*) lookup_table.data(), -#else - (float2*) mul.data_ptr(), - (__half*) lookup_table.data_ptr(), -#endif height, width, batch, vec_height ); } #undef BLOCKWIDTH -#undef BLOCKHEIGHT4 \ No newline at end of file +#undef BLOCKHEIGHT4 diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 59ac565c585e8..bc35aa0424b56 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -16,14 +16,14 @@ * limitations under the License. */ #pragma once -#include "cuda_compat.h" + namespace vllm { template __inline__ __device__ T warpReduceSum(T val) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) - val += VLLM_SHFL_XOR_SYNC(val, mask); + val += __shfl_xor_sync(0xffffffff, val, mask, 32); return val; } From c4aa2af80fa90bf5bd3ad5a4ae7e00abd5effb8f Mon Sep 17 00:00:00 2001 From: miloice Date: Tue, 5 Dec 2023 07:33:24 +0000 Subject: [PATCH 51/94] Port latest kernels to ROCm --- csrc/attention/attention_kernels.cu | 26 +++++++ csrc/attention/dtype_bfloat16.cuh | 8 +- csrc/attention/dtype_float16.cuh | 64 ++++++++++------ csrc/cuda_compat.h | 2 +- csrc/cuda_utils_kernels.cu | 1 - csrc/ops.h | 2 + csrc/pybind.cpp | 4 + .../squeezellm/quant_cuda_kernel.cu | 76 ++++++++++++++++++- csrc/reduction_utils.cuh | 2 - setup.py | 17 ----- 10 files changed, 151 insertions(+), 51 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 1b1ddee234160..19ad43b554aff 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -23,7 +23,11 @@ #include +#ifndef USE_ROCM #define WARP_SIZE 32 +#else +#define WARP_SIZE 64 +#endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) @@ -538,6 +542,7 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm +#ifndef USE_ROCM #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ cudaFuncSetAttribute( \ (void*)vllm::paged_attention_v1_kernel, \ @@ -557,6 +562,27 @@ __global__ void paged_attention_v2_reduce_kernel( q_stride, \ kv_block_stride, \ kv_head_stride); +#else +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + hipFuncSetAttribute( \ + (void*)vllm::paged_attention_v1_kernel, \ + hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); +#endif // TODO(woosuk): Tune NUM_THREADS. template< diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh index 7f2b29de0d932..31e0cee01d2e1 100644 --- a/csrc/attention/dtype_bfloat16.cuh +++ b/csrc/attention/dtype_bfloat16.cuh @@ -110,13 +110,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { #ifndef USE_ROCM return a + b; #else - // See https://github.com/RadeonOpenCompute/ROCm/issues/2534 - hip_bfloat16 A, B; - __hip_bfloat16 c; - A.data = a.data; - B.data = b.data; - c.data = (A + B).data; - return c; + return __hadd(a, b); #endif #endif } diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index b9c9275aae3f6..1ec2009c74c87 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -67,9 +67,10 @@ struct FloatVec { // Utility functions for type conversions. inline __device__ uint32_t h0_h0(uint16_t a) { - uint32_t b; #ifndef USE_ROCM + uint32_t b; asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; #else union { uint32_t u32; @@ -77,19 +78,20 @@ inline __device__ uint32_t h0_h0(uint16_t a) { } tmp; tmp.u16[0] = a; tmp.u16[1] = a; - b = tmp.u32; + return tmp.u32; #endif - return b; } inline __device__ float half_to_float(uint16_t h) { - float f; #ifndef USE_ROCM + float f; asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; #else + float f; asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); -#endif return f; +#endif } inline __device__ float2 half2_to_float2(uint32_t v) { @@ -111,56 +113,68 @@ inline __device__ float2 half2_to_float2(uint32_t v) { } inline __device__ uint16_t float_to_half(float f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#ifndef USE_ROCM asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); + return tmp.u16[0]; #else - asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); + uint16_t ret; + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f)); + return ret; #endif - return tmp.u16[0]; } inline __device__ uint32_t float2_to_half2(float2 f) { +#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; -#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif + return tmp.u32; #else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; tmp.u16[0] = float_to_half(f.x); tmp.u16[1] = float_to_half(f.y); -#endif return tmp.u32; +#endif } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { - uint16_t c; #ifndef USE_ROCM + uint16_t c; asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; #else + uint16_t c; asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); -#endif return c; +#endif } inline __device__ uint32_t add(uint32_t a, uint32_t b) { - uint32_t c; #ifndef USE_ROCM + uint32_t c; asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; #else + uint32_t c; asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); -#endif return c; +#endif } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -203,24 +217,28 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { - uint16_t c; #ifndef USE_ROCM + uint16_t c; asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; #else + uint16_t c; asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); -#endif return c; +#endif } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { - uint32_t c; #ifndef USE_ROCM + uint32_t c; asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; #else + uint32_t c; asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); -#endif return c; +#endif } template<> @@ -325,13 +343,15 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { - uint32_t d; #ifndef USE_ROCM + uint32_t d; asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; #else - asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); -#endif + uint32_t d; + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); return d; +#endif } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { @@ -499,4 +519,4 @@ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm +} // namespace vllm \ No newline at end of file diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 8991462a862e8..fcafc7e16687c 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -16,4 +16,4 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); #else #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) -#endif +#endif \ No newline at end of file diff --git a/csrc/cuda_utils_kernels.cu b/csrc/cuda_utils_kernels.cu index 2439f5922a3f3..6c844a7f6c6ed 100644 --- a/csrc/cuda_utils_kernels.cu +++ b/csrc/cuda_utils_kernels.cu @@ -1,7 +1,6 @@ #ifdef USE_ROCM #include #endif - int get_device_attribute( int attribute, int device_id) diff --git a/csrc/ops.h b/csrc/ops.h index cfb18fbefd7a9..e12c34f0aafa5 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -61,12 +61,14 @@ void gelu_fast( torch::Tensor& out, torch::Tensor& input); +#ifndef USE_ROCM torch::Tensor awq_gemm( torch::Tensor _in_feats, torch::Tensor _kernel, torch::Tensor _scaling_factors, torch::Tensor _zeros, int split_k_iters); +#endif void squeezellm_gemm( torch::Tensor vec, diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index 9e31429690021..b0120e9689773 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -48,8 +48,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &rotary_embedding, "Apply GPT-NeoX or GPT-J style rotary embedding to query and key"); +#ifndef USE_ROCM // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); +#endif + + ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); // Cache ops diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index 1392b877397be..c2f77da15977a 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) { // 4-bit matvec kernel (LUT-based) __global__ void NUQ4MatMulKernel( +#ifndef USE_ROCM const half2* __restrict__ vec, +#else + const __half2* __restrict__ vec, +#endif const int* __restrict__ mat, +#ifndef USE_ROCM half2* __restrict__ mul, +#else + float2* __restrict__ mul, +#endif const __half* __restrict__ lookup_table, int height, int width, @@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel( int row = BLOCKHEIGHT4 * blockIdx.x; int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; +#ifndef USE_ROCM __shared__ half2 blockvec[blockwidth2]; +#else + __shared__ __half2 blockvec[blockwidth2]; +#endif __shared__ __half deq2[16][BLOCKWIDTH]; int off = threadIdx.x; @@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel( } __half res; +#ifndef USE_ROCM half2 res2; half2 tmp2; +#else + __half2 res2; + __half2 tmp2; +#endif int i; int k; @@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel( while (k < blockwidth2) { tmp1 = as_unsigned(mat[i]); +#ifndef USE_ROCM res2 = {}; tmp2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); + tmp2.x = __half_as_ushort(__float2half(0)); + tmp2.y = __half_as_ushort(__float2half(0)); +#endif lut_index1 = tmp1 & 0xF; lut_index2 = (tmp1 >> 4) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 0], res2); lut_index1 = (tmp1 >> 8) & 0xF; lut_index2 = (tmp1 >> 12) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 1], res2); lut_index1 = (tmp1 >> 16) & 0xF; lut_index2 = (tmp1 >> 20) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 2], res2); lut_index1 = (tmp1 >> 24) & 0xF; lut_index2 = (tmp1 >> 28) & 0xF; +#ifndef USE_ROCM tmp2.x = deq2[lut_index1][off]; tmp2.y = deq2[lut_index2][off]; +#else + tmp2.x = __half_as_ushort(deq2[lut_index1][off]); + tmp2.y = __half_as_ushort(deq2[lut_index2][off]); +#endif res2 = __hfma2(tmp2, blockvec[k + 3], res2); +#ifndef USE_ROCM res = __hadd(__hadd(res2.x, res2.y), res); +#else + res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res); +#endif i += width; k += 4; } // col%2 -> only set one of the two values +#ifndef USE_ROCM half2 res3 = {}; if (col % 2 == 0) { res3.x = res; } else { res3.y = res; } +#else + __half2 res3; + res3.x = __half_as_ushort(__float2half(0)); + res3.y = __half_as_ushort(__float2half(0)); + if (col % 2 == 0) { + res3.x = __half_as_ushort(res); + } else { + res3.y = __half_as_ushort(res); + } +#endif +#ifndef USE_ROCM atomicAdd(&mul[b * width / 2 + col / 2], res3); +#else + int tmp_addr = b * width / 2 + col / 2; + atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); + atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); +#endif } } @@ -136,13 +201,22 @@ void squeezellm_gemm( dim3 threads(BLOCKWIDTH); vllm::squeezellm::NUQ4MatMulKernel<<>>( +#ifndef USE_ROCM (half2*) vec.data(), +#else + (__half2*) vec.data_ptr(), +#endif mat.data_ptr(), +#ifndef USE_ROCM (half2*) mul.data(), (__half*) lookup_table.data(), +#else + (float2*) mul.data_ptr(), + (__half*) lookup_table.data_ptr(), +#endif height, width, batch, vec_height ); } #undef BLOCKWIDTH -#undef BLOCKHEIGHT4 +#undef BLOCKHEIGHT4 \ No newline at end of file diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index b95ccef162072..59ac565c585e8 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -16,9 +16,7 @@ * limitations under the License. */ #pragma once - #include "cuda_compat.h" - namespace vllm { template diff --git a/setup.py b/setup.py index dfc8bf4c8459e..3be4fd4fc3950 100644 --- a/setup.py +++ b/setup.py @@ -75,10 +75,6 @@ def get_hipcc_rocm_version(): print("Could not find HIP version in the output") return None -if not torch.version.hip: - if CUDA_HOME is None: - raise RuntimeError( - "Cannot find CUDA_HOME. CUDA must be available to build the package.") def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -130,19 +126,6 @@ def get_torch_arch_list() -> Set[str]: stacklevel=2) return arch_list -def get_cuda_compute_capabilities(nvcc_cuda_version): - # First, check the TORCH_CUDA_ARCH_LIST environment variable. - compute_capabilities = get_torch_arch_list() - if not compute_capabilities: - # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available - # GPUs on the current machine. - device_count = torch.cuda.device_count() - for i in range(device_count): - major, minor = torch.cuda.get_device_capability(i) - if major < 7: - raise RuntimeError( - "GPUs with compute capability below 7.0 are not supported.") - compute_capabilities.add(f"{major}.{minor}") # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() From f8c304ed45f136078eba4c5d1dd1f86328af639a Mon Sep 17 00:00:00 2001 From: miloice Date: Tue, 5 Dec 2023 07:41:34 +0000 Subject: [PATCH 52/94] Update readme --- docs/source/getting_started/amd-installation.rst | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 886724c6c0cfc..bc81a003d9fdd 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,8 +3,9 @@ Installation with ROCm ============ -vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. At the moment AWQ quantization is not supported, but SqueezeLLM has been incorporated. -The ROCm version is tested and verified on LLaMA-2 models. +vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. +At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. +Datatypes currently supported in ROCm are FP16 and BF16. Requirements ------------ From e608c30d74ff325788ec78df03b0e18510a2c3e7 Mon Sep 17 00:00:00 2001 From: miloice Date: Tue, 5 Dec 2023 09:15:54 +0000 Subject: [PATCH 53/94] Cleaned up kernel code --- csrc/attention/dtype_float16.cuh | 54 ++++++++++---------------------- csrc/reduction_utils.cuh | 2 ++ 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 1ec2009c74c87..421419a96f137 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -83,15 +83,13 @@ inline __device__ uint32_t h0_h0(uint16_t a) { } inline __device__ float half_to_float(uint16_t h) { -#ifndef USE_ROCM float f; +#ifndef USE_ROCM asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); - return f; #else - float f; asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); - return f; #endif + return f; } inline __device__ float2 half2_to_float2(uint32_t v) { @@ -113,68 +111,56 @@ inline __device__ float2 half2_to_float2(uint32_t v) { } inline __device__ uint16_t float_to_half(float f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; +#ifndef USE_ROCM asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); - return tmp.u16[0]; #else - uint16_t ret; - asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f)); - return ret; + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); #endif + return tmp.u16[0]; } inline __device__ uint32_t float2_to_half2(float2 f) { -#ifndef USE_ROCM union { uint32_t u32; uint16_t u16[2]; } tmp; - +#ifndef USE_ROCM #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); #else asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); #endif - return tmp.u32; #else - union { - uint32_t u32; - uint16_t u16[2]; - } tmp; tmp.u16[0] = float_to_half(f.x); tmp.u16[1] = float_to_half(f.y); - return tmp.u32; #endif + return tmp.u32; } // Vector addition. inline __device__ uint16_t add(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; +#ifndef USE_ROCM asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; #else - uint16_t c; asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; #endif + return c; } inline __device__ uint32_t add(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; +#ifndef USE_ROCM asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; #else - uint32_t c; asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; #endif + return c; } inline __device__ uint2 add(uint2 a, uint2 b) { @@ -217,28 +203,24 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) { // Vector multiplication. template<> inline __device__ uint16_t mul(uint16_t a, uint16_t b) { -#ifndef USE_ROCM uint16_t c; +#ifndef USE_ROCM asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); - return c; #else - uint16_t c; asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; #endif + return c; } template<> inline __device__ uint32_t mul(uint32_t a, uint32_t b) { -#ifndef USE_ROCM uint32_t c; +#ifndef USE_ROCM asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); - return c; #else - uint32_t c; asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); - return c; #endif + return c; } template<> @@ -343,15 +325,13 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) { // Vector fused multiply-add. inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { -#ifndef USE_ROCM uint32_t d; +#ifndef USE_ROCM asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); - return d; #else - uint32_t d; asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c)); - return d; #endif + return d; } inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { diff --git a/csrc/reduction_utils.cuh b/csrc/reduction_utils.cuh index 59ac565c585e8..b95ccef162072 100644 --- a/csrc/reduction_utils.cuh +++ b/csrc/reduction_utils.cuh @@ -16,7 +16,9 @@ * limitations under the License. */ #pragma once + #include "cuda_compat.h" + namespace vllm { template From cf633a7afda862d418578bba349fff6899228e03 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Dec 2023 11:48:11 -0800 Subject: [PATCH 54/94] Raise error on unsupported model --- vllm/model_executor/model_loader.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 0cd890615c918..85cfa9228c136 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -92,6 +92,12 @@ def get_model(model_config: ModelConfig, if getattr(model_class, "supports_lora", True): model = model_class(model_config.hf_config, linear_method, lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") else: model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": From 65d154282c44bcecda8919c800afc40e8d9ab25e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Dec 2023 12:20:41 -0800 Subject: [PATCH 55/94] Support more vocab sizes --- csrc/punica/bgmv/bgmv_config.h | 4 +++ tests/lora/test_punica.py | 27 ++------------ vllm/config.py | 9 ++++- vllm/lora/layers.py | 36 +++++++++++++------ vllm/lora/lora.py | 5 +++ vllm/lora/models.py | 5 +++ vllm/lora/worker_manager.py | 5 +++ .../layers/vocab_parallel_embedding.py | 18 +++++++--- vllm/model_executor/models/llama.py | 12 +++++-- vllm/model_executor/models/mistral.py | 12 +++++-- 10 files changed, 86 insertions(+), 47 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index da6e6a611ecaa..ce2a2112a3d91 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -41,8 +41,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 32512) \ + f(in_T, out_T, W_T, narrow, 32768) \ + f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ +// Keep above in sync with vllm/lora/layers::LoRASampler // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 26a7d47933309..f603b06cdb565 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,30 +43,9 @@ def _lora_ref_impl( H1 = H2 = [ - 128, - 256, - 512, - 1024, - 1280, - 2048, - 2560, - 2752, - 3072, - 3456, - 3584, - 4096, - 5120, - 5504, - 6912, - 7168, - 8192, - 9216, - 10240, - 11008, - 13824, - 14336, - 32000, - 32256, + 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, + 5504, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, 32256, + 32512, 32768, 33024 ] SEED = [0xabcdabcd987] diff --git a/vllm/config.py b/vllm/config.py index b0a69edc900b1..7f4add75f96a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, ClassVar from dataclasses import dataclass import os @@ -364,14 +364,21 @@ class LoRAConfig: max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h possible_max_ranks = (8, 16, 32, 64) + possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6ba8b0585847d..f03ae78ee8751 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,4 +1,5 @@ # pylint: disable=unused-argument +import math from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple @@ -283,12 +284,12 @@ def set_lora( if self.embeddings_slice is not None: # TODO(yard1): Optimize this copy, we don't need to copy # everything, just the modified part - self.embeddings_weights.copy_( - self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2]) - [self.embeddings_slice[0]:self.embeddings_slice[1]]) + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2] + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def set_mapping( self, @@ -856,6 +857,11 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + if 32000 < self.base_layer.vocab_size > 33024: + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" + ) self.lora_a_stacked = torch.zeros( ( max_loras, @@ -870,7 +876,10 @@ def create_lora_weights( ( max_loras, 1, - self.base_layer.vocab_size, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -933,8 +942,6 @@ def _get_logits( if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] lora_logits = torch.empty( self.embeddings_tensors.shape[0] + 1, @@ -948,8 +955,7 @@ def _get_logits( out=lora_logits[:-1]) lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT - - logits[:, self.base_layer.org_vocab_size:] = (lora_logits.reshape( + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], ).index_select(0, @@ -957,6 +963,10 @@ def _get_logits( nan=float("-inf"), posinf=float("inf"), neginf=float("-inf"))) + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + _apply_lora( hidden_states, self.lora_a_stacked, @@ -964,6 +974,10 @@ def _get_logits( self.indices[:self.indices_len[1]], logits, ) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits def forward(self, *args, **kwargs): diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 042a98597ab26..11006c4e1a1c9 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -70,6 +70,11 @@ def output_dim(self) -> int: def is_packed(self) -> bool: return False + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + class PackedLoRA(LoRA): """LoRA used for packed layers (eg. qkv_proj).""" diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 653d3b924cc3a..042b899763f10 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -178,6 +178,11 @@ def __init__( self.rank = rank self.loras: Dict[str, LoRA] = loras + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + def get_lora(self, module_name: str) -> Optional[LoRA]: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 43c016c32193b..363b7770be178 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -184,6 +184,11 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: raise ValueError( f"LoRA rank {lora.rank} is greater than max_lora_rank " f"{self.lora_config.max_lora_rank}.") + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than " + f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}." + ) return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9e4ac26e73d00..9c5fb890251ed 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -13,8 +13,11 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs +DEFAULT_VOCAB_PADDING_SIZE = 64 -def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to @@ -44,19 +47,22 @@ class VocabParallelEmbedding(torch.nn.Module): embedding_dim: size of hidden state. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None): + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.org_vocab_size = org_num_embeddings or num_embeddings - self.num_embeddings_padded = pad_vocab_size(num_embeddings) + self.num_embeddings_padded = pad_vocab_size(num_embeddings, + padding_size) self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -118,6 +124,7 @@ class ParallelLMHead(VocabParallelEmbedding): bias: whether to use bias. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, @@ -125,9 +132,10 @@ def __init__(self, embedding_dim: int, bias: bool = False, params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None): + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings) + org_num_embeddings, padding_size) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 342f940dbb92c..240f3ad57f655 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -291,9 +291,15 @@ def __init__( unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead(unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size) + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # We need bigger padding if using lora for kernel + # compatibility + padding_size=64 + if not lora_config else lora_config.lora_vocab_padding_size, + ) self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 3547a72dc8558..7e2ee4a721e67 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -289,9 +289,15 @@ def __init__( unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead(unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size) + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # We need bigger padding if using lora for kernel + # compatibility + padding_size=64 + if not lora_config else lora_config.lora_vocab_padding_size, + ) self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( From 951e225e5d206594e61fd43f2f021187b4dc07df Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 07:38:24 +0000 Subject: [PATCH 56/94] Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize --- csrc/attention/attention_kernels.cu | 70 ++++++++++++++++++----------- csrc/cuda_compat.h | 12 ++++- 2 files changed, 55 insertions(+), 27 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 19ad43b554aff..d301e8c92489b 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -542,31 +542,52 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -#ifndef USE_ROCM -#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - cudaFuncSetAttribute( \ - (void*)vllm::paged_attention_v1_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ - vllm::paged_attention_v1_kernel \ - <<>>( \ - out_ptr, \ - query_ptr, \ - key_cache_ptr, \ - value_cache_ptr, \ - head_mapping_ptr, \ - scale, \ - block_tables_ptr, \ - context_lens_ptr, \ - max_num_blocks_per_seq, \ - alibi_slopes_ptr, \ - q_stride, \ - kv_block_stride, \ - kv_head_stride); -#else +// #ifndef USE_ROCM +// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ +// cudaFuncSetAttribute( \ +// (void*)vllm::paged_attention_v1_kernel, \ +// cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ +// vllm::paged_attention_v1_kernel \ +// <<>>( \ +// out_ptr, \ +// query_ptr, \ +// key_cache_ptr, \ +// value_cache_ptr, \ +// head_mapping_ptr, \ +// scale, \ +// block_tables_ptr, \ +// context_lens_ptr, \ +// max_num_blocks_per_seq, \ +// alibi_slopes_ptr, \ +// q_stride, \ +// kv_block_stride, \ +// kv_head_stride); +// #else +// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ +// hipFuncSetAttribute( \ +// (void*)vllm::paged_attention_v1_kernel, \ +// hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ +// vllm::paged_attention_v1_kernel \ +// <<>>( \ +// out_ptr, \ +// query_ptr, \ +// key_cache_ptr, \ +// value_cache_ptr, \ +// head_mapping_ptr, \ +// scale, \ +// block_tables_ptr, \ +// context_lens_ptr, \ +// max_num_blocks_per_seq, \ +// alibi_slopes_ptr, \ +// q_stride, \ +// kv_block_stride, \ +// kv_head_stride); +// #endif + #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ - hipFuncSetAttribute( \ - (void*)vllm::paged_attention_v1_kernel, \ - hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ out_ptr, \ @@ -582,7 +603,6 @@ __global__ void paged_attention_v2_reduce_kernel( q_stride, \ kv_block_stride, \ kv_head_stride); -#endif // TODO(woosuk): Tune NUM_THREADS. template< diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index fcafc7e16687c..ed3ebe7dc47c8 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -13,7 +13,15 @@ #endif #ifndef USE_ROCM - #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane); + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) #else #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) -#endif \ No newline at end of file +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif From 25f9a970438dfa1fd0477f2ca1c3087e41b69d93 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 07:39:08 +0000 Subject: [PATCH 57/94] Added wrapper for setting devFuncAttributeMaxDynamicSharedMemorySize --- csrc/attention/attention_kernels.cu | 42 ----------------------------- 1 file changed, 42 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index d301e8c92489b..6e744376b8978 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -542,48 +542,6 @@ __global__ void paged_attention_v2_reduce_kernel( } // namespace vllm -// #ifndef USE_ROCM -// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ -// cudaFuncSetAttribute( \ -// (void*)vllm::paged_attention_v1_kernel, \ -// cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ -// vllm::paged_attention_v1_kernel \ -// <<>>( \ -// out_ptr, \ -// query_ptr, \ -// key_cache_ptr, \ -// value_cache_ptr, \ -// head_mapping_ptr, \ -// scale, \ -// block_tables_ptr, \ -// context_lens_ptr, \ -// max_num_blocks_per_seq, \ -// alibi_slopes_ptr, \ -// q_stride, \ -// kv_block_stride, \ -// kv_head_stride); -// #else -// #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ -// hipFuncSetAttribute( \ -// (void*)vllm::paged_attention_v1_kernel, \ -// hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \ -// vllm::paged_attention_v1_kernel \ -// <<>>( \ -// out_ptr, \ -// query_ptr, \ -// key_cache_ptr, \ -// value_cache_ptr, \ -// head_mapping_ptr, \ -// scale, \ -// block_tables_ptr, \ -// context_lens_ptr, \ -// max_num_blocks_per_seq, \ -// alibi_slopes_ptr, \ -// q_stride, \ -// kv_block_stride, \ -// kv_head_stride); -// #endif - #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ ((void*)vllm::paged_attention_v1_kernel), \ From e984ada6b211c82008de183af1e5a73c80924937 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 08:52:48 +0000 Subject: [PATCH 58/94] Updated ROCm warp size --- csrc/attention/attention_kernels.cu | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 6e744376b8978..e6ffc9e38e832 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -15,6 +15,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef USE_ROCM +#include +#endif + #include #include @@ -26,7 +30,7 @@ #ifndef USE_ROCM #define WARP_SIZE 32 #else -#define WARP_SIZE 64 +#define WARP_SIZE warpSize #endif #define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MIN(a, b) ((a) < (b) ? (a) : (b)) From cc1195fe7ee8391baf380c803d0942e5ef2bda68 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 10:24:14 +0000 Subject: [PATCH 59/94] Format code --- Dockerfile.rocm | 2 +- csrc/attention/dtype_float16.cuh | 2 +- csrc/cuda_compat.h | 1 + csrc/quantization/squeezellm/quant_cuda_kernel.cu | 2 +- setup.py | 2 -- 5 files changed, 4 insertions(+), 5 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 11a16aff5c4d3..cebfdf968d7ab 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -60,4 +60,4 @@ RUN cd /app \ RUN python3 -m pip install --upgrade pip RUN python3 -m pip install --no-cache-dir ray[all] -CMD ["/bin/bash"] \ No newline at end of file +CMD ["/bin/bash"] diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh index 421419a96f137..d3271e69cd69d 100644 --- a/csrc/attention/dtype_float16.cuh +++ b/csrc/attention/dtype_float16.cuh @@ -499,4 +499,4 @@ inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } -} // namespace vllm \ No newline at end of file +} // namespace vllm diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index ed3ebe7dc47c8..aa58dd73c148a 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -25,3 +25,4 @@ #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif + diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu index c2f77da15977a..2c37d01e0ae5c 100644 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ b/csrc/quantization/squeezellm/quant_cuda_kernel.cu @@ -219,4 +219,4 @@ void squeezellm_gemm( } #undef BLOCKWIDTH -#undef BLOCKHEIGHT4 \ No newline at end of file +#undef BLOCKHEIGHT4 diff --git a/setup.py b/setup.py index 3be4fd4fc3950..e3d07cf9521ed 100644 --- a/setup.py +++ b/setup.py @@ -277,8 +277,6 @@ def get_requirements() -> List[str]: else: with open(get_path("requirements.txt")) as f: requirements = f.read().strip().split("\n") - print("requirements: ", requirements) - # exit() return requirements From f92980e357d7fc0691f6ab54df885a2a86ee7ce9 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 11:17:42 +0000 Subject: [PATCH 60/94] Check hip from wrapper --- setup.py | 25 +++++++++++++------ vllm/engine/arg_utils.py | 9 ++++--- vllm/engine/ray_utils.py | 3 ++- vllm/model_executor/layers/attention.py | 3 ++- .../layers/quantization/__init__.py | 3 ++- .../model_executor/layers/quantization/awq.py | 5 ++-- .../layers/quantization/squeezellm.py | 10 +++----- vllm/model_executor/model_loader.py | 3 ++- vllm/utils.py | 4 +++ 9 files changed, 41 insertions(+), 24 deletions(-) diff --git a/setup.py b/setup.py index e3d07cf9521ed..3fc19a912e630 100644 --- a/setup.py +++ b/setup.py @@ -19,15 +19,24 @@ ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"} # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) + +def _is_hip(): + return torch.version.hip + + +def _is_cuda(): + return torch.version.cuda + + # Compiler flags. CXX_FLAGS = ["-g", "-O2", "-std=c++17"] # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] -if torch.version.hip and ROCM_HOME is not None: +if _is_hip() and ROCM_HOME is not None: NVCC_FLAGS += ["-DUSE_ROCM"] -if torch.version.cuda and CUDA_HOME is None: +if _is_cuda() and CUDA_HOME is None: raise RuntimeError( "Cannot find CUDA_HOME. CUDA must be available to build the package.") @@ -129,7 +138,7 @@ def get_torch_arch_list() -> Set[str]: # First, check the TORCH_CUDA_ARCH_LIST environment variable. compute_capabilities = get_torch_arch_list() -if torch.version.cuda and not compute_capabilities: +if _is_cuda() and not compute_capabilities: # If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available # GPUs on the current machine. device_count = torch.cuda.device_count() @@ -140,7 +149,7 @@ def get_torch_arch_list() -> Set[str]: "GPUs with compute capability below 7.0 are not supported.") compute_capabilities.add(f"{major}.{minor}") -if torch.version.cuda: +if _is_cuda(): nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: # If no GPU is specified nor available, add all supported architectures @@ -191,7 +200,7 @@ def get_torch_arch_list() -> Set[str]: num_threads = min(os.cpu_count(), 8) NVCC_FLAGS += ["--threads", str(num_threads)] -elif torch.version.hip: +elif _is_hip(): amd_arch = get_amdgpu_offload_arch() if amd_arch not in ROCM_SUPPORTED_ARCHS: raise RuntimeError( @@ -211,7 +220,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/pybind.cpp", ] -if torch.version.cuda: +if _is_cuda(): vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") vllm_extension = CUDAExtension( @@ -245,7 +254,7 @@ def find_version(filepath: str) -> str: def get_vllm_version() -> str: version = find_version(get_path("vllm", "__init__.py")) - if torch.version.hip: + if _is_hip(): # Get the HIP version hipcc_version = get_hipcc_rocm_version() if hipcc_version != MAIN_CUDA_VERSION: @@ -271,7 +280,7 @@ def read_readme() -> str: def get_requirements() -> List[str]: """Get Python package dependencies from requirements.txt.""" - if torch.version.hip: + if _is_hip(): with open(get_path("requirements-rocm.txt")) as f: requirements = f.read().strip().split("\n") else: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 57e3398a3acb0..41577aa8304f3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,6 +7,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.utils import is_hip @dataclass @@ -89,7 +90,7 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') - if torch.cuda.is_available() and torch.version.hip: + if is_hip(): # do something specific for HIP parser.add_argument( '--load-format', @@ -106,7 +107,7 @@ def add_cli_args( help='data type for model weights and activations. ' 'The default option is FP16 precision ' 'Supports FP16 and BF16 ') - elif torch.cuda.is_available() and torch.version.cuda: + else: # do something specific for CUDA parser.add_argument( '--load-format', @@ -197,7 +198,7 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') - if torch.cuda.is_available() and torch.version.hip: + if is_hip(): # Quantization settings. parser.add_argument('--quantization', '-q', @@ -206,7 +207,7 @@ def add_cli_args( default=None, help='Method used to quantize the weights') - elif torch.cuda.is_available() and torch.version.cuda: + else: # Quantization settings. parser.add_argument('--quantization', '-q', diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index e67cc33774e71..18ba04eae733b 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -3,6 +3,7 @@ from vllm.config import ParallelConfig from vllm.logger import init_logger +from vllm.utils import is_hip import torch @@ -75,7 +76,7 @@ def initialize_cluster( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - if torch.version.hip: + if is_hip(): ray.init(address=ray_address, ignore_reinit_error=True, num_gpus=parallel_config.world_size) diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 7ce763313937b..3f4ecb5d2ae7b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -10,6 +10,7 @@ from vllm._C import ops from vllm._C import cache_ops from vllm.model_executor.input_metadata import InputMetadata +from vllm.utils import is_hip _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. @@ -161,7 +162,7 @@ def forward( p=0.0, scale=self.scale, op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if - (torch.cuda.is_available() and torch.version.hip) else None, + (is_hip()) else None, ) output = out.view_as(query) else: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index f4d25566cf599..61357e578ddf5 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -2,12 +2,13 @@ import torch from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.utils import is_hip _QUANTIZATION_CONFIG_REGISTRY = { "squeezellm": SqueezeLLMConfig, } -if torch.cuda.is_available() and torch.version.cuda: +if not is_hip(): from vllm.model_executor.layers.quantization.awq import AWQConfig _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index bcc0c092a6ac2..eb6bba0b2e91c 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -2,10 +2,11 @@ import torch from torch.nn.parameter import Parameter -if torch.cuda.is_available() and torch.version.hip: +from vllm.utils import is_hip +if is_hip(): # do something specific for HIP print("Warning: vLLM does not support AWQ on ROCm.") -elif torch.cuda.is_available() and torch.version.cuda: +else: from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index b38615cbdbb83..7855a2abacae5 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -7,6 +7,7 @@ from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.utils import is_hip class SqueezeLLMConfig(QuantizationConfig): @@ -114,14 +115,11 @@ def apply_weights(self, lookup_table = weights["lookup_table"] out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) - if torch.cuda.is_available() and torch.version.hip: - out_float = torch.zeros(out_shape, - device="cuda", - dtype=torch.float) + if is_hip(): + out_float = torch.zeros(out_shape, device="cuda", dtype=torch.float) ops.squeezellm_gemm(reshaped_x, qweight, out_float, lookup_table) out = out_float.to(dtype=torch.float16) - # do something specific for HIP - elif torch.cuda.is_available() and torch.version.cuda: + else: # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 4132adf89a0bc..d1a81b0898cae 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -10,6 +10,7 @@ from vllm.model_executor.models import * from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) +from vllm.utils import is_hip # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { @@ -44,7 +45,7 @@ # in models such as Mistral "MistralForCausalLM", ] -if torch.version.hip: +if is_hip(): for rocm_model in _ROCM_DISABLED_MODELS: del _MODEL_REGISTRY[rocm_model] diff --git a/vllm/utils.py b/vllm/utils.py index 47e51048fed45..0bda28c046199 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -53,3 +53,7 @@ def random_uuid() -> str: def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() + + +def is_hip(): + return torch.version.hip From 66b4aa10db033aa84c601b0dc8eed3aa028da6d0 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 14:31:12 +0000 Subject: [PATCH 61/94] Format code --- vllm/config.py | 39 +++++++- vllm/engine/arg_utils.py | 99 +++++++------------ vllm/engine/ray_utils.py | 2 - .../layers/quantization/__init__.py | 1 - .../layers/quantization/squeezellm.py | 4 +- 5 files changed, 71 insertions(+), 74 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index cd92d361d33c5..a2739e5f7d438 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -6,7 +6,7 @@ from vllm.logger import init_logger from vllm.transformers_utils.config import get_config -from vllm.utils import get_cpu_memory +from vllm.utils import get_cpu_memory, is_hip logger = init_logger(__name__) @@ -98,12 +98,27 @@ def __init__( def _verify_load_format(self) -> None: load_format = self.load_format.lower() - if load_format not in [ - "auto", "pt", "safetensors", "npcache", "dummy" - ]: + supported_load_format = [ + "auto", "pt", "safetensors", "npcache", "dummy" + ] + rocm_not_supported_load_format = ["safetensors"] + if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + if is_hip(): + if load_format in ["safetensors"]: + rocm_supported_load_format = [ + f for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format \'{load_format}\' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}") + # Force ROCm to load from pt weights if nothing specific is set + if load_format == "auto": + load_format = "pt" self.load_format = load_format def _verify_tokenizer_mode(self) -> None: @@ -116,6 +131,7 @@ def _verify_tokenizer_mode(self) -> None: def _verify_quantization(self) -> None: supported_quantization = ["awq", "squeezellm"] + rocm_not_supported_quantization = ["awq"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -137,6 +153,11 @@ def _verify_quantization(self) -> None: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization in rocm_not_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not supported " + f"in ROCm.") logger.warning(f"{self.quantization} quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.") @@ -364,6 +385,8 @@ def _verify_args(self) -> None: "bfloat16": torch.bfloat16, } +_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] + def _get_and_verify_dtype( config: PretrainedConfig, @@ -393,6 +416,14 @@ def _get_and_verify_dtype( else: raise ValueError(f"Unknown dtype: {dtype}") + if is_hip() and torch_dtype == torch.float32: + rocm_supported_dtypes = [ + k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + if (k not in _ROCM_NOT_SUPPORTED_DTYPE) + ] + raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}") + # Verify the dtype. if torch_dtype != config_dtype: if torch_dtype == torch.float32: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 41577aa8304f3..8dec696e7fb6c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,11 +3,8 @@ from dataclasses import dataclass from typing import Optional, Tuple -import torch - from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) -from vllm.utils import is_hip @dataclass @@ -90,52 +87,32 @@ def add_cli_args( help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') - if is_hip(): - # do something specific for HIP - parser.add_argument( - '--load-format', - type=str, - default='pt', - choices=['pt'], - help='The format of the model weights to load. ' - '"pt" will load the weights in the pytorch bin format. ') - parser.add_argument( - '--dtype', - type=str, - default='half', - choices=['half', 'float16', 'bfloat16'], - help='data type for model weights and activations. ' - 'The default option is FP16 precision ' - 'Supports FP16 and BF16 ') - else: - # do something specific for CUDA - parser.add_argument( - '--load-format', - type=str, - default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' - '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' - '"pt" will load the weights in the pytorch bin format. ' - '"safetensors" will load the weights in the safetensors format. ' - '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' - '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') - parser.add_argument( - '--dtype', - type=str, - default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='data type for model weights and activations. ' - 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') parser.add_argument('--max-model-len', type=int, default=None, @@ -198,23 +175,13 @@ def add_cli_args( parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') - if is_hip(): - # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['squeezellm', None], - default=None, - help='Method used to quantize the weights') - - else: - # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'squeezellm', None], - default=None, - help='Method used to quantize the weights') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', 'squeezellm', None], + default=None, + help='Method used to quantize the weights') return parser @classmethod diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index 18ba04eae733b..e101940f10369 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -5,8 +5,6 @@ from vllm.logger import init_logger from vllm.utils import is_hip -import torch - logger = init_logger(__name__) try: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 61357e578ddf5..a23e7ec3fdcbc 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,5 +1,4 @@ from typing import Type -import torch from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.utils import is_hip diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 7855a2abacae5..d41fea3807e0d 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -116,7 +116,9 @@ def apply_weights(self, out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): - out_float = torch.zeros(out_shape, device="cuda", dtype=torch.float) + out_float = torch.zeros(out_shape, + device="cuda", + dtype=torch.float) ops.squeezellm_gemm(reshaped_x, qweight, out_float, lookup_table) out = out_float.to(dtype=torch.float16) else: From 4a0ecb85a9c16b54d1cf96942244d104e4a4ef22 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 14:54:03 +0000 Subject: [PATCH 62/94] Enable support for mistral models --- vllm/model_executor/model_loader.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index d1a81b0898cae..7353b5da6cf42 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -11,6 +11,9 @@ from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) from vllm.utils import is_hip +from vllm.logger import init_logger + +logger = init_logger(__name__) # TODO(woosuk): Lazy-load the model classes. _MODEL_REGISTRY = { @@ -40,15 +43,17 @@ } # Models to be disabled in ROCm -_ROCM_DISABLED_MODELS = [ - # ROCm's flash attention does not support sliding window attention - # in models such as Mistral - "MistralForCausalLM", -] +_ROCM_UNSUPPORTED_MODELS = [] if is_hip(): - for rocm_model in _ROCM_DISABLED_MODELS: + for rocm_model in _ROCM_UNSUPPORTED_MODELS: del _MODEL_REGISTRY[rocm_model] +# Models partially supported in ROCm +_ROCM_PARTIALLY_SUPPORTED_MODELS = { + "MistralForCausalLM": + "Sliding window attention is not supported in ROCm's flash attention", +} + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -63,8 +68,12 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: if arch in _MODEL_REGISTRY: + if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + f"{arch} is not fully supported in ROCm. Reason: " + f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}") return _MODEL_REGISTRY[arch] - elif arch in _ROCM_DISABLED_MODELS: + elif arch in _ROCM_UNSUPPORTED_MODELS: raise ValueError( f"Model architecture {arch} is not supported by ROCm for now. \n" f"Supported architectures {list(_MODEL_REGISTRY.keys())}") From acf51a850c95748beac941a6a9797b65434eb0dc Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 15:15:54 +0000 Subject: [PATCH 63/94] Fixed hip device attribute --- vllm/utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 0bda28c046199..98bce22429d63 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -27,10 +27,14 @@ def reset(self) -> None: self.counter = 0 +def is_hip(): + return torch.version.hip + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html - cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 + cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if (not is_hip()) else 74 max_shared_mem = cuda_utils.get_device_attribute( cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu) return int(max_shared_mem) @@ -53,7 +57,3 @@ def random_uuid() -> str: def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() - - -def is_hip(): - return torch.version.hip From 4a5297774aa98e6cd3aaf0283861e8fef7a6b863 Mon Sep 17 00:00:00 2001 From: miloice Date: Wed, 6 Dec 2023 15:26:02 +0000 Subject: [PATCH 64/94] Format code --- vllm/model_executor/layers/quantization/squeezellm.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index d41fea3807e0d..f2f9cac68fd81 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -116,11 +116,9 @@ def apply_weights(self, out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) if is_hip(): - out_float = torch.zeros(out_shape, - device="cuda", - dtype=torch.float) - ops.squeezellm_gemm(reshaped_x, qweight, out_float, lookup_table) - out = out_float.to(dtype=torch.float16) + out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float) + ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) + out = out_f.to(dtype=torch.float16) else: # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) From 008e92d51044e978891dd92f842c11cc09771378 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 6 Dec 2023 14:43:05 -0800 Subject: [PATCH 65/94] Update vllm/transformers_utils/tokenizer.py --- vllm/transformers_utils/tokenizer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 983227acce41e..2cfa417679dda 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -118,10 +118,6 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache(capacity=max_num_seqs) else: self.lora_tokenizers = None - - def ping(self): - return True - def encode(self, prompt: str, request_id: Optional[str] = None, From c328e587e8094a49f4b7abb0a797b9072ca03348 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 6 Dec 2023 14:51:08 -0800 Subject: [PATCH 66/94] Update vllm/transformers_utils/tokenizer.py --- vllm/transformers_utils/tokenizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2cfa417679dda..695cb893e430e 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -118,6 +118,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache(capacity=max_num_seqs) else: self.lora_tokenizers = None + def encode(self, prompt: str, request_id: Optional[str] = None, From 23a987ac34b229c371b18bb626bf8447cc1e405a Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 7 Dec 2023 06:35:18 +0000 Subject: [PATCH 67/94] Restored awq file --- vllm/model_executor/layers/quantization/__init__.py | 8 +++----- vllm/model_executor/layers/quantization/awq.py | 10 +++++----- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a23e7ec3fdcbc..3d937ba64f9fa 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,16 +1,14 @@ from typing import Type + +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig -from vllm.utils import is_hip _QUANTIZATION_CONFIG_REGISTRY = { + "awq": AWQConfig, "squeezellm": SqueezeLLMConfig, } -if not is_hip(): - from vllm.model_executor.layers.quantization.awq import AWQConfig - _QUANTIZATION_CONFIG_REGISTRY["awq"] = AWQConfig - def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in _QUANTIZATION_CONFIG_REGISTRY: diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index eb6bba0b2e91c..59172cbc354e0 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -3,12 +3,12 @@ import torch from torch.nn.parameter import Parameter from vllm.utils import is_hip -if is_hip(): - # do something specific for HIP - print("Warning: vLLM does not support AWQ on ROCm.") -else: - from vllm._C import ops +# if is_hip(): +# # do something specific for HIP +# print("Warning: vLLM does not support AWQ on ROCm.") +# else: +from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig From 8787a4eb08a6f03ced6a02cb3b859d4e515c0be9 Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 7 Dec 2023 06:37:11 +0000 Subject: [PATCH 68/94] Format code --- csrc/attention/attention_kernels.cu | 2 +- patch_xformers-0.0.22.post7.rocm.sh | 2 +- vllm/model_executor/layers/quantization/awq.py | 5 ----- 3 files changed, 2 insertions(+), 7 deletions(-) diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index e6ffc9e38e832..3676af1a378d2 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -548,7 +548,7 @@ __global__ void paged_attention_v2_reduce_kernel( #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ - ((void*)vllm::paged_attention_v1_kernel), \ + ((void*)vllm::paged_attention_v1_kernel), \ shared_mem_size); \ vllm::paged_attention_v1_kernel \ <<>>( \ diff --git a/patch_xformers-0.0.22.post7.rocm.sh b/patch_xformers-0.0.22.post7.rocm.sh index c8e58f721ae8c..8dcb6240f0a44 100644 --- a/patch_xformers-0.0.22.post7.rocm.sh +++ b/patch_xformers-0.0.22.post7.rocm.sh @@ -19,4 +19,4 @@ if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commo echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" else echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" -fi \ No newline at end of file +fi diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 59172cbc354e0..95d419e64f049 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -2,11 +2,6 @@ import torch from torch.nn.parameter import Parameter -from vllm.utils import is_hip -# if is_hip(): -# # do something specific for HIP -# print("Warning: vLLM does not support AWQ on ROCm.") -# else: from vllm._C import ops from vllm.model_executor.layers.linear import (LinearMethodBase, From 9fa80751fefd014166cb0446b082221bfb81ffc1 Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 7 Dec 2023 10:12:19 +0000 Subject: [PATCH 69/94] Updated rocm dockerfile --- Dockerfile.rocm | 1 - 1 file changed, 1 deletion(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index cebfdf968d7ab..7af53e4472e33 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -31,7 +31,6 @@ ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin: ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib: ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/: -ENV PYTORCH_ROCM_ARCH=gfx900;gfx906;gfx908;gfx90a;gfx1030;gfx1101 # Install ROCm flash-attention RUN mkdir libs \ From 81e052dfa9872ba196dbbb6136294788d2fa4556 Mon Sep 17 00:00:00 2001 From: miloice Date: Thu, 7 Dec 2023 10:26:05 +0000 Subject: [PATCH 70/94] Update amd installation guide --- docs/source/getting_started/amd-installation.rst | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index bc81a003d9fdd..57f902bf0518c 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -22,7 +22,7 @@ Requirements Build from source with docker ----------------- -You can also build and install vLLM from source: +You can build and install vLLM from source: Build a docker image from `Dockerfile.rocm`, and launch a docker container. @@ -42,7 +42,7 @@ Build a docker image from `Dockerfile.rocm`, and launch a docker container. vllm-rocm \ bash -If you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. pytorch+rocm5.7), you can follow the steps below: +Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below: 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): @@ -53,12 +53,6 @@ If you plan to install vLLM-ROCm on a local machine or start from a fresh docker Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ - Note: If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. - -.. code-block:: console - - $ bash patch_torch211_flash_attn2.rocm.sh - .. note:: - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) From fb8ac26504462fe077dc0c7366e15c23dff23f3e Mon Sep 17 00:00:00 2001 From: TJian Date: Thu, 7 Dec 2023 22:31:25 +0800 Subject: [PATCH 71/94] Update vLLM Documentations (#18) Update the vLLM installation procedures on AMD platform. Update vLLM documentations. --- README.md | 2 + .../getting_started/amd-installation.rst | 93 ++++++++++++++++--- docs/source/index.rst | 2 + 3 files changed, 86 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 9cc325e924f77..477e61489d5a8 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 +- [2023/12] Added ROCm support to vLLM. - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). - [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. - [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! @@ -43,6 +44,7 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server +- Support NVIDIA CUDA and AMD ROCm. vLLM seamlessly supports many Hugging Face models, including the following architectures: diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 57f902bf0518c..7d8db353906b5 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -1,11 +1,11 @@ -.. _installation: +.. _installation_rocm: Installation with ROCm -============ +====================== vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. -Datatypes currently supported in ROCm are FP16 and BF16. +Data types currently supported in ROCm are FP16 and BF16. Requirements ------------ @@ -13,14 +13,84 @@ Requirements * OS: Linux * Python: 3.8 -- 3.11 (Verified on 3.10) * GPU: MI200s -* Pytorch 2.0.1/2.1.1 -* ROCm >= 5.7.0 +* Pytorch 2.0.1/2.1.1/2.2 +* ROCm 5.7 +Installation options: -.. _build_from_source: +#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image ` +#. :ref:`Build from source ` +#. :ref:`Build from source with docker ` -Build from source with docker ------------------ +.. _quick_start_docker_rocm: + +(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image +--------------------------------------------------------------------------- + +.. code-block:: console + + $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3 + $ docker run -it \ + --network=host \ + --group-add=video \ + --ipc=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size 8G \ + --device /dev/kfd \ + --device /dev/dri \ + -v :/app/model \ + embeddedllminfo/vllm-rocm \ + bash + + +.. _build_from_source_rocm: + +Option 2: Build from source +--------------------------- + +You can build and install vLLM from source: + +0. Install prerequisites (skip if you are already in an environment/docker with the following installed): + +- `ROCm `_ +- `Pytorch `_ + + .. code-block:: console + + $ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version + + +1. Install `flash attention for ROCm `_ + + Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ + +.. note:: + - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. + - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. + - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. + - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) + +2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention + + .. code-block:: console + + $ pip install xformers==0.0.22.post7 --no-deps + $ bash patch_xformers-0.0.22.post7.rocm.sh + +3. Build vLLM. + + .. code-block:: console + + $ cd vllm + $ pip install -U -r requirements-rocm.txt + $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation + + +.. _build_from_source_docker_rocm: + +Option 3: Build from source with docker +----------------------------------------------------- You can build and install vLLM from source: @@ -54,21 +124,22 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention `_ .. note:: + - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly. + - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`. - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -2. Setup xformers==0.0.22.post7 without dependencies, and apply patches to adapt for ROCm flash attention +2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention .. code-block:: console $ pip install xformers==0.0.22.post7 --no-deps $ bash patch_xformers-0.0.22.post7.rocm.sh -3. Build vllm. +3. Build vLLM. .. code-block:: console $ cd vllm $ pip install -U -r requirements-rocm.txt $ python setup.py install # This may take 5-10 minutes. - diff --git a/docs/source/index.rst b/docs/source/index.rst index 300c22762df47..04af09073a447 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -39,6 +39,7 @@ vLLM is flexible and easy to use with: * Tensor parallelism support for distributed inference * Streaming outputs * OpenAI-compatible API server +* Support NVIDIA CUDA and AMD ROCm. For more information, check out the following: @@ -56,6 +57,7 @@ Documentation :caption: Getting Started getting_started/installation + getting_started/amd-installation getting_started/quickstart .. toctree:: From 98f54870322033bfae2a07b4bc8e20ba151a6d76 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 8 Dec 2023 04:02:18 +0000 Subject: [PATCH 72/94] Updated setup.py, vllm/utils.py and amd-installation doc --- .../getting_started/amd-installation.rst | 2 -- setup.py | 19 ++++++++++--------- vllm/utils.py | 6 +++--- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 7d8db353906b5..512aa051f16d9 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -36,7 +36,6 @@ Installation options: --ipc=host \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ - --shm-size 8G \ --device /dev/kfd \ --device /dev/dri \ -v :/app/model \ @@ -105,7 +104,6 @@ Build a docker image from `Dockerfile.rocm`, and launch a docker container. --ipc=host \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ - --shm-size 8G \ --device /dev/kfd \ --device /dev/dri \ -v :/app/model \ diff --git a/setup.py b/setup.py index 3fc19a912e630..d2c97a52e0ddf 100644 --- a/setup.py +++ b/setup.py @@ -20,12 +20,12 @@ # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) -def _is_hip(): - return torch.version.hip +def _is_hip() -> bool: + return torch.version.hip is not None -def _is_cuda(): - return torch.version.cuda +def _is_cuda() -> bool: + return torch.version.cuda is not None # Compiler flags. @@ -33,7 +33,10 @@ def _is_cuda(): # TODO(woosuk): Should we use -O3? NVCC_FLAGS = ["-O2", "-std=c++17"] -if _is_hip() and ROCM_HOME is not None: +if _is_hip(): + if ROCM_HOME is None: + raise RuntimeError( + "Cannot find ROCM_HOME. ROCm must be available to build the package.") NVCC_FLAGS += ["-DUSE_ROCM"] if _is_cuda() and CUDA_HOME is None: @@ -46,19 +49,17 @@ def _is_cuda(): def get_amdgpu_offload_arch(): - error_message = "" command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" try: output = subprocess.check_output([command]) return output.decode('utf-8').strip() except subprocess.CalledProcessError as e: error_message = f"Error: {e}" + raise RuntimeError(error_message) from e except FileNotFoundError: # If the command is not found, print an error message error_message = f"The command {command} was not found." - - if error_message: - raise RuntimeError(error_message) + raise RuntimeError(error_message) from e return None diff --git a/vllm/utils.py b/vllm/utils.py index 98bce22429d63..d5d8d4efa95c0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -27,14 +27,14 @@ def reset(self) -> None: self.counter = 0 -def is_hip(): - return torch.version.hip +def is_hip() -> bool: + return torch.version.hip is not None def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html - cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if (not is_hip()) else 74 + cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 max_shared_mem = cuda_utils.get_device_attribute( cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu) return int(max_shared_mem) From d90187a294b8df0e3cc20da5dc65b3a032dee65c Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 8 Dec 2023 04:06:08 +0000 Subject: [PATCH 73/94] Updated setup.py --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index d2c97a52e0ddf..46b0840d72464 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,8 @@ def _is_cuda() -> bool: if _is_hip(): if ROCM_HOME is None: raise RuntimeError( - "Cannot find ROCM_HOME. ROCm must be available to build the package.") + "Cannot find ROCM_HOME. ROCm must be available to build the package." + ) NVCC_FLAGS += ["-DUSE_ROCM"] if _is_cuda() and CUDA_HOME is None: @@ -56,7 +57,7 @@ def get_amdgpu_offload_arch(): except subprocess.CalledProcessError as e: error_message = f"Error: {e}" raise RuntimeError(error_message) from e - except FileNotFoundError: + except FileNotFoundError as e: # If the command is not found, print an error message error_message = f"The command {command} was not found." raise RuntimeError(error_message) from e From c840531e12ae49b98eaf1a898304b758c3892115 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 8 Dec 2023 04:08:54 +0000 Subject: [PATCH 74/94] Format code --- setup.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 46b0840d72464..3a263fcf28488 100644 --- a/setup.py +++ b/setup.py @@ -36,8 +36,7 @@ def _is_cuda() -> bool: if _is_hip(): if ROCM_HOME is None: raise RuntimeError( - "Cannot find ROCM_HOME. ROCm must be available to build the package." - ) + "Cannot find ROCM_HOME. ROCm must be available to build the package.") NVCC_FLAGS += ["-DUSE_ROCM"] if _is_cuda() and CUDA_HOME is None: From 8566144a01d5f572860f347a6fde8a627d5e3bc4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Dec 2023 10:59:37 -0800 Subject: [PATCH 75/94] Reuse code --- vllm/lora/models.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 042b899763f10..1bd10abaa778a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -622,14 +622,8 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - was_added = False - if lora.id not in self._registered_loras: - was_added = True - logger.debug(f"Adding LoRA. Model id: {lora.id}, " - f"int id: {lora.id}") - self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora - else: + was_added = super().add_lora(lora) + if not was_added: # We always touch to update the LRU cache order self._registered_loras.touch(lora.id) return was_added From 2d72ae560c46332987ee052cdfd7d96ee8fe6db1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Dec 2023 11:13:19 -0800 Subject: [PATCH 76/94] Naming --- vllm/lora/models.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 1bd10abaa778a..2f7c6154c7a24 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -52,14 +52,14 @@ def convert_mapping( - mapping: LoRAMapping, lora_id_to_index: List[Optional[int]], + mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: """Converts LoRAMapping to index tensors. Args: mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_id_to_index: List mapping LoRA ids to LoRA indices. + lora_index_to_id: List mapping LoRA ids to LoRA indices. max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. @@ -86,13 +86,13 @@ def convert_mapping( embedding_indices = indices.copy() lora_indices = indices.copy() prompt_mapping = [ - lora_id_to_index.index(x) if x > 0 else -1 + lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None for i in range(len(indices)): # TODO index can be slow. optimize - lora_idx = (lora_id_to_index.index(indices[i]) + lora_idx = (lora_index_to_id.index(indices[i]) if indices[i] > 0 else -1) embedding_indices[i] = lora_idx if indices[i] > 0 else 0 indices[i] = i @@ -327,7 +327,7 @@ def __init__( self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_id_to_index: List[Optional[int]] = [None] * self.lora_slots + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, @@ -377,7 +377,7 @@ def activate_lora( if lora_id in self._active_loras: return False first_free_slot = next( - ((i, lora_id) for i, lora_id in enumerate(self.lora_id_to_index) + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) if lora_id is None), None) if first_free_slot is None: raise ValueError("No free lora slots") @@ -386,7 +386,7 @@ def activate_lora( lora_model = self._registered_loras[lora_id] logger.debug( f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") - self.lora_id_to_index[index] = lora_model.id + self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: @@ -399,8 +399,8 @@ def activate_lora( def _deactivate_lora(self, lora_id: int): try: - index = self.lora_id_to_index.index(lora_id) - self.lora_id_to_index[index] = None + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None except ValueError: pass @@ -431,7 +431,7 @@ def remove_lora(self, lora_id: int) -> bool: def convert_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, - indices_len) = convert_mapping(mapping, self.lora_id_to_index, + indices_len) = convert_mapping(mapping, self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, self.lora_config.lora_extra_vocab_size) self.base_indices[:base_indices.shape[0]].copy_(base_indices) @@ -459,7 +459,7 @@ def get_lora(self, lora_id: int) -> Optional[LoRAModel]: def remove_all_loras(self) -> bool: """Remove all LoRAModels from the manager.""" self._registered_loras.clear() - self.lora_id_to_index = [None] * self.lora_slots + self.lora_index_to_id = [None] * self.lora_slots self._active_loras.clear() def _create_lora_modules(self): From dd0cbca2bb7f3d9dc896ea9b11b04a86c7bd7a94 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 15 Dec 2023 07:41:51 +0000 Subject: [PATCH 77/94] Ported to rocm --- csrc/cuda_compat.h | 10 +- csrc/punica/bgmv/bgmv_all.cu | 2 + csrc/punica/bgmv/bgmv_impl.cuh | 85 ++++++++++++- csrc/punica/bgmv/vec_dtypes.cuh | 10 ++ csrc/punica/{punica_ops.cc => punica_ops.cu} | 25 ++-- csrc/punica/punica_ops.h | 11 ++ csrc/punica/punica_pybind.cpp | 13 ++ csrc/punica/type_convert.h | 124 +++++++++++++++++++ setup.py | 31 +++-- 9 files changed, 280 insertions(+), 31 deletions(-) rename csrc/punica/{punica_ops.cc => punica_ops.cu} (98%) create mode 100644 csrc/punica/punica_ops.h create mode 100644 csrc/punica/punica_pybind.cpp create mode 100644 csrc/punica/type_convert.h diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index aa58dd73c148a..14302b1d4684f 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,4 +1,5 @@ -#pragma once +#ifndef CSRC__CUDA_COMPAT_H__ +#define CSRC__CUDA_COMPAT_H__ #ifndef USE_ROCM #define VLLM_LDG(arg) __ldg(arg) @@ -18,6 +19,12 @@ #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) #endif +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + #ifndef USE_ROCM #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL) @@ -26,3 +33,4 @@ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif +#endif // CSRC__CUDA_COMPAT_H__ diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu index bc86416701f13..e567cb5c8da7e 100644 --- a/csrc/punica/bgmv/bgmv_all.cu +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -1,6 +1,8 @@ #include "bgmv_config.h" #include "bgmv_impl.cuh" +#include "../type_convert.h" + FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 995de26e8bada..7963342ede0fd 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -1,8 +1,14 @@ #pragma once #include +#ifndef USE_ROCM #include +#else +#include +#endif +#ifndef USE_ROCM #include +#endif #include #include #include @@ -11,6 +17,20 @@ namespace cg = cooperative_groups; +#ifdef USE_ROCM + __host__ __device__ + inline void* memcpy_blocking(void *dst, const void *src, size_t len) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; + } +#endif + // nthrs = (32, 4) template (X_copy_size), pipe); pipe.producer_commit(); +#else + memcpy_blocking(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W_copy_size); + memcpy_blocking(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X_copy_size); +#endif size_t copy_idx, compute_idx; float y = 0.f; vec_t x_vec; @@ -61,6 +92,7 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, ++tile_idx) { copy_idx = tile_idx % num_pipeline_stages; // pipeline stage: async copy W fragment +#ifndef USE_ROCM pipe.producer_acquire(); if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + @@ -76,10 +108,27 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, cuda::aligned_size_t(X_copy_size), pipe); } pipe.producer_commit(); +#else + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + memcpy_blocking(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W_copy_size); + memcpy_blocking(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X_copy_size); + } +#endif compute_idx = (tile_idx - 1) % num_pipeline_stages; // pipeline stage: compute WX +#ifndef USE_ROCM pipe.consumer_wait(); +#endif block.sync(); x_vec.load(X_shared + X_shared_offset[compute_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size); @@ -88,11 +137,15 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } #pragma unroll for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); } y_warpwise[threadIdx.y] = sum; block.sync(); @@ -102,12 +155,16 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, } block.sync(); +#ifndef USE_ROCM pipe.consumer_release(); +#endif } compute_idx = (tile_idx - 1) % num_pipeline_stages; // final pipeline stage +#ifndef USE_ROCM pipe.consumer_wait(); +#endif block.sync(); x_vec.load(X_shared + X_shared_offset[compute_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size); @@ -116,11 +173,15 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } #pragma unroll for (size_t offset = tx / 2; offset > 0; offset /= 2) { - sum += __shfl_down_sync(0xffffffff, sum, offset); + sum += VLLM_SHFL_DOWN_SYNC(sum, offset); } y_warpwise[threadIdx.y] = ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) @@ -133,11 +194,18 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, } block.sync(); +#ifndef USE_ROCM pipe.consumer_release(); +#endif // write Y; if (block.thread_rank() == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + j; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(y)); +#endif } } @@ -172,7 +240,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, float sum = 0.f; #pragma unroll for (size_t i = 0; i < vec_size; ++i) { +#ifndef USE_ROCM sum += float(w_vec[i]) * float(x_vec[i]) * scale; +#else + sum += convert_type(w_vec[i]) * convert_type(x_vec[i]) * scale; +#endif } cg::thread_block_tile g = cg::tiled_partition(block); @@ -183,8 +255,15 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, sum = g.shfl(sum, 0); if (threadIdx.x == 0) { +#ifndef USE_ROCM Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + - threadIdx.z * ty + threadIdx.y] += static_cast(sum); + threadIdx.z * ty + threadIdx.y] + += static_cast(sum); +#else + size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y; + Y[y_idx] = vllm_add(Y[y_idx], convert_type(sum)); +#endif } } diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index cf00d869cf635..86d2971731e4d 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,8 +1,13 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ +#ifndef USE_ROCM #include #include +#else +#include +#include +#endif #ifdef FLASHINFER_USE_FP8 #include #endif @@ -10,6 +15,11 @@ #include +#ifdef USE_ROCM +#include "../../cuda_compat.h" +#include "../type_convert.h" +#endif + #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cu similarity index 98% rename from csrc/punica/punica_ops.cc rename to csrc/punica/punica_ops.cu index 4ad46e5e1f726..935ce48891ab1 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cu @@ -1,12 +1,19 @@ +#include +#include + +#ifndef USE_ROCM #include #include -#include +#else +#include +#include +#endif #include #include "bgmv/bgmv_config.h" - -namespace { +#include "../cuda_compat.h" +#include "type_convert.h" //====== utils ====== @@ -549,15 +556,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); } - -} // namespace - -//====== pybind ====== - -#define DEFINE_pybind(name) m.def(#name, &name, #name); - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); - m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, - "dispatch_bgmv_low_level"); -} diff --git a/csrc/punica/punica_ops.h b/csrc/punica/punica_ops.h new file mode 100644 index 0000000000000..937e2d1d25d4a --- /dev/null +++ b/csrc/punica/punica_ops.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale); + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset); diff --git a/csrc/punica/punica_pybind.cpp b/csrc/punica/punica_pybind.cpp new file mode 100644 index 0000000000000..4435496619a29 --- /dev/null +++ b/csrc/punica/punica_pybind.cpp @@ -0,0 +1,13 @@ +#include + +#include "punica_ops.h" + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} \ No newline at end of file diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h new file mode 100644 index 0000000000000..9180693e2057b --- /dev/null +++ b/csrc/punica/type_convert.h @@ -0,0 +1,124 @@ +#ifndef CSRC__PUNICA__TYPE_CONVERT_H__ +#define CSRC__PUNICA__TYPE_CONVERT_H__ + +#ifdef USE_ROCM + +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + +#include +#include + +typedef __half nv_half; +typedef __hip_bfloat16 nv_bfloat16; + +typedef __hip_bfloat162 nv_bfloat162; +//typedef __hip_bfloat16 nv_bfloat16; + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { +return __hip_bfloat162{val, val}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { +return __hip_bfloat162{vall, valr}; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __half cast_to_half(__half val) { +return val; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __half cast_to_half(__hip_bfloat16 val) { +return __float2half(__bfloat162float(val)); +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __half cast_to_half(float val) { +return __float2half(val); +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 cast_to_bf16(__half val) { +return __float2bfloat16(__half2float(val)); +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 cast_to_bf16(__hip_bfloat16 val) { +return val; +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 cast_to_bf16(float val) { +return __float2bfloat16(val); +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline float cast_to_float(__half val) { +return __half2float(val); +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline float cast_to_float(__hip_bfloat16 val) { +return __bfloat162float(val); +} + +__TYPE_CONVERT__HOST_DEVICE__ +inline float cast_to_float(float val) { +return val; +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T_dst convert_type(T_src val) { +return static_cast(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__half, float>(__half val) { +return __half2float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half convert_type(float val) { +return __float2half(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { +return __bfloat162float(val); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 convert_type(float val) { +return __float2bfloat16(val); +} + +template +__TYPE_CONVERT__HOST_DEVICE__ +inline T vllm_add(T a, T b) { +return a + b; +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __half vllm_add<__half>(__half a, __half b) { +return __hadd(a, b); +} + +template <> +__TYPE_CONVERT__HOST_DEVICE__ +inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { +return __hadd(a, b); +} + +#undef __TYPE_CONVERT__HOST_DEVICE__ + +#endif // USE_ROCM + +#endif // CSRC__PUNICA__TYPE_CONVERT_H__ diff --git a/setup.py b/setup.py index e611d6b2cb613..435be88c35169 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,7 @@ def _is_cuda() -> bool: ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] - +NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() def get_amdgpu_offload_arch(): command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" @@ -196,6 +196,8 @@ def get_torch_arch_list() -> Set[str]: raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 9.0.") + NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() + # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: num = capability[0] + capability[2] @@ -264,17 +266,9 @@ def get_torch_arch_list() -> Set[str]: elif _is_hip(): pass -if install_punica: - ext_modules.append( - CUDAExtension( - name="vllm._punica_C", - sources=["csrc/punica/punica_ops.cc"] + - glob("csrc/punica/bgmv/*.cu"), - extra_compile_args={ - "cxx": CXX_FLAGS, - "nvcc": NVCC_FLAGS_PUNICA, - }, - )) +print(NVCC_FLAGS) +print('=========================') +print(NVCC_FLAGS_PUNICA) vllm_extension = CUDAExtension( name="vllm._C", @@ -286,6 +280,19 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(vllm_extension) +if install_punica: + ext_modules.append( + CUDAExtension( + name="vllm._punica_C", + sources=["csrc/punica/punica_ops.cu", + "csrc/punica/punica_pybind.cpp"] + + glob("csrc/punica/bgmv/*.cu"), + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, + }, + )) + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) From 6e62551d6926f39d69eff23ab18bf880db6a2a21 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 15 Dec 2023 09:56:03 +0000 Subject: [PATCH 78/94] Added notes in README --- README.md | 19 +++++++++++++++++++ setup.py | 4 ---- 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 477e61489d5a8..87ef08cb28872 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,22 @@ +multi-lora rocm development + +Derived from [Yard1's multi-lora branch](https://github.com/Yard1/vllm/tree/multi_lora) + +[A note to whoever wants to run it on ROCm now] + +Starting from ROCm v5.7.0, some type conversion utils for bf16 are implemented in header files, but unfortunately some of them are not specified as inline or static functions. So building the project on ROCm directly would lead to failures in multiple definitions of functions. + +To circumvent this I have been manually adding the "inline" keyward to the related functions. In the container we're using it means adding the keyword "inline" to "/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h#96" so that the line becomes + +```cpp +L96: #define __HOST_DEVICE__ __host__ __device__ inline +``` + +But of course this is far from pretty, if not considered nasty, as it is modifying include header in a potentially dangerous manner, and surely there should be better ways which I'm trying to find. + +This should be solved starting from [ROCm 6.0.0 though](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3). + +

diff --git a/setup.py b/setup.py index 435be88c35169..2f4215d6cb74e 100644 --- a/setup.py +++ b/setup.py @@ -266,10 +266,6 @@ def get_torch_arch_list() -> Set[str]: elif _is_hip(): pass -print(NVCC_FLAGS) -print('=========================') -print(NVCC_FLAGS_PUNICA) - vllm_extension = CUDAExtension( name="vllm._C", sources=vllm_extension_sources, From 1097aa8b9469f96a7c2fafb04d0fd037c9e89747 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Mon, 18 Dec 2023 07:02:21 +0000 Subject: [PATCH 79/94] Added loop unrolling --- csrc/punica/bgmv/bgmv_impl.cuh | 68 +++++++++++++++++++--------------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 7963342ede0fd..423e5c1cca627 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -18,17 +18,31 @@ namespace cg = cooperative_groups; #ifdef USE_ROCM - __host__ __device__ - inline void* memcpy_blocking(void *dst, const void *src, size_t len) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; +__host__ __device__ +inline void* memcpy_blocking(void *dst, const void *src, size_t len) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; + for (i = 0; i < len; ++i) { + d[i] = s[i]; + } + return dst; +} + +template +__host__ __device__ +inline void* memcpy_blocking_unroll(void *dst, const void *src) { + // Does not handle the case of long datatypes + char *d = reinterpret_cast(dst); + const char *s = reinterpret_cast(src); + size_t i = 0; +#pragma unroll + for (i = 0; i < len; ++i) { + d[i] = s[i]; } + return dst; +} #endif // nthrs = (32, 4) @@ -72,14 +86,12 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, cuda::aligned_size_t(X_copy_size), pipe); pipe.producer_commit(); #else - memcpy_blocking(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W_copy_size); - memcpy_blocking(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X_copy_size); + memcpy_blocking_unroll(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size); + memcpy_blocking_unroll(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size); #endif size_t copy_idx, compute_idx; float y = 0.f; @@ -110,17 +122,15 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, pipe.producer_commit(); #else if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { - memcpy_blocking(W_shared + W_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W + (idx * feat_out + j) * feat_in + - tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - W_copy_size); - memcpy_blocking(X_shared + X_shared_offset[copy_idx] + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X + (batch_idx * feat_in) + tile_idx * tile_size + - (threadIdx.y * tx + threadIdx.x) * vec_size, - X_copy_size); + memcpy_blocking_unroll(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); + memcpy_blocking_unroll(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size); } #endif From bc08640fcad23207cf12edf588faffa329dc39ba Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:27:13 +0000 Subject: [PATCH 80/94] Merge 'upstream' --- .github/workflows/publish.yml | 2 +- Dockerfile | 8 +- Dockerfile.rocm | 4 +- README.md | 8 +- benchmarks/benchmark_latency.py | 8 +- benchmarks/benchmark_throughput.py | 11 +- .../kernels/benchmark_paged_attention.py | 8 +- csrc/attention/attention_kernels.cu | 31 +- csrc/cache.h | 2 + csrc/cuda_compat.h | 4 +- csrc/cuda_utils.h | 2 + csrc/dispatch_utils.h | 2 + csrc/ops.h | 18 +- csrc/pos_encoding_kernels.cu | 12 +- csrc/punica/bgmv/bgmv_all.cu | 2 - csrc/pybind.cpp | 4 +- csrc/quantization/gptq/compat.cuh | 64 ++ csrc/quantization/gptq/matrix_view.cuh | 151 +++ csrc/quantization/gptq/q_gemm.cu | 869 ++++++++++++++++++ csrc/quantization/gptq/qdq_4.cuh | 235 +++++ csrc/quantization/gptq/qdq_util.cuh | 60 ++ .../getting_started/amd-installation.rst | 17 +- docs/source/getting_started/installation.rst | 7 +- docs/source/index.rst | 4 +- docs/source/models/engine_args.rst | 6 +- docs/source/models/supported_models.rst | 22 +- docs/source/serving/deploying_with_docker.rst | 8 + docs/source/serving/run_on_sky.rst | 2 +- patch_xformers.rocm.sh | 33 + pyproject.toml | 2 +- requirements-build.txt | 2 +- requirements-dev.txt | 1 + requirements-rocm.txt | 4 +- requirements.txt | 7 +- ...ch => commonpy_xformers-0.0.23.rocm.patch} | 0 ...tch => flashpy_xformers-0.0.23.rocm.patch} | 96 +- setup.py | 1 + tests/conftest.py | 34 +- tests/kernels/test_attention.py | 7 +- tests/models/test_mistral.py | 37 + tests/models/test_models.py | 5 +- tests/prompts/example.txt | 8 + tests/prompts/summary.txt | 1 + vllm/__init__.py | 2 +- vllm/config.py | 56 +- vllm/engine/arg_utils.py | 50 +- vllm/engine/async_llm_engine.py | 14 +- vllm/engine/llm_engine.py | 19 +- vllm/engine/ray_utils.py | 9 +- vllm/entrypoints/api_server.py | 6 +- vllm/entrypoints/llm.py | 16 +- vllm/entrypoints/openai/api_server.py | 32 +- vllm/model_executor/input_metadata.py | 5 +- vllm/model_executor/layers/attention.py | 63 +- vllm/model_executor/layers/linear.py | 60 +- .../layers/quantization/__init__.py | 4 +- .../model_executor/layers/quantization/awq.py | 24 +- .../layers/quantization/gptq.py | 215 +++++ .../layers/quantization/squeezellm.py | 14 +- vllm/model_executor/layers/sampler.py | 316 +++---- vllm/model_executor/model_loader.py | 57 +- vllm/model_executor/models/__init__.py | 116 ++- vllm/model_executor/models/aquila.py | 21 +- vllm/model_executor/models/baichuan.py | 32 +- vllm/model_executor/models/bloom.py | 12 +- vllm/model_executor/models/chatglm.py | 18 +- vllm/model_executor/models/decilm.py | 123 +++ vllm/model_executor/models/falcon.py | 53 +- vllm/model_executor/models/gpt2.py | 14 +- vllm/model_executor/models/gpt_bigcode.py | 13 +- vllm/model_executor/models/gpt_j.py | 21 +- vllm/model_executor/models/gpt_neox.py | 13 +- vllm/model_executor/models/internlm.py | 21 +- vllm/model_executor/models/llama.py | 26 +- vllm/model_executor/models/mistral.py | 21 +- vllm/model_executor/models/mixtral.py | 412 +++++++++ vllm/model_executor/models/mpt.py | 37 +- vllm/model_executor/models/opt.py | 28 +- vllm/model_executor/models/phi_1_5.py | 15 +- vllm/model_executor/models/qwen.py | 22 +- vllm/model_executor/models/yi.py | 21 +- vllm/model_executor/sampling_metadata.py | 187 ++++ vllm/model_executor/weight_utils.py | 54 +- vllm/sampling_params.py | 49 +- vllm/transformers_utils/tokenizer.py | 17 - vllm/utils.py | 12 +- vllm/worker/model_runner.py | 277 +++++- vllm/worker/worker.py | 38 +- 88 files changed, 3569 insertions(+), 875 deletions(-) create mode 100644 csrc/quantization/gptq/compat.cuh create mode 100644 csrc/quantization/gptq/matrix_view.cuh create mode 100644 csrc/quantization/gptq/q_gemm.cu create mode 100644 csrc/quantization/gptq/qdq_4.cuh create mode 100644 csrc/quantization/gptq/qdq_util.cuh create mode 100644 patch_xformers.rocm.sh rename rocm_patch/{commonpy_xformers-0.0.22.post7.rocm.patch => commonpy_xformers-0.0.23.rocm.patch} (100%) rename rocm_patch/{flashpy_xformers-0.0.22.post7.rocm.patch => flashpy_xformers-0.0.23.rocm.patch} (61%) create mode 100644 tests/models/test_mistral.py create mode 100644 tests/prompts/example.txt create mode 100644 tests/prompts/summary.txt create mode 100644 vllm/model_executor/layers/quantization/gptq.py create mode 100644 vllm/model_executor/models/decilm.py create mode 100644 vllm/model_executor/models/mixtral.py diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 58f9944b08aad..5211dc180798e 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,7 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.1.0'] + pytorch-version: ['2.1.2'] # Must be the most recent version that meets requirements.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/Dockerfile b/Dockerfile index b1be5fb9e3b83..6ef03b843f457 100644 --- a/Dockerfile +++ b/Dockerfile @@ -30,11 +30,15 @@ COPY requirements.txt requirements.txt COPY pyproject.toml pyproject.toml COPY vllm/__init__.py vllm/__init__.py +ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' +ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} # max jobs used by Ninja to build extensions -ENV MAX_JOBS=$max_jobs +ARG max_jobs=2 +ENV MAX_JOBS=${max_jobs} # number of threads used by nvcc ARG nvcc_threads=8 ENV NVCC_THREADS=$nvcc_threads + RUN python3 setup.py build_ext --inplace # image to run unit testing suite @@ -71,7 +75,7 @@ ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] FROM vllm-base AS vllm-openai # install additional dependencies for openai api server RUN --mount=type=cache,target=/root/.cache/pip \ - pip install accelerate fschat + pip install accelerate COPY --from=build /workspace/vllm/*.so /workspace/vllm/ COPY vllm vllm diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 7af53e4472e33..36a7ee37fd228 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -47,12 +47,12 @@ RUN mkdir libs \ COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip -RUN pip install xformers==0.0.22.post7 --no-deps +RUN pip install xformers==0.0.23 --no-deps RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ - && bash patch_xformers-0.0.22.post7.rocm.sh \ + && bash patch_xformers.rocm.sh \ && python3 setup.py install \ && cd .. diff --git a/README.md b/README.md index 87ef08cb28872..1863c31fcf2d2 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,8 @@ vLLM is fast with: - State-of-the-art serving throughput - Efficient management of attention key and value memory with **PagedAttention** - Continuous batching of incoming requests +- Fast model execution with CUDA/HIP graph +- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629) - Optimized CUDA kernels vLLM is flexible and easy to use with: @@ -63,7 +65,7 @@ vLLM is flexible and easy to use with: - Tensor parallelism support for distributed inference - Streaming outputs - OpenAI-compatible API server -- Support NVIDIA CUDA and AMD ROCm. +- Support NVIDIA GPUs and AMD GPUs vLLM seamlessly supports many Hugging Face models, including the following architectures: @@ -71,6 +73,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) +- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) - GPT-2 (`gpt2`, `gpt2-xl`, etc.) - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) @@ -79,9 +82,10 @@ vLLM seamlessly supports many Hugging Face models, including the following archi - InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.) - LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.) - Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.) +- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.) - MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.) - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) -- Phi-1.5 (`microsoft/phi-1_5`, etc.) +- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index ae5583ffe0617..e33d5fb2dc247 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -23,6 +23,7 @@ def main(args: argparse.Namespace): tensor_parallel_size=args.tensor_parallel_size, trust_remote_code=args.trust_remote_code, dtype=args.dtype, + enforce_eager=args.enforce_eager, ) sampling_params = SamplingParams( @@ -72,7 +73,7 @@ def run_to_completion(profile_dir: Optional[str] = None): # Benchmark. latencies = [] for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): - latencies.append(run_to_completion(profile=False)) + latencies.append(run_to_completion(profile_dir=None)) print(f'Avg latency: {np.mean(latencies)} seconds') @@ -84,7 +85,7 @@ def run_to_completion(profile_dir: Optional[str] = None): parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--input-len', type=int, default=32) @@ -111,6 +112,9 @@ def run_to_completion(profile_dir: Optional[str] = None): 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument('--enforce-eager', + action='store_true', + help='enforce eager mode and disable CUDA graph') parser.add_argument( '--profile', action='store_true', diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 4540ed8061fa7..3aac479c01bd2 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -69,7 +69,8 @@ def run_vllm( use_beam_search: bool, trust_remote_code: bool, dtype: str, - max_model_len: Optional[int] = None, + max_model_len: Optional[int], + enforce_eager: bool, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -81,6 +82,7 @@ def run_vllm( trust_remote_code=trust_remote_code, dtype=dtype, max_model_len=max_model_len, + enforce_eager=enforce_eager, ) # Add the requests to the engine. @@ -204,7 +206,7 @@ def main(args: argparse.Namespace): args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, - args.max_model_len) + args.max_model_len, args.enforce_eager) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -244,7 +246,7 @@ def main(args: argparse.Namespace): parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument('--quantization', '-q', - choices=['awq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--n", @@ -279,6 +281,9 @@ def main(args: argparse.Namespace): 'The "auto" option will use FP16 precision ' 'for FP32 and FP16 models, and BF16 precision ' 'for BF16 models.') + parser.add_argument("--enforce-eager", + action="store_true", + help="enforce eager execution") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index 91fcf5340298a..935393e9942ce 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -37,10 +37,6 @@ def main( query.uniform_(-scale, scale) assert num_query_heads % num_kv_heads == 0 - num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -103,7 +99,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, @@ -120,7 +116,7 @@ def run_benchmark(num_iters: int, profile: bool = False) -> float: query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 3676af1a378d2..eff28d3dacd0e 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -89,7 +89,7 @@ __device__ void paged_attention_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] + const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -132,7 +132,8 @@ __device__ void paged_attention_kernel( const int head_idx = blockIdx.x; const int num_heads = gridDim.x; - const int kv_head_idx = head_mapping[head_idx]; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; // A vector type to store a part of a key or a query. @@ -401,7 +402,7 @@ __global__ void paged_attention_v1_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] + const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -412,7 +413,7 @@ __global__ void paged_attention_v1_kernel( const int kv_head_stride) { paged_attention_kernel( /* exp_sums */ nullptr, /* max_logits */ nullptr, - out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens, + out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } @@ -430,7 +431,7 @@ __global__ void paged_attention_v2_kernel( const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] - const int* __restrict__ head_mapping, // [num_heads] + const int num_kv_heads, // [num_heads] const float scale, const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] const int* __restrict__ context_lens, // [num_seqs] @@ -440,7 +441,7 @@ __global__ void paged_attention_v2_kernel( const int kv_block_stride, const int kv_head_stride) { paged_attention_kernel( - exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale, + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); } @@ -556,7 +557,7 @@ __global__ void paged_attention_v2_reduce_kernel( query_ptr, \ key_cache_ptr, \ value_cache_ptr, \ - head_mapping_ptr, \ + num_kv_heads, \ scale, \ block_tables_ptr, \ context_lens_ptr, \ @@ -576,7 +577,7 @@ void paged_attention_v1_launcher( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -602,7 +603,6 @@ void paged_attention_v1_launcher( T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -651,7 +651,7 @@ void paged_attention_v1_launcher( query, \ key_cache, \ value_cache, \ - head_mapping, \ + num_kv_heads, \ scale, \ block_tables, \ context_lens, \ @@ -681,7 +681,7 @@ void paged_attention_v1( torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] + int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] @@ -708,7 +708,7 @@ void paged_attention_v1( query_ptr, \ key_cache_ptr, \ value_cache_ptr, \ - head_mapping_ptr, \ + num_kv_heads, \ scale, \ block_tables_ptr, \ context_lens_ptr, \ @@ -739,7 +739,7 @@ void paged_attention_v2_launcher( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -768,7 +768,6 @@ void paged_attention_v2_launcher( T* query_ptr = reinterpret_cast(query.data_ptr()); T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); - int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); int* block_tables_ptr = block_tables.data_ptr(); int* context_lens_ptr = context_lens.data_ptr(); @@ -823,7 +822,7 @@ void paged_attention_v2_launcher( query, \ key_cache, \ value_cache, \ - head_mapping, \ + num_kv_heads, \ scale, \ block_tables, \ context_lens, \ @@ -856,7 +855,7 @@ void paged_attention_v2( torch::Tensor& query, // [num_seqs, num_heads, head_size] torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] - torch::Tensor& head_mapping, // [num_heads] + int num_kv_heads, // [num_heads] float scale, torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] torch::Tensor& context_lens, // [num_seqs] diff --git a/csrc/cache.h b/csrc/cache.h index da49d9103214b..b26faad2ca814 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -1,3 +1,5 @@ +#pragma once + #include #include diff --git a/csrc/cuda_compat.h b/csrc/cuda_compat.h index 14302b1d4684f..1eef4c34607f0 100644 --- a/csrc/cuda_compat.h +++ b/csrc/cuda_compat.h @@ -1,5 +1,4 @@ -#ifndef CSRC__CUDA_COMPAT_H__ -#define CSRC__CUDA_COMPAT_H__ +#pragma once #ifndef USE_ROCM #define VLLM_LDG(arg) __ldg(arg) @@ -33,4 +32,3 @@ hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) #endif -#endif // CSRC__CUDA_COMPAT_H__ diff --git a/csrc/cuda_utils.h b/csrc/cuda_utils.h index 85cb199b9aa0c..69c96cef0d17e 100644 --- a/csrc/cuda_utils.h +++ b/csrc/cuda_utils.h @@ -1,3 +1,5 @@ +#pragma once + #include int get_device_attribute( diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index 7c0c49d392a98..0ae9cd6415982 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -2,6 +2,8 @@ * Adapted from * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h */ +#pragma once + #include #define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ diff --git a/csrc/ops.h b/csrc/ops.h index e12c34f0aafa5..9340a60da1417 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -1,3 +1,5 @@ +#pragma once + #include void paged_attention_v1( @@ -5,7 +7,7 @@ void paged_attention_v1( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -21,7 +23,7 @@ void paged_attention_v2( torch::Tensor& query, torch::Tensor& key_cache, torch::Tensor& value_cache, - torch::Tensor& head_mapping, + int num_kv_heads, float scale, torch::Tensor& block_tables, torch::Tensor& context_lens, @@ -75,3 +77,15 @@ void squeezellm_gemm( torch::Tensor mat, torch::Tensor mul, torch::Tensor lookup_table); + +torch::Tensor gptq_gemm( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama); + +void gptq_shuffle( + torch::Tensor q_weight, + torch::Tensor q_perm); diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index e1dc711778ffb..486ebe1d464c8 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel( scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2] const int rot_dim, - const int query_stride, - const int key_stride, + const int64_t query_stride, + const int64_t key_stride, const int num_heads, const int num_kv_heads, const int head_size) { @@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; const int rot_offset = i % embed_dim; apply_rotary_embedding(query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel( const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; - const int token_head = token_idx * key_stride + head_idx * head_size; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; const int rot_offset = i % embed_dim; apply_rotary_embedding(key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -89,8 +89,8 @@ void rotary_embedding( int rot_dim = cos_sin_cache.size(1); int num_heads = query.size(-1) / head_size; int num_kv_heads = key.size(-1) / head_size; - int query_stride = query.stride(-2); - int key_stride = key.stride(-2); + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu index e567cb5c8da7e..bc86416701f13 100644 --- a/csrc/punica/bgmv/bgmv_all.cu +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -1,8 +1,6 @@ #include "bgmv_config.h" #include "bgmv_impl.cuh" -#include "../type_convert.h" - FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index b0120e9689773..95f557686f337 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { // Quantization ops ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); #endif - - + ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); + ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); // Cache ops diff --git a/csrc/quantization/gptq/compat.cuh b/csrc/quantization/gptq/compat.cuh new file mode 100644 index 0000000000000..4da0bc6e2df38 --- /dev/null +++ b/csrc/quantization/gptq/compat.cuh @@ -0,0 +1,64 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _compat_cuh +#define _compat_cuh + +namespace vllm { +namespace gptq { +// atomicAdd for half types, to support CC < 7.x + +__device__ __forceinline__ void atomicAdd_half(half* address, half val) +{ + unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2)); + unsigned int old = *address_as_ui; + unsigned int assumed; + + do + { + assumed = old; + __half_raw hsum; + hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff); + half tmpres = __hadd(hsum, val); + hsum = __half_raw(tmpres); + old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x; + old = atomicCAS(address_as_ui, assumed, old); + } + while (assumed != old); +} + +// atomicAdd for half2 types + +__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val) +{ + unsigned int* address_as_ui = (unsigned int*)address; + unsigned int old = *address_as_ui; + unsigned int assumed; + do + { + assumed = old; + half2 old_val = *((half2*)&old); + half2 new_val = __hadd2(old_val, val); + old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val)); + } + while (assumed != old); +} + +// + +#if defined(__CUDA_ARCH__) || defined(USE_ROCM) +#if __CUDA_ARCH__ < 700 || defined(USE_ROCM) + +__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); } + +#if __CUDA_ARCH__ < 600 || defined(USE_ROCM) +__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); } +#endif + +#endif +#endif + +} // namespace gptq +} // namespace vllm +#endif diff --git a/csrc/quantization/gptq/matrix_view.cuh b/csrc/quantization/gptq/matrix_view.cuh new file mode 100644 index 0000000000000..1fdf019b29028 --- /dev/null +++ b/csrc/quantization/gptq/matrix_view.cuh @@ -0,0 +1,151 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama +*/ + +#ifndef _matrix_view_cuh +#define _matrix_view_cuh + +#include +#include + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { + +class MatrixView_half +{ +public: + const half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); } + __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; } + + __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const + { + half2* ptr = (half2*) item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __low2half(i01); + items[1] = __high2half(i01); + items[2] = __low2half(i23); + items[3] = __high2half(i23); + } + __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2float(__low2half(i01)); + items[1] = __half2float(__high2half(i01)); + items[2] = __half2float(__low2half(i23)); + items[3] = __half2float(__high2half(i23)); + } + + __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const + { + half2* ptr = (half2*)item_ptr(row, column); + half2 i01 = ptr[0]; + half2 i23 = ptr[1]; + items[0] = __half2half2(__low2half(i01)); + items[1] = __half2half2(__high2half(i01)); + items[2] = __half2half2(__low2half(i23)); + items[3] = __half2half2(__high2half(i23)); + } +}; + +class MatrixView_half_rw +{ +public: + half* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; } + __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; } + __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; } + __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; } + __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; } + + __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3) + { + half2 v01 = __halves2half2(v0, v1); + half2 v23 = __halves2half2(v2, v3); + half2* ptr = (half2*) item_ptr(row, column); + ptr[0] = v01; + ptr[1] = v23; + } +}; + +class MatrixView_q4_row +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (column & 0x07) * 4; + return (data[row * width / 8 + column / 8] >> shift) & 0x0f; + } + + __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + } + + __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const + { + int shift = (column & 0x07) * 4; + uint32_t d = data[row * width / 8 + column / 8] >> shift; + items[0] = d & 0x0f; + items[1] = (d >> 4) & 0x0f; + items[2] = (d >> 8) & 0x0f; + items[3] = (d >> 12) & 0x0f; + } +}; + +class MatrixView_q4_column +{ +public: + const uint32_t* data; + const int height; + const int width; + + __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width) + : data(data), height(height), width(width) + { } + + __device__ __forceinline__ int item(int row, int column) const + { + int shift = (row & 0x07) * 4; + return (data[row / 8 * width + column] >> shift) & 0x0f; + } + + __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; } + __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; } +}; + +} // namespace gptq +} // namespace vllm +#endif diff --git a/csrc/quantization/gptq/q_gemm.cu b/csrc/quantization/gptq/q_gemm.cu new file mode 100644 index 0000000000000..eb0d75f1293c4 --- /dev/null +++ b/csrc/quantization/gptq/q_gemm.cu @@ -0,0 +1,869 @@ +/* +Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa +*/ + +#include +#include + +#include +#include +#include +#include +#include + +#include "compat.cuh" +#include "matrix_view.cuh" +#include "qdq_4.cuh" + +namespace vllm { +namespace gptq { + +#define BLOCK_KN_SIZE 128 +#define BLOCK_M_SIZE_MAX 8 +#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32) +#define MAX_Q_GEMM_ROWS 50 +#define MAX_ALT_GEMM_ROWS 8 +#define THREADS_X 32 +#define THREADS_Y 32 +#define DIVIDE(x, size) (((x) + (size) - 1) / (size)) + +#if defined(USE_ROCM) +#include +__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t handle, + hipblasOperation_t transA, + hipblasOperation_t transB, + int m, + int n, + int k, + const half* alpha, + const half* AP, + int lda, + const half* BP, + int ldb, + const half* beta, + half* CP, + int ldc) { + return hipblasHgemm(handle, transA, transB, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(AP), lda, + reinterpret_cast(BP), ldb, + reinterpret_cast(beta), + reinterpret_cast(CP), ldc); +} +#define hipblasHgemm __compat_hipblasHgemm + +// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS. +#define rocblas_operation_none HIPBLAS_OP_N +#define rocblas_hgemm __compat_hipblasHgemm +#endif + +__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __hadd2(result, g_result); +} + +__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr) +{ + half2 result = {}; + const half2* a2_ptr = (const half2*)a_ptr; + #pragma unroll + for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result); + return __half2float(__low2half(result)) + __half2float(__high2half(result)); +} + +typedef void (*fp_gemm_half_q_half_gptq_kernel) +( + const half*, + const uint32_t*, + const uint32_t*, + const half*, + half*, + const int, + const int, + const int, + const int, + const int* +); + +template +__global__ void gemm_half_q_half_gptq_kernel +( + const half* __restrict__ a, + const uint32_t* __restrict__ b_q_weight, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + half* __restrict__ c, + const int size_m, + const int size_n, + const int size_k, + const int groups, + const int* __restrict__ b_q_perm +) +{ + MatrixView_half a_(a, size_m, size_k); + MatrixView_half_rw c_(c, size_m, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int t = threadIdx.x; + + // Block + int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4; + int offset_m = blockIdx.y * m_count; + int offset_k = blockIdx.z * BLOCK_KN_SIZE; + + int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n); + int end_m = min(offset_m + m_count, size_m); + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + int n = offset_n + t * 4; + + // Preload block_a + __shared__ half block_a[m_count][BLOCK_KN_SIZE]; + + if (offset_k + t < end_k) + { + for (int m = 0; m < m_count; ++m) + { + const half* a_ptr = a_.item_ptr(offset_m + m, 0); + half* block_a_ptr = block_a[m]; + + half a0; + if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]]; + else a0 = a_ptr[offset_k + t]; + block_a_ptr[t] = a0; + } + } + + // Zero output + if (n >= size_n) return; + + if (blockIdx.z == 0) + { + for (int m = 0; m < m_count; m++) + *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0; + } + + __syncthreads(); + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // a, b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + const half* a_ptr = &block_a[0][0]; + int a_stride = BLOCK_KN_SIZE; + + // Initial group + int zeros[4]; + float scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + // Column result + float block_c[m_count][4] = {}; + + // Dequantize and multiply + int k = offset_k; + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_f(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + #pragma unroll + for (int j = 0; j < 4; j++) + { + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + half2 dq[4][4]; + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + #pragma unroll + for (int m = 0; m < m_count; m++) + { + block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]); + block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]); + block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]); + block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]); + } + + b_ptr += size_n; + a_ptr += 8; + } + + k += 32; + } + + for (int m = 0; m < m_count; m++) + { + half2 *out = (half2*) c_.item_ptr(offset_m + m, n); + half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1])); + half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3])); + atomicAdd(out , result01); + atomicAdd(out + 1, result23); + } +} + + +fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count) +{ + #if BLOCK_M_SIZE_MAX >= 1 + if (m_count == 1) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 2 + if (m_count == 2) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 3 + if (m_count == 3) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 4 + if (m_count == 4) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 5 + if (m_count == 5) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 6 + if (m_count == 6) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 7 + if (m_count == 7) return gemm_half_q_half_gptq_kernel; + #endif + #if BLOCK_M_SIZE_MAX >= 8 + if (m_count == 8) return gemm_half_q_half_gptq_kernel; + #endif + return NULL; +} + + +void gemm_half_q_half_cuda_part +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* c, + int size_m, + int size_n, + int size_k, + int m_count, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4); + gridDim.y = DIVIDE(size_m, m_count); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); + + kernel<<>> + ( + a, + b_q_weight, + b_gptq_qzeros, + b_gptq_scales, + c, + size_m, + size_n, + size_k, + groups, + b_q_perm + ); +} + + +__global__ void reconstruct_exllama_kernel +( + const uint32_t* __restrict__ b_q_weight, + const int* __restrict__ b_q_perm, + const uint32_t* __restrict__ b_gptq_qzeros, + const half* __restrict__ b_gptq_scales, + const int size_k, + const int size_n, + const int groups, + half* __restrict__ b +) +{ + MatrixView_half_rw b_(b, size_k, size_n); + MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n); + MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n); + + int offset_k = BLOCK_KN_SIZE * blockIdx.y; + int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4; + + int end_k = min(offset_k + BLOCK_KN_SIZE, size_k); + + // Preload remapping table + __shared__ int perm[BLOCK_KN_SIZE]; + int t = threadIdx.x; + + if (b_q_perm) + { + if (offset_k + t < size_k) + perm[t] = b_q_perm[offset_k + t]; + } + + // Column + int n = offset_n + t * 4; + if (n >= size_n) return; + + // Find initial group + int groupsize = size_k / groups; + int group = offset_k / groupsize; + int nextgroup = offset_k + groupsize; + + // b offset + int qk = offset_k / (32 / 4); + + const uint32_t* b_ptr = b_q_weight + qk * size_n + n; + + // Initial zeros/scale + int zeros[4]; + half2 scales[4]; + half2 z1z16[4][2]; + half2 y1y16[4][2]; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + + __syncthreads(); + + int k = offset_k; + int lk = 0; + + while (k < end_k) + { + if (k == nextgroup) + { + group++; + nextgroup += groupsize; + b_gptq_qzeros_.item4(zeros, group, n); + b_gptq_scales_.item4_h2(scales, group, n); + dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]); + dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]); + dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]); + dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]); + } + + for (int p = 0; p < 4; p++) + { + half2 dq[4][4]; + const int4* b_ptr4 = (int4*) b_ptr; + int4 load_int4 = *b_ptr4; + + dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false); + dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false); + dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false); + dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false); + + b_ptr += size_n; + //half* dqh = (half*)dq; + if (b_q_perm) + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + else + { + for (int j = 0; j < 4; j++) + { + for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]); + b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j])); + b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j])); + } + } + } + k += 32; + } +} + + +void reconstruct_exllama +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_q_perm, + half* out, + int height, + int width, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + + reconstruct_exllama_kernel<<>> + ( + b_q_weight, + b_q_perm, + b_gptq_qzeros, + b_gptq_scales, + height, + width, + groups, + out + ); +} + + +__global__ void gemm_half_q_half_alt_kernel( + const half2* __restrict__ vec, + const uint32_t* __restrict__ mat, + half* __restrict__ mul, + const half* __restrict__ scales, + const uint32_t* __restrict__ zeros, + const int* __restrict__ g_idx, + int batch, + int height, + int width +) +{ + int zero_width = width / 8; + int vec_height = height * 4; + const int blockwidth2 = BLOCK_KN_SIZE / 2; + int b = blockIdx.y * BLOCK_M_SIZE_MAX; + int b_end = min(BLOCK_M_SIZE_MAX, batch - b); + int h = BLOCK_KN_SIZE * blockIdx.z / 8; + int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4; + int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + + __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2]; + if (threadIdx.x < h_end) { + for (int m = 0; m < b_end; ++m) { + blockvec[m][threadIdx.x] = + vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 + + threadIdx.x]; + } + } + + __shared__ half2 deq2[256][8]; + int val = threadIdx.x / 8; + int off = threadIdx.x % 8; + for (; val < 256; val += BLOCK_KN_SIZE / 8) { + deq2[val][off] = __halves2half2( + __int2half_rn(val & 0xF), __int2half_rn(val >> 4) + ); + } + + if (blockIdx.z == 0) + { + for (int m = 0; m < b_end; m++) + mul[(b + m) * width + w] = __int2half_rn(0); + } + __syncthreads(); + + int i = width * h + w; + int g_h = h * 8; + int k = 0; + int z_w = w / 8; + int z_mod = (w % 8) * 4; + half2 res2; + half res[BLOCK_M_SIZE_MAX] = {}; + + unsigned int tmp; + while (k < h_end) { + tmp = mat[i]; + half2 scales_tmp[4]; + half2 zeros_tmp[4]; + for (int tmp_k = 0; tmp_k < 4; tmp_k++) { + int g = g_idx[g_h + (k + tmp_k) * 2]; + int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1]; + half scale_f = scales[g * width + w]; + half scale_f2 = scales[g2 * width + w]; + half2 scale = __halves2half2(scale_f, scale_f2); + half2 zero = __halves2half2( + __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)), + __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1)) + ); + scales_tmp[tmp_k] = scale; + zeros_tmp[tmp_k] = zero; + } + for (int m = 0; m < b_end; m++) { +#ifndef USE_ROCM + res2 = {}; +#else + res2.x = __half_as_ushort(__float2half(0)); + res2.y = __half_as_ushort(__float2half(0)); +#endif + res2 = __hfma2(__hfma2(deq2[(tmp >> 0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); + res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); +#ifndef USE_ROCM + res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); +#else + res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); +#endif + } + i += width; + k += 4; + } + for (int m = 0; m < b_end; m++) { + atomicAdd(&mul[(b + m) * width + w], res[m]); + } +} + + +void gemm_half_q_half_alt +( + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + int size_m, + int size_n, + int size_k +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + blockDim.z = 1; + gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE); + gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); + gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); + + gemm_half_q_half_alt_kernel<<>> + ( + (const half2*) a, + b_q_weight, + c, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + size_m, + size_k / 8, + size_n + ); +} + + +__global__ void reconstruct_gptq_kernel +( + const uint32_t* __restrict__ w, + const half* __restrict__ w_scales, + const uint32_t* __restrict__ w_zeros, + const int* __restrict__ g_idx, + const int height, + const int width, + const int group, + half* __restrict__ out +) +{ + // Start of block + + int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x; + int row = blockIdx.y * 8; + if (column >= width) return; + + // Views + + MatrixView_q4_column w_(w, height, width); + MatrixView_half_rw out_(out, height, width); + MatrixView_half w_scales_(w_scales, group, width); + MatrixView_q4_row w_zeros_(w_zeros, group, width); + + uint32_t w_read = w_.item_uint32_t(row, column); + half* out_ptr = out_.item_ptr(row, column); + + #pragma unroll + for (int s = 0; s < 32; s += 4) + { + int group = g_idx[row + s / 4]; + half w_scale = w_scales_.item(group, column); + uint32_t w_zero = w_zeros_.item(group, column) + 1; + half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale); + *out_ptr = w_item; out_ptr += out_.width; + } +} + + +void reconstruct_gptq +( + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* out, + int height, + int width, + int groups +) +{ + dim3 blockDim, gridDim; + blockDim.x = BLOCK_KN_SIZE; + blockDim.y = 1; + gridDim.y = DIVIDE(height, 8); + gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); + reconstruct_gptq_kernel<<>> + ( + b_q_weight, + b_gptq_scales, + b_gptq_qzeros, + b_g_idx, + height, + width, + groups, + out + ); +} + + +void gemm_half_q_half_cuda +( + cublasHandle_t cublas_handle, + const half* a, + const uint32_t* b_q_weight, + const uint32_t* b_gptq_qzeros, + const half* b_gptq_scales, + const int* b_g_idx, + half* c, + half* temp_dq, + int size_m, + int size_n, + int size_k, + int groups, + bool use_exllama +) +{ + if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) { + // Reconstruct FP16 matrix, then cuBLAS + if (use_exllama) { + reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq, + size_k, size_n, groups); + } + else + { + reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + temp_dq, size_k, size_n, groups); + } + + const half alpha = __float2half(1.0f); + const half beta = __float2half(0.0f); + cublasHgemm(cublas_handle, + CUBLAS_OP_N, + CUBLAS_OP_N, + size_n, size_m, size_k, + &alpha, temp_dq, size_n, + a, size_k, + &beta, c, size_n); + } + else if (use_exllama) + { + // Quantized matmul + int max_chunks = size_m / BLOCK_M_SIZE_MAX; + int last_chunk = max_chunks * BLOCK_M_SIZE_MAX; + int last_chunk_size = size_m - last_chunk; + + if (max_chunks) + { + gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, + groups); + } + + if (last_chunk_size) + { + gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros, + b_gptq_scales, b_g_idx, c + last_chunk * size_n, + last_chunk_size, size_n, size_k, last_chunk_size, + groups); + } + } + else + { + gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, + c, size_m, size_n, size_k); + } +} + + +__global__ void shuffle_kernel +( + uint32_t* __restrict__ b_q_weight, + const int size_k, + const int size_n +) +{ + int n = blockIdx.x * THREADS_X + threadIdx.x; + if (n >= size_n) return; + int k = 0; + uint32_t* b_ptr = b_q_weight + n; + while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k += 8; } +} + + +__global__ void make_sequential_kernel +( + const uint32_t* __restrict__ w, + uint32_t* __restrict__ w_new, + const int* __restrict__ q_perm, + const int w_height, + const int w_width +) +{ + const uint64_t* w2 = (uint64_t*) w; + uint64_t* w_new2 = (uint64_t*) w_new; + int w2_stride = w_width >> 1; + int w2_column = THREADS_X * blockIdx.x + threadIdx.x; + if (w2_column >= w2_stride) return; + int w_new2_row = blockIdx.y; + int q_perm_idx = w_new2_row << 3; + uint64_t dst = 0; + + #pragma unroll + for (int i = 0; i < 8; i++) + { + int source_row = q_perm[q_perm_idx++]; + + int w2_row = source_row >> 3; + int w2_subrow = source_row & 0x07; + int w2_row_shift = w2_subrow << 2; + int wnew2_row_shift = i << 2; + + uint64_t src = w2[w2_row * w2_stride + w2_column]; + src >>= w2_row_shift; + src &= 0x0000000f0000000f; + src <<= wnew2_row_shift; + dst |= src; + } + w_new2[w_new2_row * w2_stride + w2_column] = dst; +} + + +void shuffle_exllama_weight +( + uint32_t* q_weight, + int* q_perm, + int height, + int width +) +{ + if (q_perm) + { + uint32_t* new_qweight = NULL; + cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t)); + + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = height / 8; + + make_sequential_kernel<<>> + ( + q_weight, + new_qweight, + q_perm, + height / 8, + width + ); + // Replace qweights + cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice); + // Cleanup + cudaDeviceSynchronize(); + cudaFree(new_qweight); + } + dim3 blockDim, gridDim; + blockDim.x = THREADS_X; + blockDim.y = 1; + gridDim.x = DIVIDE(width, THREADS_X); + gridDim.y = 1; + shuffle_kernel<<>>(q_weight, height, width); +} + +} // namespace gptq +} // namespace vllm + +torch::Tensor gptq_gemm +( + torch::Tensor a, + torch::Tensor b_q_weight, + torch::Tensor b_gptq_qzeros, + torch::Tensor b_gptq_scales, + torch::Tensor b_g_idx, + bool use_exllama +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options); + at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options); + + vllm::gptq::gemm_half_q_half_cuda + ( + at::cuda::getCurrentCUDABlasHandle(), + (const half*) a.data_ptr(), + (const uint32_t*) b_q_weight.data_ptr(), + (const uint32_t*)b_gptq_qzeros.data_ptr(), + (const half*) b_gptq_scales.data_ptr(), + b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(), + (half*) c.data_ptr(), + (half*) temp_dq.data_ptr(), + c.size(0), // m + c.size(1), // n + a.size(1), // k + b_gptq_qzeros.size(0), // group number + use_exllama + ); + return c; +} + +void gptq_shuffle +( + torch::Tensor q_weight, + torch::Tensor q_perm +) +{ + const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight)); + vllm::gptq::shuffle_exllama_weight( + (uint32_t*) q_weight.data_ptr(), + q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(), + q_weight.size(0) * 8, + q_weight.size(1) + ); +} diff --git a/csrc/quantization/gptq/qdq_4.cuh b/csrc/quantization/gptq/qdq_4.cuh new file mode 100644 index 0000000000000..cfc4635a22c1d --- /dev/null +++ b/csrc/quantization/gptq/qdq_4.cuh @@ -0,0 +1,235 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_4_cuh +#define _qdq_4_cuh + +#include "qdq_util.cuh" + +namespace vllm { +namespace gptq { +// Permutation: +// +// 77775555 33331111 66664444 22220000 + +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ + uint32_t qa = q[0]; + uint32_t qb = 0; + + #pragma unroll + for (int i = 0; i < 4; i++) + { + uint32_t qa0 = qa & 0x0f; + uint32_t qa1 = (qa & 0xf0) >> 4; + qa >>= 8; + qb |= (qa1 << (i * 4 + 16)); + qb |= (qa0 << (i * 4)); + } + q[0] = qb; +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + const uint32_t c0 = 0x64006400; + const half y16_ = __float2half_rn(1.0f / 16.0f); + const half2 y16 = __halves2half2(y16_, y16_); + const half z1_ = __float2half_rn(-1024.0f - 8.0f); + const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f); + const half2 z1 = __halves2half2(z1_, z1_); + const half2 z16 = __halves2half2(z16_, z16_); + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024 + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024 + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024 + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024 + + dq[0] = __hadd2(q0.as_half2, z1); + dq[1] = __hfma2(q1.as_half2, y16, z16); + dq[2] = __hadd2(q2.as_half2, z1); + dq[3] = __hfma2(q3.as_half2, y16, z16); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1z16)[2], + half2 (&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + half2 scale2 = __half2half2(scale); + + z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half)); + z1z16[1] = __hmul2(scale2, __half2half2(z16)); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __hmul2(scale2, __half2half2(y1)); + y1y16[1] = __hmul2(scale2, __half2half2(y16)); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1z16)[2], + half2(&y1y16)[2] +) +{ + half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero); + half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero)); + + z1z16[0] = __half2half2(z1.as_half); + z1z16[1] = __half2half2(z16); + + const half y1 = __float2half_rn(1.0f); + const half y16 = __float2half_rn(1.0f / 16.0f); + + y1y16[0] = __half2half2(y1); + y1y16[1] = __half2half2(y16); +} + + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1z16)[2], + half2 (&y1y16)[2], + int stride, + bool scaled +) +{ + const uint32_t c0 = 0x64006400; + + uint32_t qa = q_0; + half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 ) + half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 ) + qa >>= 8; + half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 ) + half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 ) + + if (scaled) + { + dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s) + dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]); + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); + } + else + { + dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z ) + dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z ) + dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z ) + dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) + } +} +} // namespace gptq +} // namespace vllm + +#else + +namespace vllm { +namespace gptq { +__forceinline__ __device__ void shuffle_4bit_8 +( + uint32_t* q, + int stride +) +{ +} + +__forceinline__ __device__ void dequant_4bit_8 +( + const uint32_t q_0, + half2 (&dq)[4], + int stride +) +{ + half dqh[8]; + for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8); + + for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale +( + const uint32_t zero, + const half scale, + half2 (&z1)[2], + half2 (&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z = __hmul(z, scale); + z1[0] = __half2half2(z); + y1[0] = __half2half2(scale); +} + +__forceinline__ __device__ void dequant_4bit_8_prep_zero +( + const uint32_t zero, + half2(&z1)[2], + half2(&y1)[2] +) +{ + half z = __int2half_rn(-((int)zero)); + z1[0] = __half2half2(z); +} + +__forceinline__ __device__ void dequant_4bit_8_gptq +( + const uint32_t q_0, + half2 (&dq)[4], + half2 (&z1)[2], + half2 (&y1)[2], + int stride, + bool scaled +) +{ + half2 dqh2[8]; + + uint32_t qa = q_0; + for (int i = 0; i < 4; i++) + { + half d0 = __int2half_rn(qa & 0x0f); qa >>= 4; + half d1 = __int2half_rn(qa & 0x0f); qa >>= 4; + dqh2[i] = __halves2half2(d0, d1); + } + + if (scaled) + { + dq[0] = __hfma2(dqh2[0], y1[0], z1[0]); + dq[1] = __hfma2(dqh2[1], y1[0], z1[0]); + dq[2] = __hfma2(dqh2[2], y1[0], z1[0]); + dq[3] = __hfma2(dqh2[3], y1[0], z1[0]); + } + else + { + dq[0] = __hadd2(dqh2[0], z1[0]); + dq[1] = __hadd2(dqh2[1], z1[0]); + dq[2] = __hadd2(dqh2[2], z1[0]); + dq[3] = __hadd2(dqh2[3], z1[0]); + } +} + +} // namespace gptq +} // namespace vllm + +#endif diff --git a/csrc/quantization/gptq/qdq_util.cuh b/csrc/quantization/gptq/qdq_util.cuh new file mode 100644 index 0000000000000..1722a9aa6cb34 --- /dev/null +++ b/csrc/quantization/gptq/qdq_util.cuh @@ -0,0 +1,60 @@ +/* +Copied from https://github.com/turboderp/exllamav2 +*/ + +#ifndef _qdq_util_cuh +#define _qdq_util_cuh + +namespace vllm { +namespace gptq { + +union half2_uint32 +{ + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} +}; + +union half_uint16 +{ + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} +}; + +// Max_scale premultiplied by 1/256 + +__forceinline__ __device__ half dq_scale(const int qs, const half max_scale) +{ + int qs_i = qs + 1; + half qs_h = __int2half_rn(qs_i * qs_i); + qs_h = __hmul(qs_h, max_scale); + return qs_h; +} + +__forceinline__ __device__ half dq(const int q, const int qzero, const half scale) +{ + return __hmul(__int2half_rn(q - qzero), scale); +} + +__forceinline__ __device__ half dq_ns(const int q, const int qzero) +{ + //return __hsub(__int2half_rn(q), __int2half_rn(qzero)); + return __int2half_rn(q - qzero); +} + +__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask) +{ + return (int)((q >> shift) & mask); +} + +__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask) +{ + return (int)(__funnelshift_rc(q0, q1, shift) & mask); +} + +} // namespace gptq +} // namespace vllm +#endif diff --git a/docs/source/getting_started/amd-installation.rst b/docs/source/getting_started/amd-installation.rst index 512aa051f16d9..181c970e0b2a7 100644 --- a/docs/source/getting_started/amd-installation.rst +++ b/docs/source/getting_started/amd-installation.rst @@ -3,7 +3,7 @@ Installation with ROCm ====================== -vLLM 0.2.x onwards supports model inferencing and serving on AMD GPUs with ROCm. +vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm. At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported. Data types currently supported in ROCm are FP16 and BF16. @@ -29,7 +29,7 @@ Installation options: .. code-block:: console - $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.3 + $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4 $ docker run -it \ --network=host \ --group-add=video \ @@ -70,12 +70,12 @@ You can build and install vLLM from source: - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention +2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention .. code-block:: console - $ pip install xformers==0.0.22.post7 --no-deps - $ bash patch_xformers-0.0.22.post7.rocm.sh + $ pip install xformers==0.0.23 --no-deps + $ bash patch_xformers.rocm.sh 3. Build vLLM. @@ -116,6 +116,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from - `ROCm `_ - `Pytorch `_ +- `hipBLAS `_ 1. Install `flash attention for ROCm `_ @@ -127,12 +128,12 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention. - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`) -2. Setup `xformers==0.0.22.post7` without dependencies, and apply patches to adapt for ROCm flash attention +2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention .. code-block:: console - $ pip install xformers==0.0.22.post7 --no-deps - $ bash patch_xformers-0.0.22.post7.rocm.sh + $ pip install xformers==0.0.23 --no-deps + $ bash patch_xformers.rocm.sh 3. Build vLLM. diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index c189a83a2e6ce..e7a2d0a6f0d03 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -20,7 +20,7 @@ You can install vLLM using pip: .. code-block:: console $ # (Optional) Create a new conda environment. - $ conda create -n myenv python=3.8 -y + $ conda create -n myenv python=3.9 -y $ conda activate myenv $ # Install vLLM with CUDA 12.1. @@ -34,8 +34,9 @@ You can install vLLM using pip: .. code-block:: console $ # Install vLLM with CUDA 11.8. - $ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`). - $ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl + $ export VLLM_VERSION=0.2.4 + $ export PYTHON_VERSION=39 + $ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl $ # Re-install PyTorch with CUDA 11.8. $ pip uninstall torch -y diff --git a/docs/source/index.rst b/docs/source/index.rst index 04af09073a447..816f4f7e2015d 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -30,6 +30,8 @@ vLLM is fast with: * State-of-the-art serving throughput * Efficient management of attention key and value memory with **PagedAttention** * Continuous batching of incoming requests +* Fast model execution with CUDA/HIP graph +* Quantization: `GPTQ `_, `AWQ `_, `SqueezeLLM `_ * Optimized CUDA kernels vLLM is flexible and easy to use with: @@ -39,7 +41,7 @@ vLLM is flexible and easy to use with: * Tensor parallelism support for distributed inference * Streaming outputs * OpenAI-compatible API server -* Support NVIDIA CUDA and AMD ROCm. +* Support NVIDIA GPUs and AMD GPUs For more information, check out the following: diff --git a/docs/source/models/engine_args.rst b/docs/source/models/engine_args.rst index a70c22e9af11a..d89b795149501 100644 --- a/docs/source/models/engine_args.rst +++ b/docs/source/models/engine_args.rst @@ -89,9 +89,11 @@ Below, you can find an explanation of every engine argument for vLLM: CPU swap space size (GiB) per GPU. -.. option:: --gpu-memory-utilization +.. option:: --gpu-memory-utilization - The percentage of GPU memory to be used for the model executor. + The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. + For example, a value of 0.5 would imply 50% GPU memory utilization. + If unspecified, will use the default value of 0.9. .. option:: --max-num-batched-tokens diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index f56d6eaccfddc..71808436dc114 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. + * - :code:`DeciLMForCausalLM` + - DeciLM + - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. @@ -50,6 +53,9 @@ Alongside each architecture, we include some popular models that use it. * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. + * - :code:`MixtralForCausalLM` + - Mixtral-8x7B, Mixtral-8x7B-Instruct + - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc. * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. @@ -57,8 +63,8 @@ Alongside each architecture, we include some popular models that use it. - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. * - :code:`PhiForCausalLM` - - Phi-1.5 - - :code:`microsoft/phi-1_5`, etc. + - Phi + - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. @@ -70,6 +76,9 @@ If your model uses one of the above model architectures, you can seamlessly run Otherwise, please refer to :ref:`Adding a New Model ` for instructions on how to implement support for your model. Alternatively, you can raise an issue on our `GitHub `_ project. +.. note:: + Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. + .. tip:: The easiest way to check if your model is supported is to run the program below: @@ -81,12 +90,17 @@ Alternatively, you can raise an issue on our `GitHub `_ instead of HuggingFace Hub, set an environment variable: .. code-block:: shell $ export VLLM_USE_MODELSCOPE=True + And use with :code:`trust_remote_code=True`. + .. code-block:: python from vllm import LLM @@ -94,5 +108,3 @@ Alternatively, you can raise an issue on our `GitHub = 49.4.0", - "torch >= 2.1.0", + "torch == 2.1.2", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index 38514d273f843..7e7e48a1313e5 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -2,5 +2,5 @@ ninja packaging setuptools>=49.4.0 -torch>=2.1.0 +torch==2.1.2 wheel \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt index c9b212c923a42..cf15292749083 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ # formatting yapf==0.32.0 +toml==0.10.2 ruff==0.1.5 # type checking diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 42ab9d14450d7..81bc19580274c 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -8,9 +8,7 @@ pyarrow # Required for Ray data. sentencepiece # Required for LLaMA tokenizer. numpy tokenizers>=0.15.0 -huggingface_hub<0.18,>=0.16.4 -einops # Required for phi-1_5 -transformers >= 4.34.0 # Required for Mistral. +transformers >= 4.36.0 # Required for Mixtral. fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. diff --git a/requirements.txt b/requirements.txt index 9a27eae80973c..92ba0a716c45c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,9 @@ pandas # Required for Ray data. pyarrow # Required for Ray data. sentencepiece # Required for LLaMA tokenizer. numpy -einops # Required for phi-1_5 -torch >= 2.1.0 -transformers >= 4.34.0 # Required for Mistral. -xformers >= 0.0.22.post7 # Required for CUDA 12.1. +torch == 2.1.2 +transformers >= 4.36.0 # Required for Mixtral. +xformers == 0.0.23.post1 # Required for CUDA 12.1. fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. diff --git a/rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/commonpy_xformers-0.0.23.rocm.patch similarity index 100% rename from rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch rename to rocm_patch/commonpy_xformers-0.0.23.rocm.patch diff --git a/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch b/rocm_patch/flashpy_xformers-0.0.23.rocm.patch similarity index 61% rename from rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch rename to rocm_patch/flashpy_xformers-0.0.23.rocm.patch index 4798f1efd461c..ac846728a7a91 100644 --- a/rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch +++ b/rocm_patch/flashpy_xformers-0.0.23.rocm.patch @@ -1,6 +1,6 @@ ---- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/flash.py 2023-11-29 03:17:03.930103539 +0000 -+++ flash.py 2023-11-28 16:14:25.206128903 +0000 -@@ -31,39 +31,39 @@ +--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000 ++++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000 +@@ -36,44 +36,44 @@ FLASH_VERSION = "0.0.0" try: @@ -15,9 +15,12 @@ - from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention - - FLASH_VERSION = flash_attn.__version__ -- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) -- if flash_ver_parsed < (2, 3): -- raise ImportError("Requires 2.3 for sliding window support") +- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) +- if ( +- flash_ver_parsed != (2, 3, 6) +- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" +- ): +- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") + #try: + # from ... import _C_flashattention # type: ignore[attr-defined] + # from ..._cpp_lib import _build_metadata @@ -29,35 +32,41 @@ + from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention + + FLASH_VERSION = flash_attn.__version__ -+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:2]) -+ # if flash_ver_parsed < (2, 3): -+ # raise ImportError("Requires 2.3 for sliding window support") ++ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3]) ++ # if ( ++ # flash_ver_parsed != (2, 3, 6) ++ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1" ++ # ): ++ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api") # create library so that flash-attn goes through the PyTorch Dispatcher - _flash_lib = torch.library.Library("xformers_flash", "DEF") -+ #_flash_lib = torch.library.Library("xformers_flash", "DEF") - +- - _flash_lib.define( - "flash_fwd(Tensor query, Tensor key, Tensor value, " -- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " +- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " - "int max_seqlen_q, int max_seqlen_k, " - "float p, float softmax_scale, " -- "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" +- "bool is_causal, int window_left, " +- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" - ) -- ++ #_flash_lib = torch.library.Library("xformers_flash", "DEF") + - _flash_lib.define( - "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, " - "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " - "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " - "int max_seqlen_q, int max_seqlen_k, " -- "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" +- "float p, float softmax_scale, bool is_causal, " +- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" - ) + #_flash_lib.define( + # "flash_fwd(Tensor query, Tensor key, Tensor value, " -+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, " ++ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, " + # "int max_seqlen_q, int max_seqlen_k, " + # "float p, float softmax_scale, " -+ # "bool is_causal, int window_size, bool return_softmax) -> (Tensor, Tensor, Tensor)" ++ # "bool is_causal, int window_left, " ++ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)" + #) + + #_flash_lib.define( @@ -65,52 +74,61 @@ + # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, " + # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, " + # "int max_seqlen_q, int max_seqlen_k, " -+ # "float p, float softmax_scale, bool is_causal, int window_size, Tensor rng_state) -> (Tensor, Tensor, Tensor)" ++ # "float p, float softmax_scale, bool is_causal, " ++ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)" + #) def _flash_fwd( query, -@@ -98,8 +98,8 @@ +@@ -111,8 +111,8 @@ p, softmax_scale, is_causal, -- window_size - 1, # window_size_left -- -1, # window_size_right -+ # window_size - 1, # window_size_left -+ # -1, # window_size_right +- window_left, # window_size_left +- window_right, # window_size_right ++ # window_left, # window_size_left ++ # window_right, # window_size_right return_softmax, None, # rng ) -@@ -127,8 +127,8 @@ +@@ -134,15 +134,15 @@ + out, + cu_seq_lens_q, + cu_seq_lens_k, +- seqused_k, ++ # seqused_k, + max_seq_len_q, + max_seq_len_k, + p, softmax_scale, False, is_causal, -- window_size - 1, # window_size_left -- -1, # window_size_right -+ # window_size - 1, # window_size_left -+ # -1, # window_size_right +- window_left, +- window_right, ++ # window_left, ++ # window_right, return_softmax, None, ) -@@ -169,8 +169,8 @@ +@@ -184,8 +184,8 @@ p, softmax_scale, is_causal, -- window_size - 1, # window_size_left -- -1, # window_size_right -+ # window_size - 1, # window_size_left -+ # -1, # window_size_right +- window_left, +- window_right, ++ # window_left, ++ # window_right, None, rng_state, ) -@@ -193,15 +193,15 @@ +@@ -208,15 +208,15 @@ softmax_scale, False, # zero_tensors is_causal, -- window_size - 1, # window_size_left -- -1, # window_size_right -+ # window_size - 1, # window_size_left -+ # -1, # window_size_right +- window_left, +- window_right, ++ # window_left, ++ # window_right, None, rng_state, ) @@ -123,7 +141,7 @@ except ImportError: pass -@@ -348,7 +348,7 @@ +@@ -400,7 +400,7 @@ implementation. """ diff --git a/setup.py b/setup.py index 2f4215d6cb74e..22a6a251e5c23 100644 --- a/setup.py +++ b/setup.py @@ -247,6 +247,7 @@ def get_torch_arch_list() -> Set[str]: "csrc/activation_kernels.cu", "csrc/layernorm_kernels.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu", + "csrc/quantization/gptq/q_gemm.cu", "csrc/cuda_utils_kernels.cu", "csrc/pybind.cpp", ] diff --git a/tests/conftest.py b/tests/conftest.py index 9f0cf5bdc9899..16c04e01d703c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ +import os from typing import List, Optional, Tuple import pytest @@ -7,21 +8,32 @@ from vllm import LLM, SamplingParams from vllm.transformers_utils.tokenizer import get_tokenizer -_TEST_PROMPTS = [ - "vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.", - "Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.", - "Compare and contrast artificial intelligence with human intelligence in terms of processing information.", - "Describe the basic components of a neural network and how it can be trained.", - "Write a short story about a robot that dreams for the first time.", - "Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.", - "Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.", - "Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'", -] +_TEST_PROMPTS = ["prompts/example.txt"] +_LONG_PROMPTS = ["prompts/summary.txt"] + + +def _read_prompts(filename: str) -> str: + prompts = [] + with open(filename, "r") as f: + prompt = f.readline() + prompts.append(prompt) + return prompts @pytest.fixture def example_prompts() -> List[str]: - return _TEST_PROMPTS + prompts = [] + for filename in _TEST_PROMPTS: + prompts += _read_prompts(os.path.join("tests", filename)) + return prompts + + +@pytest.fixture +def example_long_prompts() -> List[str]: + prompts = [] + for filename in _LONG_PROMPTS: + prompts += _read_prompts(os.path.join("tests", filename)) + return prompts _STR_DTYPE_TO_TORCH_DTYPE = { diff --git a/tests/kernels/test_attention.py b/tests/kernels/test_attention.py index a65d4d54d7c82..614b65f82ccbd 100644 --- a/tests/kernels/test_attention.py +++ b/tests/kernels/test_attention.py @@ -131,9 +131,6 @@ def test_paged_attention( assert num_query_heads % num_kv_heads == 0 num_queries_per_kv = num_query_heads // num_kv_heads - head_mapping = torch.repeat_interleave( - torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"), - num_queries_per_kv) alibi_slopes = None if use_alibi: alibi_slopes = torch.randn(num_query_heads, @@ -170,7 +167,7 @@ def test_paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, @@ -202,7 +199,7 @@ def test_paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, block_tables, context_lens, diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py new file mode 100644 index 0000000000000..83316fcb7469d --- /dev/null +++ b/tests/models/test_mistral.py @@ -0,0 +1,37 @@ +"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling. + +Run `pytest tests/models/test_mistral.py --forked`. +""" +import pytest + +MODELS = [ + "mistralai/Mistral-7B-Instruct-v0.1", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [128]) +def test_models( + hf_runner, + vllm_runner, + example_long_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + hf_model = hf_runner(model, dtype=dtype) + hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens) + del hf_model + + vllm_model = vllm_runner(model, dtype=dtype) + vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens) + del vllm_model + + for i in range(len(example_long_prompts)): + hf_output_ids, hf_output_str = hf_outputs[i] + vllm_output_ids, vllm_output_str = vllm_outputs[i] + assert hf_output_str == vllm_output_str, ( + f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}") + assert hf_output_ids == vllm_output_ids, ( + f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 95eabaafec811..518eae201ed32 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -8,6 +8,7 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1", + "Deci/DeciLM-7b", "tiiuae/falcon-7b", "gpt2", "bigcode/tiny_starcoder_py", @@ -15,12 +16,12 @@ "EleutherAI/pythia-70m", "bigscience/bloom-560m", "mosaicml/mpt-7b", - "microsoft/phi-1_5", + "microsoft/phi-2", ] @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [128]) def test_models( hf_runner, diff --git a/tests/prompts/example.txt b/tests/prompts/example.txt new file mode 100644 index 0000000000000..e1b97bc6eee75 --- /dev/null +++ b/tests/prompts/example.txt @@ -0,0 +1,8 @@ +vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. +Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020. +Compare and contrast artificial intelligence with human intelligence in terms of processing information. +Describe the basic components of a neural network and how it can be trained. +Write a short story about a robot that dreams for the first time. +Analyze the impact of the COVID-19 pandemic on global economic structures and future business models. +Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies. +Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.' diff --git a/tests/prompts/summary.txt b/tests/prompts/summary.txt new file mode 100644 index 0000000000000..2f947a264ce93 --- /dev/null +++ b/tests/prompts/summary.txt @@ -0,0 +1 @@ +Subtitles: for our annual races at Knockhill Circuit.Today\'s racing comes from the Porsche Carrera Cup Great Britainand the Legends Cars Elite Cup with JLM.It\'s the latter who get us underway with their first race of the day,and joining me in the commentary box is Paul O\'Neill.First race of the day for the Legends.Jonty Norman has drawn pole position,with Matt Knight alongside.Marcus Pett on Row 2 with Daniel Pooley.Declan Burke is next up, and then Tyler Read, on Row 3.He\'s leading the rookie championship at the moment.Chris Needham on Row 4 with Luke Simmons.Andrew Rogerson and Gareth Sheridan on Row 5.Sixth row, Peter Barrable, with Charlie Budd.Row 7, Jack Parker, fourth in the championship right now.Nick Price is next to him.Will Gibson, who looks like he\'s out of the championship contention now,with Oli Schlup alongside.Then Ben McNeice and Flight Lieutenant Matt Isherwood.Robert Barrable, championship leader, he\'s on Row 10.Then Brent Bowie from Kieran Beattie and Nick Bridgeman.Mike Schlup on Row 12, followed by Ryan McLeish,who won the day overall yesterday.Mark Beaty, Row 13, with Andy Bird.Then it\'s Ben Higgins and Nathan Anthony.Connor Mills and Paul Musselle complete Row 15.And completing the grid is James Newbery.Here we go, with Race number 1 of the day,the final day of the first ever Legends Cars Elite Cup with JLM.And on the front row, it\'s Jonty Norman in grey,Matt Knight in black and gold.Coming from third place on the grid is Marcus Pett,who goes left of shot in the gunmetal carto challenge for the lead.Marcus Pett, the man from Boston in Lincolnshire,goes through into lead position.Very definitely a fancied championship runnerbut hasn\'t quite had the rub of the green this weekend.And they all pile into McIntyre\'s for the first time.And this is where we look for driving standards.James Newbery brakes at the back.He\'s got Paul Musselle immediately in front of him.Those two had an interesting battle yesterdayinvolving a little bit of contact, I think,but they\'re both all right at the moment, as they clear the chicane for the first time.Marcus Pett is away.The difference you\'ll see in Legends Cars racing todayis that for this meeting,the bump drafting that we\'ve seen in the pasthas been ruled out for this round,and it\'s under review for the future.But look at the battle for second position, three wide,as Marcus Pett comes in front of the crowds here.Matt Knight on the inside, Dan Pooley on the outside in 32.Dan Pooley challenging for third. He had a strong day yesterday -he was up in the top ten, which was great to see.The man from March.That third car there, eclipsed at the moment,comes out of the slipstream.Dan repaired his own car after Croft,and that of Kieran Beaty,so I know Kieran wanted to thank him for that. He\'s been working hard.And Pooley side by side with Matt Knight.We\'ve got the 13, Chris Needham car, up there in the mix as well.The three top guys in the...Ryan McLeish getting very sideways there,the Scot in the 71 car.The first time we\'ve seen him on our ITV coverage.He\'s not a guest driver this week.I suppose you could technically call him a guest,but he\'s fully championship registeredand took a splendid win yesterday - overall win and race win.Overall on points.Sorry, Paul, gets a chance to get you in.That\'s Jack Parker!Oh, what\'s happened there?So, this was the start. They\'re all still warming the tyres up,ready for the lights to go green,which they do... around about now.And they get going.And then there was a car, wasn\'t there?Oh, I tell you what, that could\'ve ended up really nastyas it snaked up the grass.Yeah, I\'ll tell you what, the moment when the lights went outwas when Marcus Pett broke ranks.That was a very, very meticulous start from Marcus Pett.The blue car here is Tyler Read, top rookie,who looks like he\'s going down the inside of Daniel Pooley,so he\'s gonna make a space here.So, Dan Pooley has lost second position.It\'s Marcus Pett still out front. Matt Knight...I was saying to the drivers,"Don\'t go away if you\'re in the lead because you won\'t get any coverage." Pett\'s down the road, isn\'t he? Look at the gap he\'s got. Yeah.He\'s got three seconds. It\'s gonna be more than that.What I was quite concerned about was the damp part of the circuitdown at the hairpin, where you need to be down the inside of peopleto get the braking done,but these guys seem to be all respecting...Not track limits, but they\'re respecting each other around usbecause I was quite concerned about coming here,but this is quite synonymous with Legends racing at Knockhill.And look at this now. Knight has got...Look at that. I remember Marcus getting his first race win,which was at Snetterton years ago.It\'s always fantastic to see a first-time winner.And Tyler Read is giving him a great workout.Matt Knight back in third.It\'s between the top two at the moment. Oh! Tyler goes wide.He\'s throwing the car around.Marcus Pett, looking a little bit smoother in the 79,was very frustrated yesterday, but Read\'s all over him.Yeah, but look at this now.You\'ve got third, fourth, fifth and sixth.This is gonna be absolutely spectacular!Tyler Read\'s gone! What\'s gone on?!Oh, has the Treherne engine gone pop? He\'s lost a lot of ground.Is he gonna come back into it?Now it\'s Knight having a go on the outside line again.Matt Knight can\'t do it. He runs out wide.Oli Schlup\'s coming through.Schlup hasn\'t had a win yet in Legends cars, so he\'s queueing up.They\'re coming onto the last lap.This could be a key moment for Oli Schlup,who\'s back in third in the K-Seal car.Across the line.Marcus Pett soaking up the pressure brilliantly so far.But does he need to be in front as they come onto the last lap?I don\'t know, but I think Read must have missed a gear,as someone\'s exited stage left.Look at that, back in the mix!It\'s now six for the lead. Can Pett hold on?Championship leader Robert Barrablehas come through from about three rows from the back,and he\'s at the back of the train.Barrable here is gonna extend his championship leadand start towards the front of the grid for Race 2.Barrable, the Irishman, he\'s there.The white car with the green and orange stripeson the nose cone of the car.But it\'s Marcus Pett out front at the moment... Oh!Matt Isherwood\'s rejoined at the back in the black and green.Isherwood\'s got back at them. Matt Knight\'s having a go.Along Railway Straight.Schlup would normally bump draft him. He can\'t do that on the rules.But look at Marcus Pett.Fairly wide-ish line in. Good defensive stuff from Pett.It\'s all about the run up to the hill now.And Marcus Pett is gonna take the win, I think.Here they come, up towards the line. Pett from Matt Knight.It\'s gonna be Matt\'s best resultin the Legends Cars National Championship.Third position goes to Oli Schlup, who is delighted with that.Then it was Tyler Read. Great race from him.Robert Barrable, though...Barrable, from 19th on the grid, without bump drafting,comes through into fifth placeahead of the excellent recovery from Flight Lieutenant Matt Isherwood.Dan Pooley seventh. Another great result for Dan Pooley.So much to take away from those last racing laps.Oh, and those last four lapsis exactly why we have these Legends on the TOCA package.That was exceptional.Marcus Pett looked like a dead cert not to finish first,but congratulations to you. That was brilliant.But Barrable, after exiting stage leftwhen he caught the back of everybody and got right up there...There\'s too much to talk about. Let\'s just talk about this guy.Pett, you are a legend, mate. Well done.Cracking. It is a lad and dad.Literally, Marcus and his dad, Robert, they look after the car.It is lad and dad. We hear that mentioned in other formulas,but genuinely, that is all it is.It is very difficult for drivers like that and teams like thatto come and race on this stage.It is a big thing. And he\'s such a smashing guy.And his dad as well. Really delighted with the win.Super stuff by Matt Knight. brilliant from Oli Schlup.Fantastic as well from Tyler Read.And on the front row,it\'s Jonty Norman in grey, Matt Knight in black and gold.Coming from third place on the grid is Marcus Pett.Bit of a shemozzle at the back.Two cars hooked up, which is not good to see.Oh, has the Treherne engine gone pop? He\'s lost a lot of ground.Now it\'s Knight having a go on the outside line again.Matt Knight can\'t do it. He runs out wide.Oli Schlup\'s coming through.And Marcus Pett is gonna take the win, I think. Pett from Matt Knight. It\'s gonna be Matt\'s best resultin the Legends Cars National Championship.Here\'s how they finished.Marcus Pett takes another win in the Legends Cars Elite Cup with JLM.READS INFOREADS INFOREADS INFOREADS INFOREADS INFOREADS INFOProblems in that race for Ryan McLeish, yesterday\'s winner.Charlie Budd in 30th.And the other driver having problems, obviously,from that first stoppage, Brent Bowie.Marcus, that was a tough racebecause there was a red flag in the middle of it.Actually, the first bit, you got away,but it was a full reset,and pressure throughout to the chequered flag.Yeah, definitely.We had an ideal start and managed to build up a lead early on,which was great, but when you\'re in that position,the last thing you want to see is a red flag. iming line at the end of lap one.So, Gus Burton leads the way.Big, big dive by Foster on the inside,to go back ahead of Wylie.He goes off the road and back on again.He\'s all sideways.And diving up on the outside line comes Ryan Ratcliffe.Wylie here battling with one of the Pro category cars,but behind him, all the Pro-Am opposition crawling all over him.Well, that was dramatic stuff, wasn\'t it?Round the outside of Turn 1, put Harry Foster in the wrong place.That was Max Bird going wide, number 44, the pink and blue car.So that\'s just haemorrhaged places in Pro-Am.And he\'s the... Oh, a puncture.There\'s somebody with a puncture. Is that Angus Whiteside? Possibly.Let\'s see.I think it is. And you\'ve got this damp patch on the inside,on the braking there, just at the final into the hairpin.This has been a dramatic start to this race for Porsches.Absolutely right.Coming up over the timing line, Gus Burton leads the way.Nine tenths of a second to the good.Big effort being made by Jason Lockwoodin the yellow and orange car in the background, look,to try to get up the inside line, then diving down towards Turn 1.Goes ahead of Oliver White, the very experienced Formula 4 champion.In the silver car, Oliver White, back into Carrera Cup.Remember, he did a full season last year.Good to have him back on the grid.As the cars clamber their way up over the kerb,through the chicane.But Gus Burton saying to everybody, "I\'m back." He leads.Yeah, a dramatic way for Gus Burton to come back to this championship.Remember, he started this year with Century Motorsport but then ducked out of the championship prior to Thruxton.He\'s still competing in the Supercup series with Fach Auto.As there in the pits, getting a new rear left tyre, is Angus Whiteside.But Gus Burton absolutely on it.Very quick in testing here during the week.They tested on Wednesday and on Friday.Gus Burton very quick in...And he\'s really enjoying life now.Back in the championship with the NAPA Racing UK supportand with a different team, Nick Tandy\'s JTR outfit.And he\'s done the fastest lap of the race, as he leads.He is not in the championship fight, but he wants to win races.Car off. It\'s Max Bird again.So, Max Bird, the Pro-Am championship leader,three times a winner in class this year,off the road and back on again.But that\'s gonna throw him way, way down the order.This race is going from bad to worse for him.It\'s just completely unfolded for poor Max Bird.That\'s the curse of having our camera on board, I think,but it\'s just unravelled after a great qualifying.Now, you were talking about Gus Burton\'s start,and it is going to be investigated after the race.OK. Well, it\'ll take a lot of camera action analysisto look at it. This is on board with Bird.Round Turn 1.All OK there. Very close... Goes to the outside.That\'s dangerous cos you can get knocked wide,and that\'s exactly what happens.The man he was trying to get past, Josh Stanton,who spent last night trackside at Cowdenbeath watching stock cars.I\'m not suggesting for a moment he\'s learnt how to defend,but he was enjoying himself, watching a different form of racing.I think all the best people were at Cowdenbeath, weren\'t they?Nick Tandy was, and others. Oh!As there, absolutely on the giddy limit, is Harry Foster,making his way in sixth place.Down towards the hairpin.He\'s dropped back from that leading quintet,but he\'s keeping Ross Wylie at bay.Ross Wylie, there, creeping into shot, leads now Pro-Amahead of Ryan Ratcliffe.And Josh Stanton is third in Pro-Am, last year\'s Am champion.Yeah, and Ross Wylie the only Scottish driver in the race. A lot of support for him,from local sponsors as well as the public.Buoyed by his recent run at the British Grand Prix at Supercup,and thoroughly loving racing at his home circuit, Ross Wylie.Track is nicely dry.There was some threats of possible rain.We had rain yesterday during qualifying.They actually only got one runon their slick tyres yesterday in qualifyingbefore the rain arrived, and that set the grid.So, Gus Burton\'s lead growing all the time.1.3 seconds now, that margin over Adam Smalley.As Max Bird tries to fight back in Pro-Am.Gets up the inside line there.So, that puts him ahead of David Stirling.So, he\'s split the second and third Am fightas he tries to recover.Yeah, but he\'s lost a lot of ground with that momenton the outside of McIntyre\'s.It\'s getting a lot darker overhead at Knockhill,even though there is a break in the cloud.A big effort there from the lapped car of Angus Whiteside.He\'s not fighting for position, he\'s trying to unlap himself.But just wonder whether we might get so f the right of McIntyre\'s,up towards Butcher\'s, then the chicane.And looking to try and maintain this 100% recordin the Team Parker Racing-run car in Am.Yeah. David Fairbrother in second place,but some 11 seconds behind in the Am category.But he will take another podium.His second in the championship, too, Justin Sherwood.The race leader 2.5 seconds to the good, Gus Burton.Other battles still to be resolved.What\'s going on in Pro-Am? Ross Wylie leads.He\'s fallen back behind Josh Malin overall. That was the move.Josh Malin through on the inside at the hairpin.Ross Wylie, in a sense, content to let that happen - gave him room -because that\'s not his battle, but what it does meanis that Ryan Ratcliffe, his class rival,is directly behind him.This is William Aspin versus Max Bird for sixth in Pro-Am.And a very determined Max Bird goes one side, get his nose chopped off.Will Aspin, the man from Florence, defends on the other side.They\'re absolutely together, almost touching.Here comes Max Bird.Oh, but he can\'t find a way through there.Angus Whiteside is now getting in on the act.Round the outside goes Max Bird, but they both take it wide,and through goes Angus Whiteside on the inside.Doesn\'t affect the race order.Whiteside unlaps himself from those two cars. Will Aspin stays ahead. Max Bird tries to fight back.Down towards Duffus Dip.Ignore the car in the lead of this battle packbecause it\'s not on the lead lap.But then Aspin under attack.Max Bird tries to get up alongside himfor the inside line coming into McIntyre\'s.He is on the inside, and he is ahead now.Yeah. And behind him, there was a car completely off on the grassafter Turn 1.So I do think that section of the track is a little slippery,for whatever reason. Maybe it just hasn\'t quite dried out.But this was a great battle between Max Bird and Will Aspin.So, drivers, in one or two cases,setting personal best lap times last time around,suggesting that the road is drying still.The cars are getting lighter on fuel anyway.Down at the hairpin comes the recovering Max Bird,as over the line goes Harry Foster, being chased by Josh Malin.Josh up into seventh overall.A top six could be on - he\'s only half a second back.Yeah, it\'s not far away, is it?And still plenty of laps left in this race.You probably noticed through that Turn 1the drivers are not riding the big kerb on the inside.That\'s because it\'s a new kerb that\'s been put in, actually,to raise the level of the kerbback to the level it was before the track got resurfaced twice.But with the resurfacing twice,it had raised the track surface by 80mm,and the drivers found they were, in previous years,able to use that kerb.Now? Not so much.So, there going through is Oliver Wight in the silver car,down towards the hairpin.Jason Lockwood ahead of him.Jason for EXCELR8, and he is running in 12 at the moment,which is potentially going to be his best finish of the year.It\'s been a tough season for Jason,but he could be on for his best results thus far.However, Gus Burton has rather dominated this,and look at the gap that he\'s pulled.Adam Smalley, as we suggested earlier,might be thinking about banking points,but it doesn\'t look as though he\'s been able to do anything at allabout that JTR car ahead.No. In terms of pure speed,he hasn\'t been able to threaten Gus Burton at all, has he? Gus Burton has led every race.As he\'s now passing David Fairbrotherat the back of the field.But he\'s had this race under control.But unfortunately, he\'s got this investigation after the racefor a possible false start hanging over him.And if, if, if anything is found, and it\'s a false start,normally that\'s a ten-second penalty,and he\'s not ten seconds ahead,so there is gonna be a postscript to this story, that\'s for sure.Now, this is Henry Dawes, Ollie Jacksoncoming through the chicane.Dawes goes wide, goes through the gravel,goes over the grass, loses a place,gets it all sideways, but just about saves it by the end of the straight.Yeah, nearly lost it on the wet grass.Oh. Harry Foster.This is passing David Fairbrother again, further back.So, this is Smalley versus Matty Graham for second place.So, this gap has come r. \n\n Your task is to create long detailed paragraph-by-paragraph summary. Detailed paragraph-by-paragraph summary of the text above: \ No newline at end of file diff --git a/vllm/__init__.py b/vllm/__init__.py index 67a662ced9aff..e5cd1c2f3334b 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -8,7 +8,7 @@ from vllm.outputs import CompletionOutput, RequestOutput from vllm.sampling_params import SamplingParams -__version__ = "0.2.3" +__version__ = "0.2.6" __all__ = [ "LLM", diff --git a/vllm/config.py b/vllm/config.py index 6d00049b6ab78..9facbd6664075 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -50,6 +50,12 @@ class ModelConfig: output). If None, will be derived from the model. quantization: Quantization method that was used to quantize the model weights. If None, we assume the model weights are not quantized. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. """ def __init__( @@ -66,6 +72,8 @@ def __init__( tokenizer_revision: Optional[str] = None, max_model_len: Optional[int] = None, quantization: Optional[str] = None, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, ) -> None: self.model = model self.tokenizer = tokenizer @@ -77,6 +85,8 @@ def __init__( self.revision = revision self.tokenizer_revision = tokenizer_revision self.quantization = quantization + self.enforce_eager = enforce_eager + self.max_context_len_to_capture = max_context_len_to_capture if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, @@ -96,30 +106,34 @@ def __init__( self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() + self._verify_cuda_graph() def _verify_load_format(self) -> None: load_format = self.load_format.lower() supported_load_format = [ "auto", "pt", "safetensors", "npcache", "dummy" ] - rocm_not_supported_load_format = ["safetensors"] + rocm_not_supported_load_format = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") - if is_hip(): - if load_format in ["safetensors"]: - rocm_supported_load_format = [ - f for f in supported_load_format - if (f not in rocm_not_supported_load_format) - ] - raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " - f"Supported load format are " - f"{rocm_supported_load_format}") - # Force ROCm to load from pt weights if nothing specific is set - if load_format == "auto": - load_format = "pt" + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format \'{load_format}\' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}") + + # TODO: Remove this check once HF updates the pt weights of Mixtral. + architectures = getattr(self.hf_config, "architectures", []) + if "MixtralForCausalLM" in architectures and load_format == "pt": + raise ValueError( + "Currently, the 'pt' format is not supported for Mixtral. " + "Please use the 'safetensors' format instead. ") self.load_format = load_format def _verify_tokenizer_mode(self) -> None: @@ -131,7 +145,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: - supported_quantization = ["awq", "squeezellm"] + supported_quantization = ["awq", "gptq", "squeezellm"] rocm_not_supported_quantization = ["awq"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -163,6 +177,18 @@ def _verify_quantization(self) -> None: "optimized yet. The speed can be slower than " "non-quantized models.") + def _verify_cuda_graph(self) -> None: + if self.max_context_len_to_capture is None: + self.max_context_len_to_capture = self.max_model_len + self.max_context_len_to_capture = min(self.max_context_len_to_capture, + self.max_model_len) + if (self.quantization in ["gptq", "squeezellm"] + and not self.enforce_eager): + # Related issue: https://github.com/vllm-project/vllm/issues/2147 + logger.warning(f"{self.quantization} does not support CUDA graph " + "yet. Disabling CUDA graph.") + self.enforce_eager = True + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 0dae1613690e5..07e20a74d929b 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -33,6 +33,8 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + enforce_eager: bool = False + max_context_len_to_capture: int = 8192 enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 @@ -160,11 +162,13 @@ def add_cli_args( type=int, default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU') - parser.add_argument('--gpu-memory-utilization', - type=float, - default=EngineArgs.gpu_memory_utilization, - help='the percentage of GPU memory to be used for' - 'the model executor') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -185,9 +189,25 @@ def add_cli_args( parser.add_argument('--quantization', '-q', type=str, - choices=['awq', 'squeezellm', None], + choices=['awq', 'gptq', 'squeezellm', None], default=None, - help='Method used to quantize the weights') + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument('--enforce-eager', + action='store_true', + help='Always use eager-mode PyTorch. If False, ' + 'will use eager mode and CUDA graph in hybrid ' + 'for maximal performance and flexibility.') + parser.add_argument('--max-context-len-to-capture', + type=int, + default=EngineArgs.max_context_len_to_capture, + help='maximum context length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode.') # LoRA related configs parser.add_argument('--enable-lora', action='store_true', @@ -209,13 +229,12 @@ def add_cli_args( default=EngineArgs.lora_dtype, choices=['auto', 'float16', 'bfloat16', 'float32'], help='data type for LoRA') - parser.add_argument( - '--max-cpu-loras', - type=int, - default=EngineArgs.max_cpu_loras, - help=('Maximum number of LoRAs to store in CPU memory. ' - 'Must be >= than max_num_seqs. ' - 'Defaults to max_num_seqs.')) + parser.add_argument('--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) return parser @classmethod @@ -235,7 +254,8 @@ def create_engine_configs( self.download_dir, self.load_format, self.dtype, self.seed, self.revision, self.tokenizer_revision, self.max_model_len, - self.quantization) + self.quantization, self.enforce_eager, + self.max_context_len_to_capture) cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, self.swap_space, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 876e85138a65a..6752ee144f2ad 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,7 +2,7 @@ import time from functools import partial from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, - Union) + Union, AsyncIterator) from vllm.lora.request import LoRARequest from vllm.config import ModelConfig @@ -458,12 +458,12 @@ async def add_request( return stream async def generate( - self, - prompt: Optional[str], - sampling_params: SamplingParams, - request_id: str, - prompt_token_ids: Optional[List[int]] = None, - lora_request: Optional[LoRARequest] = None) -> RequestOutput: + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None + ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3b538d47d6a12..33c1fe68bcb73 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -86,6 +86,7 @@ def __init__( f"load_format={model_config.load_format}, " f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " f"quantization={model_config.quantization}, " + f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed})") # TODO(woosuk): Print more configs in debug mode. @@ -242,12 +243,23 @@ def _init_cache(self) -> None: raise ValueError("No available memory for the cache blocks. " "Try increasing `gpu_memory_utilization` when " "initializing the engine.") + max_seq_len = self.cache_config.block_size * num_gpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks # Initialize the cache. self._run_workers("init_cache_engine", cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model") @classmethod def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": @@ -725,9 +737,10 @@ def _check_stop(self, seq: Sequence, """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): - # Truncate the output text so that the stop string is - # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + if not sampling_params.include_stop_str_in_output: + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return if seq.get_last_token_id() in sampling_params.stop_token_ids: diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py index e101940f10369..f402da4c621dd 100644 --- a/vllm/engine/ray_utils.py +++ b/vllm/engine/ray_utils.py @@ -1,9 +1,8 @@ -import socket from typing import Optional, Tuple, TYPE_CHECKING from vllm.config import ParallelConfig from vllm.logger import init_logger -from vllm.utils import is_hip +from vllm.utils import get_open_port, is_hip logger = init_logger(__name__) @@ -43,12 +42,6 @@ def execute_method(self, method, *args, **kwargs): from ray.util.placement_group import PlacementGroup -def get_open_port(): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - return s.getsockname()[1] - - def initialize_cluster( parallel_config: ParallelConfig, engine_use_ray: bool = False, diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index fb29837da8cf0..6910b3265dfd2 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -73,6 +73,8 @@ async def stream_results() -> AsyncGenerator[bytes, None]: parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default=None) parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() @@ -83,4 +85,6 @@ async def stream_results() -> AsyncGenerator[bytes, None]: host=args.host, port=args.port, log_level="debug", - timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9061909d72c33..a335c2e8e2e25 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -39,8 +39,10 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq". If None, we assume the model weights are not - quantized and use `dtype` to determine the data type of the weights. + we support "awq", "gptq" and "squeezellm". If None, we first check + the `quantization_config` attribute in the model config file. If + that is None, we assume the model weights are not quantized and use + `dtype` to determine the data type of the weights. revision: The specific model version to use. It can be a branch name, a tag name, or a commit id. tokenizer_revision: The specific tokenizer version to use. It can be a @@ -56,6 +58,12 @@ class LLM: when their `best_of` sampling parameters are larger than 1. If all requests will have `best_of=1`, you can safely set this to 0. Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. """ def __init__( @@ -72,6 +80,8 @@ def __init__( seed: int = 0, gpu_memory_utilization: float = 0.9, swap_space: int = 4, + enforce_eager: bool = False, + max_context_len_to_capture: int = 8192, **kwargs, ) -> None: if "disable_log_stats" not in kwargs: @@ -89,6 +99,8 @@ def __init__( seed=seed, gpu_memory_utilization=gpu_memory_utilization, swap_space=swap_space, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 7b94e1b52a5fd..be5f4190e633f 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -80,6 +80,14 @@ def parse_args(): default="assistant", help="The role name to return if " "`request.add_generation_prompt=true`.") + parser.add_argument("--ssl-keyfile", + type=str, + default=None, + help="The file path to the SSL key file") + parser.add_argument("--ssl-certfile", + type=str, + default=None, + help="The file path to the SSL cert file") parser = AsyncEngineArgs.add_cli_args(parser) return parser.parse_args() @@ -332,8 +340,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: # Send token-by-token response for each request.n delta_text = output.text[len(previous_texts[i]):] previous_texts[i] = output.text - completion_tokens = len(output.token_ids) - previous_num_tokens[i] = completion_tokens + previous_num_tokens[i] = len(output.token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage(content=delta_text), @@ -351,8 +358,8 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: prompt_tokens = len(res.prompt_token_ids) final_usage = UsageInfo( prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + previous_num_tokens[i], ) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=[], finish_reason=output.finish_reason) @@ -568,17 +575,22 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: i = output.index delta_text = output.text[len(previous_texts[i]):] token_ids = output.token_ids[previous_num_tokens[i]:] - top_logprobs = output.logprobs[previous_num_tokens[i]:] + if request.logprobs is not None: + top_logprobs = output.logprobs[previous_num_tokens[i]:] + else: + top_logprobs = None offsets = len(previous_texts[i]) if request.echo and not has_echoed[i]: if not echo_without_generation: delta_text = res.prompt + delta_text token_ids = res.prompt_token_ids + token_ids - top_logprobs = res.prompt_logprobs + top_logprobs - else: + if top_logprobs: + top_logprobs = res.prompt_logprobs + top_logprobs + else: # only just return the prompt delta_text = res.prompt token_ids = res.prompt_token_ids - top_logprobs = res.prompt_logprobs + if top_logprobs: + top_logprobs = res.prompt_logprobs has_echoed[i] = True if request.logprobs is not None: logprobs = create_logprobs( @@ -740,4 +752,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: host=args.host, port=args.port, log_level="info", - timeout_keep_alive=TIMEOUT_KEEP_ALIVE) + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py index e4ddf08cd9a03..af6f4921856e1 100644 --- a/vllm/model_executor/input_metadata.py +++ b/vllm/model_executor/input_metadata.py @@ -21,12 +21,14 @@ def __init__( max_context_len: Optional[int], context_lens: Optional[torch.Tensor], block_tables: Optional[torch.Tensor], + use_cuda_graph: bool, ) -> None: self.prompt_lens = prompt_lens self.max_context_len = max_context_len self.slot_mapping = slot_mapping self.context_lens = context_lens self.block_tables = block_tables + self.use_cuda_graph = use_cuda_graph self.is_prompt = len(prompt_lens) > 0 # Set during the execution of the first attention op. @@ -39,4 +41,5 @@ def __repr__(self) -> str: f"max_context_len={self.max_context_len}, " f"slot_mapping={self.slot_mapping}, " f"context_lens={self.context_lens}, " - f"block_tables={self.block_tables})") + f"block_tables={self.block_tables}, " + f"use_cuda_graph={self.use_cuda_graph})") diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py index 3f4ecb5d2ae7b..6482875d1c55b 100644 --- a/vllm/model_executor/layers/attention.py +++ b/vllm/model_executor/layers/attention.py @@ -24,13 +24,10 @@ class PagedAttention(nn.Module): can either contain prompt tokens or generation tokens. The class does the following: - 1. Wait for the cache operations (e.g., swap, copy) to finish. The cache - operations are issued by the cache engine before executing the forward - pass of the model, and they are executed asynchronously. - 2. Reshape and store the input key and value tensors in the KV cache. - 3. Perform (multi-head/multi-query/grouped-query) attention using either + 1. Reshape and store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention using either xformers or the PagedAttention custom op. - 4. Return the output tensor. + 3. Return the output tensor. """ def __init__( @@ -54,9 +51,6 @@ def __init__( assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.head_mapping = torch.repeat_interleave( - torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), - self.num_queries_per_kv) if self.head_size not in _SUPPORTED_HEAD_SIZES: raise ValueError(f"head_size ({self.head_size}) is not supported. " @@ -70,20 +64,18 @@ def forward( key_cache: Optional[torch.Tensor], value_cache: Optional[torch.Tensor], input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: """PagedAttention forward pass. Args: query: shape = [batch_size, seq_len, num_heads * head_size] key: shape = [batch_size, seq_len, num_kv_heads * head_size] - value: shape = [batch_size, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] key_cache: shape = [num_blocks, num_kv_heads, head_size/x, block_size, x] value_cache: shape = [num_blocks, num_kv_heads, head_size, block_size] input_metadata: metadata for the inputs. - cache_event: event to wait for the cache operations to finish. Returns: shape = [batch_size, seq_len, num_heads * head_size] """ @@ -92,10 +84,6 @@ def forward( query = query.view(-1, self.num_heads, self.head_size) key = key.view(-1, self.num_kv_heads, self.head_size) value = value.view(-1, self.num_kv_heads, self.head_size) - slot_mapping = input_metadata.slot_mapping.flatten() - - if cache_event is not None: - cache_event.wait() # Reshape the keys and values and store them in the cache. # If key_cache and value_cache are not provided, the new key and value @@ -107,7 +95,7 @@ def forward( value, key_cache, value_cache, - slot_mapping, + input_metadata.slot_mapping.flatten(), ) if input_metadata.is_prompt: @@ -141,7 +129,8 @@ def forward( input_metadata.attn_bias = attn_bias else: input_metadata.attn_bias = _make_alibi_bias( - self.alibi_slopes, batch_size, seq_len, query.dtype) + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) # TODO(woosuk): Too many view operations. Let's try to reduce them # in the future for code readability. @@ -167,15 +156,20 @@ def forward( output = out.view_as(query) else: # Decoding run. - output = _paged_attention( - query, - key_cache, - value_cache, - input_metadata, - self.head_mapping, - self.scale, - self.alibi_slopes, - ) + if key_cache is not None and value_cache is not None: + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + else: + # This happens during the initial memory profiling run for + # CUDA graphs. + output = torch.zeros_like(query) # Reshape the output tensor. return output.view(batch_size, seq_len, hidden_size) @@ -183,31 +177,34 @@ def forward( def _make_alibi_bias( alibi_slopes: torch.Tensor, + num_kv_heads: int, batch_size: int, seq_len: int, dtype: torch.dtype, ) -> LowerTriangularMaskWithTensorBias: - bias = torch.arange(seq_len, dtype=dtype) + bias = torch.arange(seq_len, dtype=dtype, device="cuda") # NOTE(zhuohan): HF uses # `bias = bias[None, :].repeat(prompt_len, 1)` # here. We find that both biases give the same results, but # the bias below more accurately follows the original ALiBi # paper. bias = bias[None, :] - bias[:, None] - bias = bias.to(alibi_slopes.device) # When using custom attention bias, xformers requires the bias to # be sliced from a tensor whose length is a multiple of 8. padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] bias = torch.empty( batch_size, - alibi_slopes.shape[0], + num_heads, seq_len, padded_len, device=alibi_slopes.device, dtype=dtype, )[:, :, :, :seq_len].copy_(bias) bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) attn_bias = LowerTriangularMaskWithTensorBias(bias) return attn_bias @@ -217,7 +214,7 @@ def _paged_attention( key_cache: torch.Tensor, value_cache: torch.Tensor, input_metadata: InputMetadata, - head_mapping: torch.Tensor, + num_kv_heads: int, scale: float, alibi_slopes: Optional[torch.Tensor], ) -> torch.Tensor: @@ -244,7 +241,7 @@ def _paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, input_metadata.block_tables, input_metadata.context_lens, @@ -274,7 +271,7 @@ def _paged_attention( query, key_cache, value_cache, - head_mapping, + num_kv_heads, scale, input_metadata.block_tables, input_metadata.context_lens, diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 810efb67df8d5..5190de65d7956 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -21,8 +21,10 @@ class LinearMethodBase(ABC): """Base class for different (maybe quantized) linear methods.""" @abstractmethod - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: """Create weights for a linear layer.""" raise NotImplementedError @@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - weight = Parameter(torch.empty(output_size, - input_size, + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) @@ -102,9 +106,11 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size, self.params_dtype) + self.input_size, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) if bias: self.bias = Parameter( torch.empty(self.output_size, @@ -168,10 +174,12 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size, self.output_size_per_partition, self.params_dtype) + self.input_size, self.output_size_per_partition, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -295,10 +303,12 @@ def weight_loader(self, loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) else: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "MergedColumnParallelLinear, assume the weight is " - "the same for all partitions.") + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -418,10 +428,12 @@ def weight_loader(self, loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size) else: - logger.warning( - "Loading a weight without `output_dim` attribute in " - "QKVParallelLinear, assume the weight is the same " - "for all partitions.") + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -481,10 +493,12 @@ def __init__( linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( - self.input_size_per_partition, self.output_size, self.params_dtype) + self.input_size_per_partition, self.output_size, self.input_size, + self.output_size, self.params_dtype) for name, weight in self.linear_weights.items(): - self.register_parameter(name, weight) - set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 3d937ba64f9fa..b3449eaff0e35 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,11 +1,13 @@ from typing import Type +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -from vllm.model_executor.layers.quantization.base_config import QuantizationConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, + "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, } diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 95d419e64f049..831576b1d7cd7 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.group_size != 0: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") - if output_size % self.quant_config.pack_factor != 0: + if output_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The output size is not aligned with the quantized " "weight shape. This can be caused by too large " @@ -92,8 +94,8 @@ def create_weights(self, input_size: int, output_size: int, qweight = Parameter( torch.empty( - input_size, - output_size // self.quant_config.pack_factor, + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -108,8 +110,8 @@ def create_weights(self, input_size: int, output_size: int, }) qzeros = Parameter( torch.empty( - input_size // self.quant_config.group_size, - output_size // self.quant_config.pack_factor, + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), @@ -124,8 +126,8 @@ def create_weights(self, input_size: int, output_size: int, }) scales = Parameter( torch.empty( - input_size // self.quant_config.group_size, - output_size, + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, device="cuda", dtype=params_dtype, ), @@ -142,7 +144,7 @@ def create_weights(self, input_size: int, output_size: int, } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py new file mode 100644 index 0000000000000..8fe96e7ddb98d --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -0,0 +1,215 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.pack_factor = 32 // self.weight_bits + # exllama kernel v1 only supports 4 bit + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + return cls(weight_bits, group_size, desc_act) + + def get_linear_method(self) -> "GPTQLinearMethod": + return GPTQLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() + + +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. + + Args: + quant_config: The GPTQ quantization config. + """ + + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + del output_size # Unused. + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExllamaState.UNINITIALIZED + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if input_size != input_size_per_partition and self.quant_config.group_size != -1: + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + exllama_state = ExllamaState.UNUSED + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + g_idx = Parameter( + torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) + qzeros = Parameter( + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + scales = Parameter( + torch.empty( + scale_and_zero_size, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 1, + }) + return { + "qweight": qweight, + "g_idx": g_idx, + "qzeros": qzeros, + "scales": scales, + "exllama_state": exllama_state, + } + + def apply_weights(self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + out_shape = x.shape[:-1] + (qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + weights["g_idx"] = torch.argsort(weights["g_idx"]).to( + torch.int) + else: + weights["g_idx"] = torch.empty((1, 1), device="meta") + weights["exllama_state"] = ExllamaState.READY + ops.gptq_shuffle(weights["qweight"], weights["g_idx"]) + output = ops.gptq_gemm(reshaped_x, weights["qweight"], + weights["qzeros"], weights["scales"], + weights["g_idx"], + weights["exllama_state"] == ExllamaState.READY) + if bias is not None: + output = output + bias + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index f2f9cac68fd81..1932bd145076b 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, input_size: int, output_size: int, - params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: - if input_size % self.quant_config.pack_factor != 0: + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " "tensor parallel size.") qweight = Parameter( torch.empty( - input_size // self.quant_config.pack_factor, - output_size, + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, device="cuda", dtype=torch.int32, ), @@ -108,7 +110,7 @@ def create_weights(self, input_size: int, output_size: int, } def apply_weights(self, - weights: Dict[str, torch.Tensor], + weights: Dict[str, Any], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = weights["qweight"] diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 538b766db00c5..72cebdce8090d 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -6,13 +6,11 @@ from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_gather) -from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceData, SequenceGroupOutput, SequenceOutput) -_SAMPLING_EPS = 1e-5 - class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -61,40 +59,34 @@ def forward( # Get the logits for the next tokens. logits = self._get_logits(hidden_states, embedding, embedding_bias) + _, vocab_size = logits.shape + # Apply logits processors (if any). logits = _apply_logits_processors(logits, sampling_metadata) + + # Prepare sampling tensors with pinned memory to avoid blocking. + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + # Apply presence and frequency penalties. - presence_penalties, frequency_penalties, repetition_penalties = ( - _get_penalties(sampling_metadata)) - assert len(presence_penalties) == logits.shape[0] - assert len(frequency_penalties) == logits.shape[0] - assert len(repetition_penalties) == logits.shape[0] - logits = _apply_penalties(logits, sampling_metadata, - presence_penalties, frequency_penalties, - repetition_penalties) + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) # Apply temperature scaling. - temperatures = _get_temperatures(sampling_metadata) - assert len(temperatures) == logits.shape[0] - if any(t != 1.0 for t in temperatures): - t = torch.tensor(temperatures, - dtype=logits.dtype, - device=logits.device) - # Use in-place division to avoid creating a new tensor. - logits.div_(t.unsqueeze(dim=1)) - - # Apply top-p and top-k truncation. - top_ps, top_ks, min_ps = _get_top_p_top_k_min_p( - sampling_metadata, self.vocab_size) - assert len(top_ps) == len(top_ks) == logits.shape[0] - do_top_p = any(p < 1.0 - _SAMPLING_EPS for p in top_ps) - do_top_k = any(k != self.vocab_size for k in top_ks) - if do_top_p or do_top_k: - logits = _apply_top_p_top_k(logits, top_ps, top_ks) - - do_min_p = any(mp > _SAMPLING_EPS for mp in min_ps) + # Use in-place division to avoid creating a new tensor. + logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_p_top_k(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + if do_min_p: - logits = _apply_min_p(logits, min_ps) + logits = _apply_min_p(logits, sampling_tensors.min_ps) # We use float32 for probabilities and log probabilities. # Compute the probabilities. @@ -121,32 +113,6 @@ def _prune_hidden_states( sampling_metadata.selected_token_indices) -def _get_penalties( - sampling_metadata: SamplingMetadata -) -> Tuple[List[float], List[float], List[float]]: - # Collect the presence and frequency penalties. - presence_penalties: List[float] = [] - frequency_penalties: List[float] = [] - repetition_penalties: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - p = sampling_params.presence_penalty - f = sampling_params.frequency_penalty - r = sampling_params.repetition_penalty - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - # NOTE: We do not apply presence and frequency penalties for the - # prompt token positions where we don't sample new tokens. - prompt_len = sampling_metadata.prompt_lens[i] - presence_penalties += [0] * (prompt_len - 1) - frequency_penalties += [0] * (prompt_len - 1) - repetition_penalties += [1] * (prompt_len - 1) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) - return presence_penalties, frequency_penalties, repetition_penalties - - def _get_prompt_and_output_tokens( sampling_metadata: SamplingMetadata, ) -> Tuple[List[List[int]], List[List[int]]]: @@ -169,25 +135,16 @@ def _get_prompt_and_output_tokens( def _get_bin_counts_and_mask( - logits: torch.Tensor, - tokens: List[List[int]], + tokens: torch.Tensor, vocab_size: int, num_seqs: int, ) -> Tuple[torch.Tensor, torch.Tensor]: - max_len = max(len(tokens) for tokens in tokens) - padded_tokens = [ - tokens + [vocab_size] * (max_len - len(tokens)) for tokens in tokens - ] - tokens_tensor = torch.tensor(padded_tokens, - dtype=torch.long, - device=logits.device) - # Compute the bin counts for the tokens. # vocab_size + 1 for padding. bin_counts = torch.zeros((num_seqs, vocab_size + 1), dtype=torch.long, - device=logits.device) - bin_counts.scatter_add_(1, tokens_tensor, torch.ones_like(tokens_tensor)) + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) bin_counts = bin_counts[:, :vocab_size] mask = bin_counts > 0 @@ -218,45 +175,16 @@ def _apply_logits_processors( return logits -def _apply_penalties( - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - presence_penalties: List[float], - frequency_penalties: List[float], - repetition_penalties: List[float], -) -> torch.Tensor: +def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: num_seqs, vocab_size = logits.shape - for i in range(num_seqs): - p = presence_penalties[i] - f = frequency_penalties[i] - r = repetition_penalties[i] - if abs(p) < _SAMPLING_EPS and abs(f) < _SAMPLING_EPS and abs( - r - 1.0) < _SAMPLING_EPS: - continue - break - else: - # Return early if all sequences have zero penalties. - return logits - - prompt_tokens, output_tokens = ( - _get_prompt_and_output_tokens(sampling_metadata)) - assert len(prompt_tokens) == logits.shape[0] - assert len(output_tokens) == logits.shape[0] - - prompt_bin_counts, prompt_mask = _get_bin_counts_and_mask( - logits, prompt_tokens, vocab_size, num_seqs) + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) output_bin_counts, output_mask = _get_bin_counts_and_mask( - logits, output_tokens, vocab_size, num_seqs) - - repetition_penalties = torch.tensor(repetition_penalties, - dtype=logits.dtype, - device=logits.device) - frequency_penalties = torch.tensor(frequency_penalties, - dtype=logits.dtype, - device=logits.device) - presence_penalties = torch.tensor(presence_penalties, - dtype=logits.dtype, - device=logits.device) + output_tokens_tensor, vocab_size, num_seqs) repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) repetition_penalties[~(prompt_mask | output_mask)] = 1.0 @@ -265,109 +193,65 @@ def _apply_penalties( # We follow the definition in OpenAI API. # Refer to https://platform.openai.com/docs/api-reference/parameter-details - logits -= frequency_penalties.unsqueeze(dim=1) * output_bin_counts - logits -= presence_penalties.unsqueeze(dim=1) * output_mask + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask return logits -def _get_temperatures(sampling_metadata: SamplingMetadata) -> List[float]: - # Collect the temperatures for the logits. - temperatures: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - temperature = sampling_params.temperature - if temperature < _SAMPLING_EPS: - # NOTE: Zero temperature means deterministic sampling - # (i.e., greedy sampling or beam search). - # Set the temperature to 1 to avoid division by zero. - temperature = 1.0 - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - prompt_len = sampling_metadata.prompt_lens[i] - temperatures += [temperature] * (prompt_len - 1) - temperatures += [temperature] * len(seq_ids) - return temperatures - - -def _get_top_p_top_k_min_p( - sampling_metadata: SamplingMetadata, - vocab_size: int, -) -> Tuple[List[float], List[int], List[float]]: - top_ps: List[float] = [] - top_ks: List[int] = [] - min_ps: List[float] = [] - for i, seq_group in enumerate(sampling_metadata.seq_groups): - seq_ids, sampling_params = seq_group - top_p = sampling_params.top_p - min_p = sampling_params.min_p - # k should not be greater than the vocab size. - top_k = min(sampling_params.top_k, vocab_size) - # k=-1 means no truncation. - top_k = vocab_size if top_k == -1 else top_k - if (i < sampling_metadata.num_prompts - and sampling_params.prompt_logprobs is not None): - prompt_len = sampling_metadata.prompt_lens[i] - top_ps += [top_p] * (prompt_len - 1) - top_ks += [top_k] * (prompt_len - 1) - min_ps += [min_p] * (prompt_len - 1) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - return top_ps, top_ks, min_ps - - def _apply_top_p_top_k( logits: torch.Tensor, - top_ps: List[float], - top_ks: List[int], + p: torch.Tensor, + k: torch.Tensor, ) -> torch.Tensor: - p = torch.tensor(top_ps, dtype=logits.dtype, device=logits.device) - k = torch.tensor(top_ks, dtype=torch.int, device=logits.device) logits_sort, logits_idx = logits.sort(dim=-1, descending=True) # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) - probs_sum = probs_sort.cumsum(dim=-1) - top_p_mask = (probs_sum - probs_sort) > p.unsqueeze(dim=1) - logits_sort[top_p_mask] = -float("inf") + probs_sum = probs_sort.cumsum(dim=-1).sub_(probs_sort) + top_p_mask = probs_sum > p.unsqueeze_(dim=1) # Apply top-k. # Create a mask for the top-k elements. top_k_mask = torch.arange(logits_idx.shape[-1], device=logits_idx.device) top_k_mask = top_k_mask.expand(logits_idx.shape[0], -1) - top_k_mask = top_k_mask >= k.unsqueeze(dim=1) - logits_sort[top_k_mask] = -float("inf") + top_k_mask = top_k_mask >= k.unsqueeze_(dim=1) + + # Final mask. + mask = (top_p_mask | top_k_mask) + logits_sort.masked_fill_(mask, -float("inf")) # Re-sort the probabilities. - logits = torch.gather(logits_sort, - dim=-1, - index=torch.argsort(logits_idx, dim=-1)) + src = torch.arange(logits_idx.shape[-1], + device=logits_idx.device).expand_as(logits_idx) + logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, + index=logits_idx, + src=src) + logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) return logits def _apply_min_p( logits: torch.Tensor, - min_ps: List[float], + min_p: torch.Tensor, ) -> torch.Tensor: """ Adapted from https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 """ - min_p = torch.tensor(min_ps, dtype=logits.dtype, device=logits.device) probs = torch.softmax(logits, dim=-1) top_probs, _ = probs.max(dim=-1, keepdim=True) - scaled_min_p = min_p.unsqueeze(dim=1) * top_probs + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs tokens_to_remove = probs < scaled_min_p - logits = logits.masked_fill(tokens_to_remove, -float("inf")) + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) return logits def _greedy_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], - logprobs: torch.Tensor, + samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: - samples = torch.argmax(logprobs, dim=-1).cpu() + samples = samples.tolist() sample_idx = 0 results = [] for seq_group in selected_seq_groups: @@ -376,27 +260,19 @@ def _greedy_sample( assert num_parent_seqs == 1, ( "Greedy sampling should have only one seq.") parent_ids = list(range(num_parent_seqs)) - next_token_ids = [samples[sample_idx].item()] + next_token_ids = [samples[sample_idx]] results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == logprobs.size(0) return results def _random_sample( selected_seq_groups: List[Tuple[List[int], SamplingParams]], is_prompts: List[bool], - probs: torch.Tensor, + random_samples: torch.Tensor, ) -> List[Tuple[List[int], List[int]]]: # Find the maximum best_of value of the prompt phase requests. - max_best_of = 1 - for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): - if is_prompt: - seq_ids, sampling_params = seq_group - max_best_of = max(max_best_of, sampling_params.best_of) - random_samples = torch.multinomial(probs, - num_samples=max_best_of, - replacement=True).cpu() + random_samples = random_samples.cpu() sample_idx = 0 results = [] for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): @@ -404,8 +280,6 @@ def _random_sample( num_parent_seqs = len(seq_ids) if is_prompt: # Prompt phase. - assert num_parent_seqs == 1, ( - "Prompt input should have only one seq.") parent_ids = [0] * sampling_params.best_of next_token_ids = random_samples[ sample_idx, :sampling_params.best_of].tolist() @@ -416,7 +290,6 @@ def _random_sample( num_parent_seqs, 0].tolist() results.append((next_token_ids, parent_ids)) sample_idx += num_parent_seqs - assert sample_idx == probs.size(0) return results @@ -473,6 +346,28 @@ def _beam_search_sample( return results +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, +): + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + # This allows us to do sampling with replacement by creating + # num_samples copies of each row in the tensor, and then + # batch sampling the resulting tensor. + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + def _sample( probs: torch.Tensor, logprobs: torch.Tensor, @@ -486,28 +381,51 @@ def _sample( categorized_seq_group_ids[sampling_type].append(i) sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. for sampling_type in SamplingType: - seq_group_ids = categorized_seq_group_ids[sampling_type] - seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] - is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] sample_indices = categorized_sample_indices[sampling_type] num_tokens = len(sample_indices) if num_tokens == 0: continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices) if sampling_type == SamplingType.GREEDY: - category_logprobs = logprobs[sample_indices] - sample_results = _greedy_sample(seq_groups, category_logprobs) + greedy_samples = torch.argmax(logprobs[sample_indices], dim=-1) + elif sampling_type == SamplingType.RANDOM: + max_best_of = 1 + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of = max(max_best_of, sampling_params.best_of) + multinomial_samples = _multinomial(probs[sample_indices], + max_best_of) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ + sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) elif sampling_type == SamplingType.RANDOM: - category_probs = probs[sample_indices] sample_results = _random_sample(seq_groups, is_prompts, - category_probs) + multinomial_samples) elif sampling_type == SamplingType.BEAM: - category_logprobs = logprobs[sample_indices] sample_results = _beam_search_sample(seq_groups, is_prompts, sampling_metadata.seq_data, - category_logprobs) - else: - raise ValueError(f"Unsupported sampling type: {sampling_type}") + beam_search_logprobs) sample_results_dict.update(zip(seq_group_ids, sample_results)) sample_results = [ @@ -558,7 +476,7 @@ def _get_logprobs( batched_logprobs_query_result = logprobs[[ batched_logprobs_query_seq_indices, batched_logprobs_query_token_indices - ]].cpu() + ]] # Batched query for logprobs of topk tokens if largest_num_logprobs > 0: @@ -570,6 +488,8 @@ def _get_logprobs( else: top_logprobs, top_token_ids = None, None + batched_logprobs_query_result = batched_logprobs_query_result.cpu() + # Gather results result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] result_sample_logprobs: List[SampleLogprobs] = [] diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index fbf00b50049c1..e65b1bc6525f8 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -7,7 +7,7 @@ from transformers import PretrainedConfig from vllm.config import ModelConfig, LoRAConfig -from vllm.model_executor.models import * +from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) from vllm.utils import is_hip @@ -15,46 +15,6 @@ logger = init_logger(__name__) -# TODO(woosuk): Lazy-load the model classes. -_MODEL_REGISTRY = { - "AquilaModel": AquilaForCausalLM, - "AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2 - "BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b - "BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b - "BloomForCausalLM": BloomForCausalLM, - "ChatGLMModel": ChatGLMForCausalLM, - "ChatGLMForConditionalGeneration": ChatGLMForCausalLM, - "FalconForCausalLM": FalconForCausalLM, - "GPT2LMHeadModel": GPT2LMHeadModel, - "GPTBigCodeForCausalLM": GPTBigCodeForCausalLM, - "GPTJForCausalLM": GPTJForCausalLM, - "GPTNeoXForCausalLM": GPTNeoXForCausalLM, - "InternLMForCausalLM": InternLMForCausalLM, - "LlamaForCausalLM": LlamaForCausalLM, - "LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-* - "MistralForCausalLM": MistralForCausalLM, - # transformers's mpt class has lower case - "MptForCausalLM": MPTForCausalLM, - "MPTForCausalLM": MPTForCausalLM, - "OPTForCausalLM": OPTForCausalLM, - "PhiForCausalLM": PhiForCausalLM, - "QWenLMHeadModel": QWenLMHeadModel, - "RWForCausalLM": FalconForCausalLM, - "YiForCausalLM": YiForCausalLM, -} - -# Models to be disabled in ROCm -_ROCM_UNSUPPORTED_MODELS = [] -if is_hip(): - for rocm_model in _ROCM_UNSUPPORTED_MODELS: - del _MODEL_REGISTRY[rocm_model] - -# Models partially supported in ROCm -_ROCM_PARTIALLY_SUPPORTED_MODELS = { - "MistralForCausalLM": - "Sliding window attention is not supported in ROCm's flash attention", -} - @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -68,19 +28,12 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: architectures = getattr(config, "architectures", []) for arch in architectures: - if arch in _MODEL_REGISTRY: - if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - logger.warning( - f"{arch} is not fully supported in ROCm. Reason: " - f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}") - return _MODEL_REGISTRY[arch] - elif arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError( - f"Model architecture {arch} is not supported by ROCm for now. \n" - f"Supported architectures {list(_MODEL_REGISTRY.keys())}") + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls raise ValueError( f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") + f"Supported architectures: {ModelRegistry.get_supported_archs()}") def get_model(model_config: ModelConfig, diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 078d3d74719df..f60ea640359b3 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,39 +1,83 @@ -from vllm.model_executor.models.aquila import AquilaForCausalLM -from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM, - BaichuanForCausalLM) -from vllm.model_executor.models.bloom import BloomForCausalLM -from vllm.model_executor.models.falcon import FalconForCausalLM -from vllm.model_executor.models.gpt2 import GPT2LMHeadModel -from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM -from vllm.model_executor.models.gpt_j import GPTJForCausalLM -from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM -from vllm.model_executor.models.internlm import InternLMForCausalLM -from vllm.model_executor.models.llama import LlamaForCausalLM -from vllm.model_executor.models.mistral import MistralForCausalLM -from vllm.model_executor.models.mpt import MPTForCausalLM -from vllm.model_executor.models.opt import OPTForCausalLM -from vllm.model_executor.models.phi_1_5 import PhiForCausalLM -from vllm.model_executor.models.qwen import QWenLMHeadModel -from vllm.model_executor.models.chatglm import ChatGLMForCausalLM -from vllm.model_executor.models.yi import YiForCausalLM +import importlib +from typing import List, Optional, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + +# Architecture -> (module, class). +_MODELS = { + "AquilaModel": ("aquila", "AquilaForCausalLM"), + "AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2 + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "InternLMForCausalLM": ("internlm", "InternLMForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("mistral", "MistralForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "YiForCausalLM": ("yi", "YiForCausalLM"), +} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_PARTIALLY_SUPPORTED_MODELS = { + "MistralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", + "MixtralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", +} + + +class ModelRegistry: + + @staticmethod + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "ROCm for now.") + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + f"Model architecture {model_arch} is partially supported " + "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + + module_name, model_cls_name = _MODELS[model_arch] + module = importlib.import_module( + f"vllm.model_executor.models.{module_name}") + return getattr(module, model_cls_name, None) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + __all__ = [ - "AquilaForCausalLM", - "BaiChuanForCausalLM", - "BaichuanForCausalLM", - "BloomForCausalLM", - "ChatGLMForCausalLM", - "FalconForCausalLM", - "GPT2LMHeadModel", - "GPTBigCodeForCausalLM", - "GPTJForCausalLM", - "GPTNeoXForCausalLM", - "InternLMForCausalLM", - "LlamaForCausalLM", - "MPTForCausalLM", - "OPTForCausalLM", - "PhiForCausalLM", - "QWenLMHeadModel", - "MistralForCausalLM", - "YiForCausalLM", + "ModelRegistry", ] diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index f8c4d643294b1..2a1a0d76801c8 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -158,14 +158,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) return output @@ -209,7 +207,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -219,7 +216,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) hidden_states = residual + hidden_states @@ -258,18 +254,15 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, ) hidden_states = self.norm(hidden_states) @@ -296,10 +289,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -332,11 +324,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 3b56b9e137021..cd8ab444677ea 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -172,15 +172,13 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.W_pack(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) if self.postion_embedding != "ALIBI": q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) return output @@ -221,7 +219,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -236,7 +233,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Fully Connected @@ -273,19 +269,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -311,10 +304,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -355,26 +347,38 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) -class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 13B and Baichuan2 7B/13B.""" def __init__(self, config, linear_method: Optional[LinearMethodBase] = None): - super().__init__(config, "ALIBI", linear_method) + if config.hidden_size == 4096: # baichuan2 7b + super().__init__(config, "ROPE", linear_method) + else: # baichuan 13b, baichuan2 13b + super().__init__(config, "ALIBI", linear_method) -class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b +class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 7B.""" def __init__(self, config, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 9da0490104b6a..6d1aeeed78e93 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -118,14 +118,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: del position_ids # Unused. qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.dense(attn_output) return output @@ -184,7 +182,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # Layer norm at the beginning of the transformer layer. layernorm_output = self.input_layernorm(hidden_states) @@ -201,7 +198,6 @@ def forward( hidden_states=layernorm_output, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) attention_output = attention_output + residual layernorm_output = self.post_attention_layernorm(attention_output) @@ -250,19 +246,16 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) hidden_states = self.word_embeddings_layernorm(hidden_states) for i in range(len(self.h)): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], input_metadata, - cache_event, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -288,10 +281,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 60ec4d9b4018a..aa957b36b36d7 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -100,7 +100,6 @@ def forward( position_ids: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) @@ -113,7 +112,6 @@ def forward( key_cache, value_cache, input_metadata, - cache_event, ) attn_output, _ = self.dense(context_layer) return attn_output @@ -203,7 +201,6 @@ def forward( position_ids: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # hidden_states: [num_tokens, h] # Layer norm at the beginning of the transformer layer. @@ -214,7 +211,6 @@ def forward( position_ids=position_ids, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Residual connection. @@ -269,17 +265,14 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: for i in range(self.num_layers): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, kv_cache=kv_caches[i], input_metadata=input_metadata, - cache_event=cache_event, ) # Final layer norm. if self.post_layer_norm: @@ -314,8 +307,7 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], - ): + ) -> torch.Tensor: inputs_embeds = self.embedding(input_ids) # Run encoder. @@ -324,9 +316,7 @@ def forward( position_ids=position_ids, kv_caches=kv_caches, input_metadata=input_metadata, - cache_events=cache_events, ) - return hidden_states @@ -350,10 +340,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -377,6 +366,9 @@ def load_weights(self, continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py new file mode 100644 index 0000000000000..984be0cccd16d --- /dev/null +++ b/vllm/model_executor/models/decilm.py @@ -0,0 +1,123 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 DeciAI Research Team. All rights reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeciLM model compatible with HuggingFace weights.""" + +from typing import Optional + +import torch +from transformers import PretrainedConfig + +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) + + +class DeciLMForCausalLM(LlamaForCausalLM): + """ + Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct. + Based on the llama executor. + + The main difference is that DeciLM uses Variable Grouped Query Attention. + The constant number of GQA heads in the decoder is overriden with a value + per layer. + + Usually, in the HuggingFace implementation, instead of + "config.num_key_value_heads", we use + "config.num_key_value_heads_per_layer[i]" which varies. + + Currently, PagedAttention does not work well with variable GQA, so we + normalize the weights upon loading, and use uniform GQA with the max value + instead. + """ + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + config.num_key_value_heads = max(config.num_key_value_heads_per_layer) + delattr(config, "num_key_value_heads_per_layer") + super().__init__(config=config, linear_method=linear_method) + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + if "k_proj" in name or "v_proj" in name: + loaded_weight = self._degroup_weight(loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: + hidden_size = self.config.hidden_size + head_size = self.config.hidden_size // self.config.num_attention_heads + target_num_kv_heads = self.config.num_key_value_heads + num_kv_heads = loaded_weight.shape[0] // head_size + n_repeats = target_num_kv_heads / num_kv_heads + assert n_repeats == int(n_repeats) + + n_repeats = int(n_repeats) + loaded_weight = loaded_weight.view(num_kv_heads, head_size, + hidden_size) + loaded_weight = torch.repeat_interleave(loaded_weight, + repeats=n_repeats, + dim=0) + loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size, + hidden_size) + + return loaded_weight diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 8890d29b1267b..7055d08521c47 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -178,7 +178,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, bias = self.query_key_value(hidden_states) if bias is not None: @@ -187,8 +186,7 @@ def forward( if self.use_rotary: q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) attn_output, bias = self.dense(attn_output) return attn_output, bias @@ -266,8 +264,7 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], - ): + ) -> torch.Tensor: residual = hidden_states if self.config.new_decoder_architecture: @@ -282,7 +279,6 @@ def forward( hidden_states=attention_layernorm_out, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) if self.reduce_row_parallel_results and attention_bias is not None: attention_output += attention_bias @@ -311,7 +307,6 @@ def forward( mlp_output += mlp_bias output = mlp_output + residual - return output @@ -349,18 +344,15 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.word_embeddings(input_ids) for i in range(len(self.h)): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] hidden_states = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -389,14 +381,12 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer( input_ids, positions, kv_caches, input_metadata, - cache_events, ) return hidden_states @@ -425,27 +415,32 @@ def load_weights(self, params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] if "query_key_value" in name: output_dim = getattr(param, "output_dim", None) loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + - (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + - loaded_weight_shape[output_dim + 1:]) - wq = loaded_weight.narrow( - output_dim + 1, 0, num_query_heads_per_kv_head).reshape( - *loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - wk = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - wv = loaded_weight.narrow( - output_dim + 1, num_query_heads_per_kv_head + 1, - 1).reshape(*loaded_weight_shape[:output_dim], -1, - *loaded_weight_shape[output_dim + 1:]) - loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 5fe678ecc9d5d..d9b561cd8b225 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -82,13 +82,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata, cache_event) + input_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -148,7 +147,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -156,7 +154,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # residual connection hidden_states = attn_output + residual @@ -196,17 +193,14 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds for i in range(len(self.h)): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata, - cache_event) + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -232,10 +226,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -275,7 +268,6 @@ def load_weights(self, if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 2007c264f0cb9..4d8144bad351f 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -95,7 +95,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.split( @@ -107,7 +106,7 @@ def forward( ) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata, cache_event) + input_metadata) attn_output, _ = self.c_proj(attn_output) return attn_output @@ -167,7 +166,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -175,7 +173,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # residual connection hidden_states = attn_output + residual @@ -215,17 +212,14 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: inputs_embeds = self.wte(input_ids) position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds for i in range(len(self.h)): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata, - cache_event) + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -251,10 +245,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 1ad344fd6cc0d..ab3480a77a43d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -94,14 +94,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) attn_output, _ = self.out_proj(attn_output) return attn_output @@ -156,7 +154,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln_1(hidden_states) @@ -165,7 +162,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) mlp_output = self.mlp(hidden_states) hidden_states = attn_output + mlp_output + residual @@ -196,18 +192,15 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.h)): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], input_metadata, - cache_event, ) hidden_states = self.ln_f(hidden_states) return hidden_states @@ -238,10 +231,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -274,11 +266,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index df5c86bf103ad..1e340162ddb58 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -72,7 +72,6 @@ def __init__( config.hidden_size, linear_method=linear_method, ) - scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) assert rotary_dim % 2 == 0 @@ -93,14 +92,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.dense(attn_output) return output @@ -156,7 +153,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: attn_input = self.input_layernorm(hidden_states) attn_output = self.attention( @@ -164,7 +160,6 @@ def forward( hidden_states=attn_input, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) if self.use_parallel_residual: @@ -211,18 +206,15 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_in(input_ids) for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states = layer( position_ids, hidden_states, kv_caches[i], input_metadata, - cache_event, ) hidden_states = self.final_layer_norm(hidden_states) return hidden_states @@ -251,10 +243,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index ba28ff8d140af..00bb70fc3f87f 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -110,14 +110,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) return output @@ -160,7 +158,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -175,7 +172,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Fully Connected @@ -214,19 +210,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -253,10 +246,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -289,11 +281,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 240f3ad57f655..ddae87b07b978 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -148,14 +148,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) return output @@ -199,7 +197,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -214,7 +211,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Fully Connected @@ -256,19 +252,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -308,10 +301,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -341,14 +333,26 @@ def load_weights(self, model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 7e2ee4a721e67..058a219d3f5a6 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -146,14 +146,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) return output @@ -194,7 +192,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -209,7 +206,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Fully Connected @@ -252,19 +248,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -306,10 +299,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -342,11 +334,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py new file mode 100644 index 0000000000000..e61b401a78a2b --- /dev/null +++ b/vllm/model_executor/models/mixtral.py @@ -0,0 +1,412 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch import nn +from transformers import MixtralConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + ReplicatedLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + linear_method=linear_method) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralMoE(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + linear_method=linear_method) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, + self.num_total_experts, + bias=False, + linear_method=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + batch_size, sequence_length, hidden_dim) + + +class MixtralAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + linear_method=linear_method) + self.block_sparse_moe = MixtralMoE(config=config, + linear_method=linear_method) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + MixtralDecoderLayer(config, linear_method=linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> SamplerOutput: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = MixtralModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index c7be7a922915f..d6e9a76d2ba42 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -50,9 +50,14 @@ def __init__( super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads self.clip_qkv = config.attn_config["clip_qkv"] self.qk_ln = config.attn_config["qk_ln"] self.alibi_bias_max = config.attn_config["alibi_bias_max"] + if "kv_n_heads" in config.attn_config: + self.total_num_kv_heads = config.attn_config['kv_n_heads'] + else: + self.total_num_kv_heads = self.total_num_heads assert not config.attn_config["prefix_lm"] assert config.attn_config["alibi"] @@ -61,6 +66,7 @@ def __init__( self.d_model, self.d_model // self.total_num_heads, self.total_num_heads, + self.total_num_kv_heads, bias=not config.no_bias, linear_method=linear_method, ) @@ -78,6 +84,17 @@ def __init__( assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim # Create the alibi slopes and slice them. tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -91,7 +108,8 @@ def __init__( self.attn = PagedAttention(self.num_heads, self.head_dim, scaling, - alibi_slopes=alibi_slopes) + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads) def forward( self, @@ -99,19 +117,17 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: del position_ids # unused. qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) - q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) if self.qk_ln: q = self.q_ln(q) k = self.k_ln(k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.out_proj(attn_output) return output @@ -169,7 +185,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: x = self.norm_1(hidden_states) x = self.attn( @@ -177,7 +192,6 @@ def forward( hidden_states=x, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) hidden_states = hidden_states + x x = self.norm_2(hidden_states) @@ -217,18 +231,15 @@ def forward( position_ids: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.wte(input_ids) for i in range(len(self.blocks)): - cache_event = None if cache_events is None else cache_events[i] block = self.blocks[i] hidden_states = block( position_ids, hidden_states, kv_caches[i], input_metadata, - cache_event, ) hidden_states = self.norm_f(hidden_states) return hidden_states @@ -256,10 +267,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -279,6 +289,9 @@ def load_weights(self, params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 1c698c20f35db..22d3b5ccadfde 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -98,13 +98,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, - input_metadata, cache_event) + input_metadata) output, _ = self.out_proj(attn_output) return output @@ -154,7 +153,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: # Self Attention residual = hidden_states @@ -163,8 +161,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn(hidden_states=hidden_states, kv_cache=kv_cache, - input_metadata=input_metadata, - cache_event=cache_event) + input_metadata=input_metadata) hidden_states = residual + hidden_states # 350m applies layer norm AFTER attention if not self.do_layer_norm_before: @@ -245,7 +242,6 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: inputs_embeds = self.embed_tokens(input_ids) pos_embeds = self.embed_positions(positions) @@ -254,10 +250,8 @@ def forward( hidden_states = inputs_embeds + pos_embeds for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] - hidden_states = layer(hidden_states, kv_caches[i], input_metadata, - cache_event) + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) @@ -282,10 +276,8 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: - return self.decoder(input_ids, positions, kv_caches, input_metadata, - cache_events) + return self.decoder(input_ids, positions, kv_caches, input_metadata) class OPTForCausalLM(nn.Module): @@ -308,10 +300,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -345,11 +336,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/phi_1_5.py b/vllm/model_executor/models/phi_1_5.py index ac441e476bb82..9f3c6f68d24e5 100644 --- a/vllm/model_executor/models/phi_1_5.py +++ b/vllm/model_executor/models/phi_1_5.py @@ -135,14 +135,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.Wqkv(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(position_ids, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.out_proj(attn_output) return output @@ -195,7 +193,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: residual = hidden_states hidden_states = self.ln(hidden_states) @@ -204,7 +201,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) feed_forward_hidden_states = self.mlp(hidden_states) hidden_states = attn_outputs + feed_forward_hidden_states + residual @@ -231,18 +227,15 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embd(input_ids) for i in range(self.config.num_hidden_layers): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] hidden_states = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, ) return hidden_states @@ -277,10 +270,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) hidden_states = self.lm_head.ln(hidden_states) return hidden_states @@ -305,6 +297,9 @@ def load_weights(self, if "rotary_emb.inv_freq" in name: continue + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue # pylint: disable=E1136 param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 33bae61f6016d..2d394a6b914c5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -82,7 +82,6 @@ def __init__( self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - self.c_attn = QKVParallelLinear( hidden_size, self.head_dim, @@ -113,14 +112,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.c_attn(hidden_states) q, k, v = qkv.chunk(chunks=3, dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.c_proj(attn_output) return output @@ -157,7 +154,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -171,7 +167,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Fully Connected @@ -207,19 +202,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.wte(input_ids) residual = None for i in range(len(self.h)): - cache_event = None if cache_events is None else cache_events[i] layer = self.h[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, residual, ) hidden_states, _ = self.ln_f(hidden_states, residual) @@ -246,10 +238,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -279,11 +270,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 889cc3f0b5fcb..78715a8873fce 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -146,14 +146,12 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) k_cache, v_cache = kv_cache - attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata, - cache_event) + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) output, _ = self.o_proj(attn_output) return output @@ -195,7 +193,6 @@ def forward( hidden_states: torch.Tensor, kv_cache: KVCache, input_metadata: InputMetadata, - cache_event: Optional[torch.cuda.Event], residual: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor]: # Self Attention @@ -209,7 +206,6 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, input_metadata=input_metadata, - cache_event=cache_event, ) # Fully Connected @@ -245,19 +241,16 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.embed_tokens(input_ids) residual = None for i in range(len(self.layers)): - cache_event = None if cache_events is None else cache_events[i] layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, kv_caches[i], input_metadata, - cache_event, residual, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -284,10 +277,9 @@ def forward( positions: torch.Tensor, kv_caches: List[KVCache], input_metadata: InputMetadata, - cache_events: Optional[List[torch.cuda.Event]], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata, cache_events) + input_metadata) return hidden_states def sample( @@ -320,11 +312,18 @@ def load_weights(self, for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = params_dict[name.replace(weight_name, param_name)] + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) break else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index deb779f537c69..49013ec273787 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -1,9 +1,13 @@ +from dataclasses import dataclass from typing import Dict, List, Tuple import torch from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SequenceData +from vllm.utils import in_wsl + +_SAMPLING_EPS = 1e-5 class SamplingMetadata: @@ -41,3 +45,186 @@ def __repr__(self) -> str: f"prompt_lens={self.prompt_lens}, " f"selected_token_indices={self.selected_token_indices}, " f"categorized_sample_indices={self.categorized_sample_indices})") + + +@dataclass +class SamplingTensors: + """Tensors for sampling.""" + + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + prompt_tokens: torch.Tensor + output_tokens: torch.Tensor + + @classmethod + def from_sampling_metadata( + cls, sampling_metadata: "SamplingMetadata", vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: List[List[int]] = [] + output_tokens: List[List[int]] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # For tokens in the prompt that we only need to get their logprobs + prompt_len = sampling_metadata.prompt_lens[i] + temperatures += [temperature] * (prompt_len - 1) + top_ps += [top_p] * (prompt_len - 1) + top_ks += [top_k] * (prompt_len - 1) + min_ps += [min_p] * (prompt_len - 1) + presence_penalties += [0] * (prompt_len - 1) + frequency_penalties += [0] * (prompt_len - 1) + repetition_penalties += [1] * (prompt_len - 1) + prompt_tokens.extend([] for _ in range(prompt_len - 1)) + output_tokens.extend([] for _ in range(prompt_len - 1)) + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + + sampling_tensors = SamplingTensors.from_lists( + temperatures, top_ps, top_ks, min_ps, presence_penalties, + frequency_penalties, repetition_penalties, prompt_tokens, + output_tokens, vocab_size, device, dtype) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + + @classmethod + def from_lists(cls, temperatures: List[float], top_ps: List[float], + top_ks: List[int], min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[List[int]], + output_tokens: List[List[int]], vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> "SamplingTensors": + # Note that the performance will be very bad without + # pinned memory. + pin_memory = not in_wsl() + prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_padded_tokens = [ + tokens + [vocab_size] * (prompt_max_len - len(tokens)) + for tokens in prompt_tokens + ] + output_max_len = max(len(tokens) for tokens in output_tokens) + output_padded_tokens = [ + tokens + [vocab_size] * (output_max_len - len(tokens)) + for tokens in output_tokens + ] + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + min_ps_t = torch.tensor( + min_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + presence_penalties_t = torch.tensor( + presence_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + frequency_penalties_t = torch.tensor( + frequency_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + prompt_tensor = torch.tensor( + prompt_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + output_tensor = torch.tensor( + output_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + # Because the memory is pinned, we can do non-blocking + # transfer to device. + return cls( + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + min_ps=min_ps_t.to(device=device, non_blocking=True), + presence_penalties=presence_penalties_t.to(device=device, + non_blocking=True), + frequency_penalties=frequency_penalties_t.to(device=device, + non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, + non_blocking=True), + prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), + output_tokens=output_tensor.to(device=device, non_blocking=True), + ) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 825b9016288c6..365c847a435fe 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -125,15 +125,29 @@ def get_quant_config( def prepare_hf_model_weights( model_name_or_path: str, cache_dir: Optional[str] = None, - use_safetensors: bool = False, + load_format: str = "auto", fall_back_to_pt: bool = True, revision: Optional[str] = None, ) -> Tuple[str, List[str], bool]: # Download model weights from huggingface. is_local = os.path.isdir(model_name_or_path) + use_safetensors = False # Some quantized models use .pt files for storing the weights. - allow_patterns = ["*.safetensors" - ] if use_safetensors else ["*.bin", "*.pt"] + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npcache": + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + if not is_local: # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. @@ -148,6 +162,10 @@ def prepare_hf_model_weights( hf_weights_files: List[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break if not use_safetensors: # Exclude files that are not needed for inference. # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 @@ -163,13 +181,6 @@ def prepare_hf_model_weights( if not any(f.endswith(x) for x in blacklist) ] - if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt: - return prepare_hf_model_weights(model_name_or_path, - cache_dir=cache_dir, - use_safetensors=False, - fall_back_to_pt=False, - revision=revision) - if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") @@ -182,30 +193,16 @@ def hf_model_weights_iterator( cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None, + fall_back_to_pt: Optional[bool] = True, ) -> Iterator[Tuple[str, torch.Tensor]]: - use_safetensors = False - use_np_cache = False - fall_back_to_pt = False - if load_format == "auto": - use_safetensors = True - fall_back_to_pt = True - elif load_format == "safetensors": - use_safetensors = True - elif load_format == "pt": - pass - elif load_format == "npcache": - use_np_cache = True - else: - raise ValueError(f"Unknown load_format: {load_format}") - hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( model_name_or_path, cache_dir=cache_dir, - use_safetensors=use_safetensors, + load_format=load_format, fall_back_to_pt=fall_back_to_pt, revision=revision) - if use_np_cache: + if load_format == "npcache": # Currently np_cache only support *.bin checkpoints assert use_safetensors is False @@ -287,4 +284,5 @@ def initialize_dummy_weights( values between -1e-3 and 1e-3 works well for most models. """ for param in model.state_dict().values(): - param.data.uniform_(low, high) + if torch.is_floating_point(param): + param.data.uniform_(low, high) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 38b7c0b531bd2..30a8036a63fc9 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -2,6 +2,7 @@ from enum import IntEnum from functools import cached_property from typing import Callable, List, Optional, Union + import torch _SAMPLING_EPS = 1e-5 @@ -70,6 +71,8 @@ class SamplingParams: stop_token_ids: List of tokens that stop the generation when they are generated. The returned output will contain the stop tokens unless the stop tokens are special tokens. + include_stop_str_in_output: Whether to include the stop strings in output + text. Defaults to False. ignore_eos: Whether to ignore the EOS token and continue generating tokens after the EOS token is generated. max_tokens: Maximum number of tokens to generate per output sequence. @@ -103,6 +106,7 @@ def __init__( early_stopping: Union[bool, str] = False, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, ignore_eos: bool = False, max_tokens: int = 16, logprobs: Optional[int] = None, @@ -140,6 +144,7 @@ def __init__( self.skip_special_tokens = skip_special_tokens self.spaces_between_special_tokens = spaces_between_special_tokens self.logits_processors = logits_processors + self.include_stop_str_in_output = include_stop_str_in_output self._verify_args() if self.use_beam_search: self._verify_beam_search() @@ -227,24 +232,26 @@ def sampling_type(self) -> SamplingType: return SamplingType.RANDOM def __repr__(self) -> str: - return (f"SamplingParams(n={self.n}, " - f"best_of={self.best_of}, " - f"presence_penalty={self.presence_penalty}, " - f"frequency_penalty={self.frequency_penalty}, " - f"repetition_penalty={self.repetition_penalty}, " - f"temperature={self.temperature}, " - f"top_p={self.top_p}, " - f"top_k={self.top_k}, " - f"min_p={self.min_p}, " - f"use_beam_search={self.use_beam_search}, " - f"length_penalty={self.length_penalty}, " - f"early_stopping={self.early_stopping}, " - f"stop={self.stop}, " - f"stop_token_ids={self.stop_token_ids}, " - f"ignore_eos={self.ignore_eos}, " - f"max_tokens={self.max_tokens}, " - f"logprobs={self.logprobs}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"skip_special_tokens={self.skip_special_tokens}, " - "spaces_between_special_tokens=" - f"{self.spaces_between_special_tokens})") + return ( + f"SamplingParams(n={self.n}, " + f"best_of={self.best_of}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens})") diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 695cb893e430e..532c7a4e6c1dc 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -10,9 +10,6 @@ logger = init_logger(__name__) -# A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. -_FAST_LLAMA_TOKENIZER = "hf-internal-testing/llama-tokenizer" - def get_tokenizer( tokenizer_name: str, @@ -29,13 +26,6 @@ def get_tokenizer( "Cannot use the fast tokenizer in slow tokenizer mode.") kwargs["use_fast"] = False - if ("llama" in tokenizer_name.lower() and kwargs.get("use_fast", True) - and tokenizer_name != _FAST_LLAMA_TOKENIZER): - logger.info( - "For some LLaMA V1 models, initializing the fast tokenizer may " - "take a long time. To reduce the initialization time, consider " - f"using '{_FAST_LLAMA_TOKENIZER}' instead of the original " - "tokenizer.") try: tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, @@ -43,13 +33,6 @@ def get_tokenizer( trust_remote_code=trust_remote_code, tokenizer_revision=tokenizer_revision, **kwargs) - except TypeError as e: - # The LLaMA tokenizer causes a protobuf error in some environments. - err_msg = ( - "Failed to load the tokenizer. If you are using a LLaMA V1 model " - f"consider using '{_FAST_LLAMA_TOKENIZER}' instead of the " - "original tokenizer.") - raise RuntimeError(err_msg) from e except ValueError as e: # If the error pertains to the tokenizer class not existing or not # currently being imported, suggest using the --trust-remote-code flag. diff --git a/vllm/utils.py b/vllm/utils.py index 8f88e66ecbac0..2d8f3c86bafc7 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,4 +1,5 @@ import enum +import socket import uuid from platform import uname @@ -114,11 +115,6 @@ def get_max_shared_memory_bytes(gpu: int = 0) -> int: return int(max_shared_mem) -def get_gpu_memory(gpu: int = 0) -> int: - """Returns the total memory of the GPU in bytes.""" - return torch.cuda.get_device_properties(gpu).total_memory - - def get_cpu_memory() -> int: """Returns the total CPU memory of the node in bytes.""" return psutil.virtual_memory().total @@ -147,3 +143,9 @@ def _async_wrapper(*args, **kwargs) -> asyncio.Future: return loop.run_in_executor(executor=None, func=p_func) return _async_wrapper + + +def get_open_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 76daaac15904a..662b3a809d33f 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,6 +1,9 @@ -from typing import Dict, List, Optional, Tuple, Set +import time +from typing import Dict, List, Tuple, Union, Optional, Set +import numpy as np import torch +import torch.nn as nn from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger @@ -13,10 +16,16 @@ ) from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest +from vllm.utils import in_wsl logger = init_logger(__name__) +KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 +# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# NOTE: _get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + LORA_WARMUP_RANK = 8 @@ -43,6 +52,22 @@ def __init__( self.block_size = None # Set after initial profiling. self.lora_manager = None + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture + if self.model_config is not None else 0) + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() + def load_model(self) -> None: self.model = get_model(self.model_config, self.lora_config) @@ -63,6 +88,11 @@ def load_model(self) -> None: def set_block_size(self, block_size: int) -> None: self.block_size = block_size + max_num_blocks = (self.max_context_len_to_capture + block_size - + 1) // block_size + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], @@ -149,6 +179,7 @@ def _prepare_prompt( max_context_len=None, context_lens=None, block_tables=None, + use_cuda_graph=False, ) return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping @@ -176,14 +207,14 @@ def _prepare_decode( generation_token = seq_data.get_last_token_id() input_tokens.append([generation_token]) - context_len = seq_data.get_len() - if self.sliding_window is not None: - context_len = min(context_len, self.sliding_window) - context_lens.append(context_len) - - position = context_len - 1 + seq_len = seq_data.get_len() + position = seq_len - 1 input_positions.append([position]) + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + block_table = seq_group_metadata.block_tables[seq_id] block_number = block_table[position // self.block_size] block_offset = position % self.block_size @@ -198,22 +229,52 @@ def _prepare_decode( block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) + batch_size = len(input_tokens) + max_context_len = max(context_lens) + use_captured_graph = ( + not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_context_len <= self.max_context_len_to_capture) + if use_captured_graph: + # Pad the input tokens, positions, and slot mapping to match the + # batch size of the captured graph. + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + for _ in range(graph_batch_size - batch_size): + input_tokens.append([]) + input_positions.append([]) + slot_mapping.append([]) + context_lens.append(1) + block_tables.append([]) + batch_size = graph_batch_size + + # When using CUDA graph, we don't need to make the tensors on the GPU + # because they will be eventually copied to the designated GPU buffer. + device = "cpu" if use_captured_graph else "cuda" + pin_memory = use_captured_graph and not self.in_wsl input_tokens = _make_tensor_with_pad(input_tokens, max_len=1, pad=0, - dtype=torch.long) + dtype=torch.long, + device=device, + pin_memory=pin_memory) input_positions = _make_tensor_with_pad(input_positions, max_len=1, pad=0, - dtype=torch.long) + dtype=torch.long, + device=device, + pin_memory=pin_memory) slot_mapping = _make_tensor_with_pad(slot_mapping, max_len=1, pad=_PAD_SLOT_ID, - dtype=torch.long) - max_context_len = max(context_lens) + dtype=torch.long, + device=device, + pin_memory=pin_memory) context_lens = torch.tensor(context_lens, dtype=torch.int, - device="cuda") + device=device, + pin_memory=pin_memory) + max_block_table_len = max([len(t) for t in block_tables]) block_tables = _make_tensor_with_pad(block_tables, max_len=max_block_table_len, @@ -222,12 +283,30 @@ def _prepare_decode( lora_index_mapping = [ _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping ] + + if use_captured_graph: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.graph_block_tables[:batch_size] + for i, block_table in enumerate(block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=device) + else: + block_tables = _make_tensor_with_pad( + block_tables, + max_len=max_context_len, + pad=0, + dtype=torch.int, + ) + input_metadata = InputMetadata( prompt_lens=[], slot_mapping=slot_mapping, max_context_len=max_context_len, context_lens=context_lens, block_tables=block_tables, + use_cuda_graph=use_captured_graph, ) return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping @@ -284,11 +363,11 @@ def _prepare_sample( categorized_sample_indices_start_idx + num_seqs)) categorized_sample_indices_start_idx += num_seqs - selected_token_indices = torch.tensor(selected_token_indices, - dtype=torch.long, - device="cuda") + selected_token_indices = _async_h2d(selected_token_indices, + dtype=torch.long, + pin_memory=not self.in_wsl) categorized_sample_indices = { - t: torch.tensor(seq_ids, dtype=torch.int, device="cuda") + t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() } @@ -310,18 +389,18 @@ def execute_model( self, seq_group_metadata_list: List[SequenceGroupMetadata], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - cache_events: Optional[List[torch.cuda.Event]] = None, ) -> SamplerOutput: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. - # Prepare input tensors. is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. if is_prompt: inputs = self._prepare_prompt(seq_group_metadata_list) input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping = inputs else: inputs = self._prepare_decode(seq_group_metadata_list) input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping = inputs + sampling_metadata, lora_requests = self._prepare_sample( seq_group_metadata_list, input_metadata.prompt_lens) @@ -336,12 +415,16 @@ def execute_model( self.apply_loras(lora_requests, lora_mapping) # Execute the model. - hidden_states = self.model( + if input_metadata.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + hidden_states = model_executable( input_ids=input_tokens, positions=input_positions, kv_caches=kv_caches, input_metadata=input_metadata, - cache_events=cache_events, ) # Sample the next token. @@ -403,7 +486,59 @@ def profile_run(self) -> None: num_layers = self.model_config.get_num_layers(self.parallel_config) kv_caches = [(None, None)] * num_layers self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() return + + @torch.inference_mode() + def capture_model(self, kv_caches: List[KVCache]) -> None: + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, 1, + dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(_BATCH_SIZES_TO_CAPTURE): + # Create dummy input_metadata. + input_metadata = InputMetadata( + prompt_lens=[], + slot_mapping=slot_mapping[:batch_size], + max_context_len=self.max_context_len_to_capture, + context_lens=context_lens[:batch_size], + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + ) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + kv_caches, + input_metadata, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") def remove_all_loras(self) -> bool: return self.lora_manager.remove_all_loras() @@ -422,6 +557,87 @@ def list_loras(self) -> Set[int]: return self.lora_manager.list_loras() +class CUDAGraphRunner: + + def __init__(self, model: nn.Module): + self.model = model + self.graph = None + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + def capture( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + memory_pool, + ) -> None: + assert self.graph is None + # Run the model once without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + self.model( + input_ids, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Capture the graph. + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): + hidden_states = self.model( + input_ids, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": input_metadata.slot_mapping, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, + } + self.output_buffers = {"hidden_states": hidden_states} + return + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + input_metadata: InputMetadata, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, + non_blocking=True) + self.input_buffers["context_lens"].copy_(input_metadata.context_lens, + non_blocking=True) + self.input_buffers["block_tables"].copy_(input_metadata.block_tables, + non_blocking=True) + + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["hidden_states"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: assert len(x) <= max_len return x + [pad] * (max_len - len(x)) @@ -432,6 +648,25 @@ def _make_tensor_with_pad( max_len: int, pad: int, dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", + pin_memory: bool = False, ) -> torch.Tensor: padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, dtype=dtype, device="cuda") + return torch.tensor(padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu") + + +def _get_graph_batch_size(batch_size: int) -> int: + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return (batch_size + 7) // 8 * 8 + + +def _async_h2d(data: list, dtype, pin_memory): + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory) + return t.to(device="cuda", non_blocking=True) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d58849e76828e..49d7fdbb32c71 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -14,7 +14,6 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner -from vllm.utils import get_gpu_memory from vllm.lora.request import LoRARequest @@ -52,7 +51,15 @@ def __init__( self.gpu_cache = None self.lora_manager = None - def init_model(self): + def init_model(self) -> None: + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + # This env var set by Ray causes exceptions with graph building. os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) # Env vars will be set by Ray. @@ -86,7 +93,6 @@ def profile_num_available_blocks( # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() - torch.cuda.reset_peak_memory_stats() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -95,8 +101,9 @@ def profile_num_available_blocks( # Calculate the number of blocks that can be allocated with the # profiled peak memory. torch.cuda.synchronize() - peak_memory = torch.cuda.max_memory_allocated() - total_gpu_memory = get_gpu_memory() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + peak_memory = total_gpu_memory - free_gpu_memory + cache_block_size = CacheEngine.get_cache_block_size( block_size, self.model_config, self.parallel_config) num_gpu_blocks = int( @@ -108,10 +115,6 @@ def profile_num_available_blocks( self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() - - # Reset the seed to ensure that the random state is not affected by - # the model initialization and profiling. - set_random_seed(self.model_config.seed) return num_gpu_blocks, num_cpu_blocks def init_cache_engine(self, cache_config: CacheConfig) -> None: @@ -122,6 +125,13 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) + def warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + @torch.inference_mode() def execute_model( self, @@ -144,15 +154,17 @@ def execute_model( cache_events = self.cache_events if issued_cache_op else None + # Wait for cache operations to finish. + # TODO(woosuk): Profile swapping overhead and optimize if needed. + if cache_events is not None: + for event in cache_events: + event.wait() # If there is no input, we don't need to execute the model. if not seq_group_metadata_list: - if cache_events is not None: - for event in cache_events: - event.wait() return {} output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache, cache_events) + self.gpu_cache) return output def add_lora(self, lora_request: LoRARequest) -> bool: From 4415c92a8ed35d507582c16a1995b22c3c247db9 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Wed, 20 Dec 2023 15:33:22 +0000 Subject: [PATCH 81/94] Remove redundant include --- csrc/punica/bgmv/bgmv_all.cu | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu index e567cb5c8da7e..bc86416701f13 100644 --- a/csrc/punica/bgmv/bgmv_all.cu +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -1,8 +1,6 @@ #include "bgmv_config.h" #include "bgmv_impl.cuh" -#include "../type_convert.h" - FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) From 46678e1946852b6e60d0e7ba73b72858321a66e5 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Thu, 21 Dec 2023 10:10:20 +0000 Subject: [PATCH 82/94] Fix merge bugs --- vllm/engine/async_llm_engine.py | 3 ++- vllm/worker/model_runner.py | 24 +++++++++--------------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 6752ee144f2ad..910912df73e9a 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -462,7 +462,8 @@ async def generate( prompt: Optional[str], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 662b3a809d33f..56f1c5963b6dc 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -275,15 +275,6 @@ def _prepare_decode( device=device, pin_memory=pin_memory) - max_block_table_len = max([len(t) for t in block_tables]) - block_tables = _make_tensor_with_pad(block_tables, - max_len=max_block_table_len, - pad=0, - dtype=torch.int) - lora_index_mapping = [ - _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping - ] - if use_captured_graph: # The shape of graph_block_tables is # [max batch size, max context len // block size]. @@ -293,12 +284,15 @@ def _prepare_decode( input_block_tables[i, :len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device=device) else: - block_tables = _make_tensor_with_pad( - block_tables, - max_len=max_context_len, - pad=0, - dtype=torch.int, - ) + max_block_table_len = max([len(t) for t in block_tables]) + block_tables = _make_tensor_with_pad(block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int) + + lora_index_mapping = [ + _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping + ] input_metadata = InputMetadata( prompt_lens=[], From 6640a2e65d9a685be637707e1f58351c6012616b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Dec 2023 17:01:12 -0800 Subject: [PATCH 83/94] Apply feedback from code review --- csrc/punica/bgmv/bgmv_all.cu | 2 +- csrc/punica/bgmv/bgmv_config.h | 2 +- tests/lora/test_layers.py | 43 +++++----- tests/lora/test_lora.py | 14 ++-- tests/lora/test_lora_manager.py | 52 ++++++------ tests/lora/test_worker.py | 8 +- tests/lora/utils.py | 12 +-- vllm/core/scheduler.py | 7 +- vllm/engine/arg_utils.py | 31 ++++--- vllm/engine/llm_engine.py | 34 ++++---- vllm/lora/layers.py | 121 +++++++++------------------- vllm/lora/lora.py | 89 +++++++++++++------- vllm/lora/models.py | 99 +++++++++-------------- vllm/lora/worker_manager.py | 60 +++----------- vllm/model_executor/model_loader.py | 2 +- vllm/worker/model_runner.py | 20 ++--- vllm/worker/worker.py | 1 - 17 files changed, 261 insertions(+), 336 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu index bc86416701f13..2502a67e3c813 100644 --- a/csrc/punica/bgmv/bgmv_all.cu +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -18,4 +18,4 @@ FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) \ No newline at end of file +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index ce2a2112a3d91..ced0397dab216 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -46,7 +46,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ -// Keep above in sync with vllm/lora/layers::LoRASampler +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 1020d2cd684f1..71c671132205a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,16 +8,16 @@ import torch.nn.functional as F from vllm.lora.layers import ( - LoRAColumnParallelLinear, - LoRAMergedColumnParallelLinear2Slice, - LoRAQKVParallelLinear, - LoRAVocabParallelEmbedding, - LoRARowParallelLinear, - LoRASampler, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + VocabParallelEmbeddingWithLoRA, + RowParallelLinearWithLoRA, + SamplerWithLoRA, LoRAMapping, - LoRALayer, + BaseLayerWithLoRA, ) -from vllm.lora.models import LoRA, convert_mapping +from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights from vllm.config import LoRAConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -66,11 +66,11 @@ def get_random_id_to_index(num_loras: int, def populate_loras( id_to_index: List[Optional[int]], - layer: LoRALayer, + layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, generate_embeddings_tensor: int = 0, repeats: int = 1, -) -> Tuple[Dict[int, LoRA], Dict[int, List[LoRA]]]: +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. Args: @@ -89,12 +89,12 @@ def populate_loras( # Dictionary that maps the lora ID to the # corresponding lora weights. - lora_dict: Dict[int, LoRA] = dict() + lora_dict: Dict[int, LoRALayerWeights] = dict() # Dictionary that maps the lora ID to the # corresponding subloras. Only useful when # repeats > 1. - sublora_dict: Dict[int, List[LoRA]] = dict() + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() for slot_idx, lora_id in enumerate(id_to_index): if lora_id is not None: @@ -111,7 +111,8 @@ def populate_loras( sublora.optimize() subloras.append(sublora) - lora = LoRA.pack(subloras) if repeats > 1 else subloras[0] + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] layer.set_lora( slot_idx, @@ -179,7 +180,7 @@ def create_random_embedding_layer(): embedding = VocabParallelEmbedding(512, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) embedding.weight.data[512:, :] = 0 - lora_embedding = LoRAVocabParallelEmbedding(embedding) + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) return embedding, lora_embedding @@ -277,7 +278,7 @@ def create_random_embedding_layer(): expanded_embedding.weight.data[:512, :] = embedding_data # We need to deepcopy the embedding as it will be modifed # in place - lora_embedding = LoRAVocabParallelEmbedding( + lora_embedding = VocabParallelEmbeddingWithLoRA( deepcopy(expanded_embedding)) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -400,8 +401,8 @@ def create_random_sampler_layer(): linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, 32000:] = 0 sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) - lora_sampler = LoRASampler(sampler, 1024, linear.weight.dtype, - linear.weight.device) + lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, + linear.weight.device) lora_sampler.create_lora_weights(max_loras, lora_config) return linear, sampler, lora_sampler @@ -510,11 +511,11 @@ def create_random_linear_parallel_layer(): if orientation == "row": linear = RowParallelLinear(4096, 4096, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRARowParallelLinear(linear) + lora_linear = RowParallelLinearWithLoRA(linear) else: linear = ColumnParallelLinear(4096, 4096, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRAColumnParallelLinear(linear) + lora_linear = ColumnParallelLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) return linear, lora_linear @@ -608,11 +609,11 @@ def create_column_parallel_packed_layer(): linear = MergedColumnParallelLinear(4096, [4096] * repeats, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRAMergedColumnParallelLinear2Slice(linear) + lora_linear = MergedColumnParallelLinearWithLoRA(linear) else: linear = QKVParallelLinear(4096, 64, 32, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRAQKVParallelLinear(linear) + lora_linear = QKVParallelLinearWithLora(linear) @dataclass class FakeConfig: diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py index b86f7a480e749..1b972cc53f24d 100644 --- a/tests/lora/test_lora.py +++ b/tests/lora/test_lora.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm.lora.layers import _apply_lora, _apply_lora_packed_2slice, _apply_lora_packed_3slice +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice from .utils import DummyLoRAManager @@ -122,19 +122,19 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T output = torch.zeros(k, m, device="cuda", dtype=dtype) - _apply_lora_packed_2slice( + _apply_lora_packed_nslice( input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, m // 2) + device="cuda"), output, (m // 2, )) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_2slice(input, lora_a_stacks, lora_b_stacks, + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), - output, m // 2) + output, (m // 2, )) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -206,7 +206,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) - _apply_lora_packed_3slice( + _apply_lora_packed_nslice( input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), @@ -216,7 +216,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_3slice(input, lora_a_stacks, lora_b_stacks, + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), output, (qkv[0], qkv[1])) assert torch.allclose(torch.zeros_like(output), output) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index de7b245ad4e79..9c52058ff9a51 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -7,9 +7,10 @@ from torch import nn from vllm.config import LoRAConfig -from vllm.lora.layers import (LoRAColumnParallelLinear, LoRARowParallelLinear, - LoRAMergedColumnParallelLinear2Slice) -from vllm.lora.lora import LoRA, PackedLoRA +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, LoRAMapping) from vllm.lora.request import LoRARequest @@ -54,7 +55,7 @@ def create_lora(lora_id: int, model: nn.Module, loras = {} for name in sub_modules: w = model.get_submodule(name).weight - loras[name] = LoRA( + loras[name] = LoRALayerWeights( name, 8, 16, @@ -76,7 +77,7 @@ def create_packed_lora( for replaced_module_name in replaced_module_names: if replaced_module_name == empty_replaced_module_name: continue - loras[replaced_module_name] = LoRA( + loras[replaced_module_name] = LoRALayerWeights( replaced_module_name, 8, 16, @@ -99,12 +100,13 @@ def test_replace_submodules(dist_init, dummy_model): lora_target_modules=["dense1", "layer1.dense2"]) model = manager.model - assert isinstance(model.get_submodule("dense1"), LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("dense1"), + ColumnParallelLinearWithLoRA) assert isinstance(model.get_submodule("layer1.dense1"), - LoRAColumnParallelLinear) + ColumnParallelLinearWithLoRA) assert isinstance(model.get_submodule("dense2"), RowParallelLinear) assert isinstance(model.get_submodule("layer1.dense2"), - LoRARowParallelLinear) + RowParallelLinearWithLoRA) def test_lora_model_manager(dist_init, dummy_model): @@ -289,10 +291,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, worker_lora_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, torch.device("cuda")) - worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) @@ -300,7 +302,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) @@ -311,7 +313,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) @@ -322,7 +324,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) @@ -333,7 +335,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) @@ -346,7 +348,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -362,10 +364,10 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, worker_lora_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, torch.device("cuda")) - worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) @@ -373,7 +375,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) @@ -383,7 +385,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) @@ -393,7 +395,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) @@ -403,7 +405,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) @@ -415,7 +417,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -446,12 +448,12 @@ def test_packed_loras(dist_init, dummy_model_gate_up): model = manager.model assert isinstance(model.get_submodule("gate_up_proj"), - LoRAMergedColumnParallelLinear2Slice) + MergedColumnParallelLinearWithLoRA) assert manager.add_lora(model_lora) assert manager.add_lora(model_lora1) packed_lora = model_lora.get_lora("gate_up_proj") - assert packed_lora and isinstance(packed_lora, PackedLoRA) + assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) assert torch.allclose(packed_lora.lora_a[0], model_lora.get_lora("gate_proj").lora_a) @@ -463,7 +465,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up): model_lora.get_lora("up_proj").lora_b) packed_lora1 = model_lora1.get_lora("gate_up_proj") - assert packed_lora1 and isinstance(packed_lora1, PackedLoRA) + assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_b[0] is None diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index abc8babd55e93..126d910f53ab3 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -30,7 +30,7 @@ def test_worker_apply_lora(sql_lora_files): worker.init_model() worker.load_model() - worker.model_runner.apply_loras([], LoRAMapping([], [])) + worker.model_runner.set_active_loras([], LoRAMapping([], [])) assert worker.list_loras() == set() n_loras = 32 @@ -38,7 +38,7 @@ def test_worker_apply_lora(sql_lora_files): LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) ] - worker.model_runner.apply_loras(lora_requests, LoRAMapping([], [])) + worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], [])) assert worker.list_loras() == { lora_request.lora_int_id for lora_request in lora_requests @@ -50,8 +50,8 @@ def test_worker_apply_lora(sql_lora_files): k=random.randint(1, n_loras)) random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] - worker.model_runner.apply_loras(iter_lora_requests, LoRAMapping([], - [])) + worker.model_runner.set_active_loras(iter_lora_requests, + LoRAMapping([], [])) assert worker.list_loras().issuperset( {lora_request.lora_int_id for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 072a0d957758b..280e0f2043e68 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2,7 +2,7 @@ import torch -from vllm.lora.lora import LoRA +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights class DummyLoRAManager: @@ -11,10 +11,10 @@ def __init__(self): super().__init__() self._loras = {} - def set_module_lora(self, module_name: str, lora: LoRA): + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): self._loras[module_name] = lora - def get_module_lora(self, module_name: str) -> Optional[LoRA]: + def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]: return self._loras.get(module_name, None) def init_random_lora(self, @@ -22,7 +22,7 @@ def init_random_lora(self, weight: torch.Tensor, rank: int = 8, generate_embeddings_tensor: int = 0): - lora = LoRA( + lora = LoRALayerWeights( module_name, rank=rank, lora_alpha=1, @@ -49,7 +49,7 @@ def init_lora(self, rank=8, noop=False, embeddings_tensor=None): - lora = LoRA( + lora = LoRALayerWeights( module_name, rank=rank, lora_alpha=1, @@ -83,6 +83,6 @@ def init_packed_lora( noop=i in noop_lora_index, ) base_loras.append(base_lora) - packed_lora = LoRA.pack(base_loras) + packed_lora = PackedLoRALayerWeights.pack(base_loras) self.set_module_lora(module_name, packed_lora) return packed_lora diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 14647958cbd29..fc5ee185c4045 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -37,7 +37,6 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], - lora_enabled: bool = False, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run @@ -49,8 +48,8 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups - if lora_enabled: - self.num_loras = len(set(self.lora_requests)) + self.num_loras = len(self.lora_requests) + if self.num_loras > 0: self._sort_by_lora_ids() def is_empty(self) -> bool: @@ -243,7 +242,6 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, - lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -334,7 +332,6 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], - lora_enabled=self.lora_enabled, ) return scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 23f86e7637037..62e5aa5257914 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -40,7 +40,7 @@ class EngineArgs: max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 lora_dtype = 'auto' - max_cpu_loras: int = -1 + max_cpu_loras: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -211,24 +211,29 @@ def add_cli_args( # LoRA related configs parser.add_argument('--enable-lora', action='store_true', - help='enable lora adapters') + help='If True, enable handling of LoRA adapters.') parser.add_argument('--max-loras', type=int, default=EngineArgs.max_loras, - help='max number of LoRAs in a single batch') + help='Max number of LoRAs in a single batch.') parser.add_argument('--max-lora-rank', type=int, default=EngineArgs.max_lora_rank, - help='max LoRA rank') - parser.add_argument('--lora-extra-vocab-size', - type=int, - default=EngineArgs.lora_extra_vocab_size, - help='LoRA extra vocab size') - parser.add_argument('--lora-dtype', - type=str, - default=EngineArgs.lora_dtype, - choices=['auto', 'float16', 'bfloat16', 'float32'], - help='data type for LoRA') + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) parser.add_argument( '--max-cpu-loras', type=int, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0bfd3276d8de6..12dea4e842b44 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -125,6 +125,9 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] + def get_tokenizer_for_seq(self, sequence: Sequence): + return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -154,7 +157,7 @@ def _init_workers(self, distributed_init_method: str): max_parallel_loading_workers, ) - def _init_tokenizer(self, **kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, @@ -162,7 +165,7 @@ def _init_tokenizer(self, **kwargs): tokenizer_mode=self.model_config.tokenizer_mode, trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) - init_kwargs.update(kwargs) + init_kwargs.update(tokenizer_init_kwargs) self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( self.model_config.tokenizer, **init_kwargs) @@ -389,13 +392,13 @@ def _check_beam_search_early_stopping( current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - current_worst_seq.lora_request).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + current_worst_seq).eos_token_id)) if early_stopping is False: highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - best_running_seq.lora_request).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -409,8 +412,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - best_running_seq.lora_request).eos_token_id, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -419,8 +422,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - best_running_seq.lora_request).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -511,8 +514,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request - ).eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -540,8 +542,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request - ).eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) # Check if we can stop the beam search. @@ -721,7 +722,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer.get_lora_tokenizer(seq.lora_request), + self.get_tokenizer_for_seq(seq), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -764,8 +765,7 @@ def _check_stop(self, seq: Sequence, # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) and seq.get_last_token_id() - == self.tokenizer.get_lora_tokenizer( - seq.lora_request).eos_token_id): + == self.get_tokenizer_for_seq(seq).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index f03ae78ee8751..252909c859628 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -51,64 +51,20 @@ def _apply_lora( output: (batch_size, output_dim) """ org_output = output - if x.ndim == 3: - x = x.view(x.shape[0] * x.shape[1], -1) - if output.ndim == 3: - output = output.view(output.shape[0] * output.shape[1], -1) + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) -def _apply_lora_packed_2slice( - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, torch.Tensor], - indices: torch.Tensor, - output: torch.Tensor, - output_dim: int, -): - """Applies lora to each input. - - This method applies all loras to each input. It uses the - indices vector to determine which lora yields the - correct output. An index of -1 means no lora should be - applied. This method adds the final lora results to the - output. - - This method is used for layers that are composed of 2 sublayers - (slices) packed together (eg. gate_proj + up_proj -> - gate_up_proj). - - Both slices must have the same size (output_dim), meaning the output - tensor will have size output_dim*2. - - Input shapes: - x: (batch_size, hidden_dim) - lora_a_stacked: 2 element tuple of (num_loras, lora_rank, hidden_dim) - lora_b_stacked: 2 element tuple of (num_loras, output_dim, lora_rank) - indices: (batch_size) - output: (batch_size, output_dim*2) - output_dim: scalar - """ - org_output = output - if x.ndim == 3: - x = x.view(x.shape[0] * x.shape[1], -1) - if output.ndim == 3: - output = output.view(output.shape[0] * output.shape[1], -1) - add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, - 1.0, 0, output_dim) - add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, - 1.0, output_dim, output_dim) - return output.view_as(org_output) - - -def _apply_lora_packed_3slice( +def _apply_lora_packed_nslice( x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, int], + output_slices: Tuple[int, ...], ): """Applies lora to each input. @@ -118,10 +74,8 @@ def _apply_lora_packed_3slice( applied. This method adds the final lora results to the output. - This method is used for layers that are composed of 3 sublayers - (slices) packed together (attention projection). The - first slice (Q) may have different size from the two subsequent - slices (K, V). + This method is used for layers that are composed of multiple sublayers + (slices) packed together. Input shapes: x: (batch_size, hidden_dim) @@ -129,13 +83,12 @@ def _apply_lora_packed_3slice( lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) indices: (batch_size) output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: 2 element tuple of (q_slice_size, kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), where n is number of slices """ org_output = output - if x.ndim == 3: - x = x.view(x.shape[0] * x.shape[1], -1) - if output.ndim == 3: - output = output.view(output.shape[0] * output.shape[1], -1) + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, 1.0, 0, output_slices[0]) add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, @@ -147,20 +100,17 @@ def _apply_lora_packed_3slice( @dataclass class LoRAMapping: + # Per every token in input_ids: index_mapping: Tuple[int, ...] + # Per sampled token: prompt_mapping: Tuple[int, ...] - def __eq__(self, __value: object) -> bool: - return (isinstance(__value, self.__class__) - and self.prompt_mapping == __value.prompt_mapping - and self.index_mapping == __value.index_mapping) - def __post_init__(self): self.index_mapping = tuple(self.index_mapping) self.prompt_mapping = tuple(self.prompt_mapping) -class LoRALayer(nn.Module): +class BaseLayerWithLoRA(nn.Module): def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, model_config: PretrainedConfig) -> None: @@ -193,7 +143,7 @@ def set_mapping( ... -class LoRAVocabParallelEmbedding(LoRALayer): +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() @@ -327,7 +277,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return full_output.view_as(full_output_org) -class LoRAColumnParallelLinear(LoRALayer): +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() @@ -432,7 +382,7 @@ def linear_weights(self): return self.base_layer.linear_weights -class LoRAMergedColumnParallelLinear2Slice(LoRAColumnParallelLinear): +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 2 sublayers (slices) packed together (eg. gate_proj + up_proj -> gate_up_proj). @@ -523,18 +473,18 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) - _apply_lora_packed_2slice( + _apply_lora_packed_nslice( x, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, - self.output_dim, + (self.output_dim, ), ) return output -class LoRAQKVParallelLinear(LoRAColumnParallelLinear): +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -687,7 +637,7 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) - _apply_lora_packed_3slice( + _apply_lora_packed_nslice( x, self.lora_a_stacked, self.lora_b_stacked, @@ -698,7 +648,7 @@ def apply_weights(self, x: torch.Tensor, return output -class LoRARowParallelLinear(LoRALayer): +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() @@ -824,7 +774,7 @@ def weight(self): return self.base_layer.weight -class LoRASampler(LoRALayer): +class SamplerWithLoRA(BaseLayerWithLoRA): def __init__( self, @@ -984,16 +934,17 @@ def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> LoRALayer: +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: supported_layer_types = { - VocabParallelEmbedding: LoRAVocabParallelEmbedding, - ColumnParallelLinear: LoRAColumnParallelLinear, - QKVParallelLinear: LoRAQKVParallelLinear, - MergedColumnParallelLinear: LoRAMergedColumnParallelLinear2Slice, - RowParallelLinear: LoRARowParallelLinear, + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLora, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, } for src_layer_type, lora_layer_type in supported_layer_types.items(): if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck @@ -1009,8 +960,8 @@ def from_layer_sampler( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, -) -> LoRASampler: - ret = LoRASampler(layer, lm_head.embedding_dim, lm_head.weight.dtype, - lm_head.weight.device) +) -> SamplerWithLoRA: + ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) ret.create_lora_weights(max_loras, lora_config, model_config) return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 11006c4e1a1c9..fbb228c9582d4 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -1,10 +1,11 @@ from typing import List, Optional import torch +from vllm.utils import in_wsl -class LoRA: - """A LoRA that is composed of two low rank matrixes.""" +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" def __init__( self, @@ -28,29 +29,7 @@ def __init__( else: self.scaling = scaling - @classmethod - def pack(cls, loras: List["LoRA"]) -> "PackedLoRA": - """Pack a list of LoRAs into a single LoRA. - - If LoRA is None, it signifies that the submodule does not have a LoRA. - """ - first_lora = next(lora for lora in loras if lora is not None) - for lora in loras: - if lora is None: - continue - lora.optimize() - rank = first_lora.rank - module_name = first_lora.module_name - obj = PackedLoRA( - module_name, - rank, - [lora.lora_alpha if lora is not None else None for lora in loras], - [lora.lora_a if lora is not None else None for lora in loras], - [lora.lora_b if lora is not None else None for lora in loras], - scaling=[1 if lora is not None else None for lora in loras]) - return obj - - def optimize(self) -> "LoRA": + def optimize(self) -> "LoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" if self.scaling == 1: return @@ -75,8 +54,42 @@ def extra_vocab_size(self) -> int: return self.embeddings_tensor.shape[ 0] if self.embeddings_tensor is not None else 0 + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + -class PackedLoRA(LoRA): +class PackedLoRALayerWeights(LoRALayerWeights): """LoRA used for packed layers (eg. qkv_proj).""" def __init__( @@ -103,7 +116,29 @@ def __init__( lora_alpha / self.rank for lora_alpha in self.lora_alphas ] - def optimize(self) -> "PackedLoRA": + @classmethod + def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" for i in range(len(self.lora_b)): if self.scaling[i] == 1 or self.lora_b[i] is None: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2f7c6154c7a24..df3d92aa3eef2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -14,12 +14,14 @@ from vllm.config import LoRAConfig from vllm.utils import LRUCache, in_wsl -from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler -from vllm.lora.lora import LoRA +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule logger = logging.getLogger(__name__) +# TODO: The mappings below should be moved to individual model classes. + PACKED_MODULES_CFG = { "qkv_proj": [ "q_proj", @@ -131,38 +133,6 @@ def get_lora_id(): return _GLOBAL_LORA_ID -def _create_dummy_lora(module_name: str, - input_dim: int, - output_dim: int, - rank: int, - dtype: torch.dtype, - device: torch.device, - embeddings_tensor_dim: Optional[int] = None) -> "LoRA": - pin_memory = str(device) == "cpu" and not in_wsl() - lora_a = torch.zeros([input_dim, rank], - dtype=dtype, - device=device, - pin_memory=pin_memory) - lora_b = torch.zeros([rank, output_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) - embeddings_tensor = torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory) if embeddings_tensor_dim else None - return LoRA( - module_name, - rank=rank, - lora_alpha=1, - lora_a=lora_a, - lora_b=lora_b, - embeddings_tensor=embeddings_tensor, - ) - - class LoRAModel: """A LoRA fine-tuned model.""" @@ -170,20 +140,20 @@ def __init__( self, lora_model_id: int, rank: int, - loras: Dict[str, LoRA], + loras: Dict[str, LoRALayerWeights], ) -> None: self.id = lora_model_id assert (lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank - self.loras: Dict[str, LoRA] = loras + self.loras: Dict[str, LoRALayerWeights] = loras @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size for lora in self.loras.values()) if self.loras else 0 - def get_lora(self, module_name: str) -> Optional[LoRA]: + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) @@ -202,7 +172,7 @@ def from_lora_tensors( ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and not in_wsl() - loras: Dict[str, LoRA] = {} + loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) if module_name not in loras: @@ -218,8 +188,9 @@ def from_lora_tensors( if pin_memory: lora_embeddings_tensor = ( lora_embeddings_tensor.pin_memory()) - loras[module_name] = LoRA(module_name, rank, lora_alpha, None, - None, lora_embeddings_tensor) + loras[module_name] = LoRALayerWeights(module_name, rank, + lora_alpha, None, None, + lora_embeddings_tensor) if is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype).t() @@ -241,7 +212,7 @@ def from_lora_tensors( loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() - for _, lora in loras.items(): + for lora in loras.values(): lora.optimize() return cls(lora_model_id, rank, loras) @@ -343,6 +314,9 @@ def __init__( dtype=torch.long, device="cuda") self.offsets = [] + # 4 is the number of indicies tensors defined above + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices self.indices_len = [None] * 4 self.model: nn.Module = model @@ -352,8 +326,9 @@ def __init__( self.lora_target_modules = copy.deepcopy(lora_target_modules) self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} - self.modules: Dict[str, "LoRALayer"] = {} + self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self._registered_loras: Dict[int, LoRAModel] = {} + # Dict instead of a Set for compatibility with LRUCache. self._active_loras: Dict[int, None] = {} self._last_mapping = None self._create_lora_modules() @@ -374,6 +349,7 @@ def activate_lora( self, lora_id: int, ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" if lora_id in self._active_loras: return False first_free_slot = next( @@ -405,6 +381,7 @@ def _deactivate_lora(self, lora_id: int): pass def deactivate_lora(self, lora_id: int) -> bool: + """Remove a LoRA from a GPU buffer.""" if lora_id in self._active_loras: self._deactivate_lora(lora_id) self._active_loras.pop(lora_id) @@ -412,7 +389,7 @@ def deactivate_lora(self, lora_id: int) -> bool: return False def add_lora(self, lora: LoRAModel) -> bool: - """Add a LoRAModel to the manager.""" + """Add a LoRAModel to the manager CPU cache.""" if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") @@ -422,13 +399,13 @@ def add_lora(self, lora: LoRAModel) -> bool: return False def remove_lora(self, lora_id: int) -> bool: - """Remove a LoRAModel from the manager.""" + """Remove a LoRAModel from the manager CPU cache.""" # TODO: should we check active lora? self.deactivate_lora(lora_id) return bool(self._registered_loras.pop(lora_id, None)) # TODO see if this can be vectorized - def convert_mapping(self, mapping: LoRAMapping) -> None: + def _set_lora_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, indices_len) = convert_mapping(mapping, self.lora_index_to_id, @@ -444,9 +421,9 @@ def convert_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - def set_row_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: if self._last_mapping != lora_mapping: - self.convert_mapping(lora_mapping) + self._set_lora_mapping(lora_mapping) self._last_mapping = lora_mapping def list_loras(self) -> Dict[int, LoRAModel]: @@ -484,8 +461,8 @@ def _create_lora_modules(self): self.sampler_indices_padded, self.embeddings_indices, self.indices_len) - def register_module(self, module_name: str, module: "LoRALayer"): - assert isinstance(module, LoRALayer) + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) self.modules[module_name] = module def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: @@ -493,7 +470,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): if not self._match_target_modules(module_name) or not isinstance( - module, LoRALayer): + module, BaseLayerWithLoRA): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -509,7 +486,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: hasattr(module.base_layer, "embedding_dim") else module.base_layer.weight.shape[1]) - lora = _create_dummy_lora( + lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, output_dim, @@ -518,7 +495,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: "cpu", embeddings_tensor_dim=embeddings_tensor_dim) else: - lora = _create_dummy_lora( + lora = LoRALayerWeights.create_dummy_lora_weights( module_name, module.lora_a_stacked.shape[-1], module.lora_b_stacked.shape[-2], @@ -532,7 +509,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: replacements = self.packed_modules_mapping[parts[-1]] subloras = [] for i, r in enumerate(replacements): - lora = _create_dummy_lora( + lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, module.lora_a_stacked[i].shape[-1], module.lora_b_stacked[i].shape[-2], @@ -542,7 +519,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: ) lora.optimize() subloras.append(lora) - lora = LoRA.pack(subloras) + lora = PackedLoRALayerWeights.pack(subloras) model.loras[module_name] = lora return model @@ -579,7 +556,8 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: if replacement_loras[i]: continue replacement_loras[i] = None - lora_model.loras[module_name] = LoRA.pack(replacement_loras) + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) class LoRALRUCache(LRUCache): @@ -647,16 +625,15 @@ def remove_oldest_lora(self) -> bool: return False -def create_lora_adapter( +def create_lora_manager( model: nn.Module, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, - lora_config:LoRAConfig, - target_modules: Union[str, - List[str]] = TARGET_MODULES_QKV, - lora_manager_cls:Type[LoRAModelManager] = LoRAModelManager, **kwargs)\ - -> LoRAModelManager: + lora_config: LoRAConfig, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not getattr(model, "supports_lora", False): raise ValueError(f"Model {type(model)} is not supported for LoRA.") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 363b7770be178..a507c08588dad 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -5,7 +5,7 @@ import torch from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager, create_lora_adapter) + LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest from vllm.lora.layers import LoRAMapping from vllm.config import LoRAConfig @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -class AbstractWorkerLoRAManager(ABC): +class WorkerLoRAManager(ABC): """Abstract class for managing LoRA models on the worker side.""" def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, @@ -30,7 +30,7 @@ def is_enabled(self) -> bool: ... @abstractmethod - def create_lora_adapter( + def create_lora_manager( self, model: torch.nn.Module, target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, @@ -38,8 +38,8 @@ def create_lora_adapter( ... @abstractmethod - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: ... @abstractmethod @@ -63,41 +63,7 @@ def list_loras(self) -> Set[int]: ... -class DisabledWorkerLoRAManager(AbstractWorkerLoRAManager): - """WorkerLoRAManager that does nothing.""" - - @property - def is_enabled(self) -> bool: - return False - - def create_lora_adapter( - self, - model: torch.nn.Module, - target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, - ) -> Any: - return model - - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: - return - - def add_lora(self, lora_request: LoRARequest) -> bool: - return False - - def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - return False - - def remove_lora(self, lora_id: int) -> bool: - return False - - def remove_all_loras(self) -> bool: - return - - def list_loras(self) -> Set[int]: - return set() - - -class WorkerLoRAManager(AbstractWorkerLoRAManager): +class WorkerLoRAManager(WorkerLoRAManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already @@ -123,12 +89,12 @@ def __init__( def is_enabled(self) -> bool: return True - def create_lora_adapter( + def create_lora_manager( self, model: torch.nn.Module, target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: - lora_manager = create_lora_adapter( + lora_manager = create_lora_manager( model, max_num_seqs=self.max_num_seqs, max_num_batched_tokens=self.max_num_batched_tokens, @@ -140,10 +106,10 @@ def create_lora_adapter( self._lora_manager: LoRAModelManager = lora_manager return lora_manager.model - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: self._apply_loras(lora_requests) - self._lora_manager.set_row_lora_mapping(lora_mapping) + self._lora_manager.set_lora_mapping(lora_mapping) def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: loras_that_exist = self.list_loras() @@ -226,12 +192,12 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): _lora_manager_cls: Type[ LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager - def create_lora_adapter( + def create_lora_manager( self, model: torch.nn.Module, target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: - lora_manager = create_lora_adapter( + lora_manager = create_lora_manager( model, target_modules=target_modules, lora_manager_cls=self._lora_manager_cls, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 9903cb540d438..0f1125e5c8e3e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -63,7 +63,7 @@ def get_model(model_config: ModelConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - if getattr(model_class, "supports_lora", True): + if getattr(model_class, "supports_lora", False): model = model_class(model_config.hf_config, linear_method, lora_config) elif lora_config: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e20da8b5a2ad5..3552c5d665668 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,10 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.lora.worker_manager import ( - DisabledWorkerLoRAManager, - LRUCacheWorkerLoRAManager, -) +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.utils import in_wsl @@ -77,12 +74,7 @@ def load_model(self) -> None: self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, vocab_size, self.lora_config, self.device) - self.model = self.lora_manager.create_lora_adapter(self.model) - else: - self.lora_manager = DisabledWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, vocab_size, - self.lora_config, self.device) + self.model = self.lora_manager.create_lora_manager(self.model) def set_block_size(self, block_size: int) -> None: self.block_size = block_size @@ -409,7 +401,7 @@ def execute_model( flat_lora_index_mapping, lora_prompt_mapping, ) - self.apply_loras(lora_requests, lora_mapping) + self.set_active_loras(lora_requests, lora_mapping) # Execute the model. if input_metadata.use_cuda_graph: @@ -492,9 +484,9 @@ def profile_run(self) -> None: def remove_all_loras(self) -> bool: return self.lora_manager.remove_all_loras() - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: - self.lora_manager.apply_loras(lora_requests, lora_mapping) + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self.lora_manager.set_active_loras(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: return self.lora_manager.add_lora(lora_request) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 49d7fdbb32c71..f1cce3f83527f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -49,7 +49,6 @@ def __init__( self.cache_engine = None self.cache_events = None self.gpu_cache = None - self.lora_manager = None def init_model(self) -> None: # torch.distributed.all_reduce does not free the input tensor until From a9f2ad5d9dcf0f42e2c32d9dbf021b51747348bf Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 22 Dec 2023 06:23:11 +0000 Subject: [PATCH 84/94] format --- vllm/worker/model_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 56f1c5963b6dc..8d1778954bc33 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -194,7 +194,7 @@ def _prepare_decode( slot_mapping: List[List[int]] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] - lora_index_mapping: List[int] = [] + lora_index_mapping: List[List[int]] = [] lora_prompt_mapping: List[int] = [] for seq_group_metadata in seq_group_metadata_list: From 1d2ddb7400d0f46557da50f2b801a8d4b798fd72 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 22 Dec 2023 06:44:59 +0000 Subject: [PATCH 85/94] Fix --- setup.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/setup.py b/setup.py index 5405c85109371..22a6a251e5c23 100644 --- a/setup.py +++ b/setup.py @@ -52,10 +52,6 @@ def _is_cuda() -> bool: NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() -<<<<<<< HEAD -======= - ->>>>>>> multi_lora def get_amdgpu_offload_arch(): command = "/opt/rocm/llvm/bin/amdgpu-offload-arch" try: From 4512cd158a51eeb18f1c6236a5fc09eb625041af Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 22 Dec 2023 09:48:53 +0000 Subject: [PATCH 86/94] Add temporary fixes --- vllm/lora/layers.py | 12 ++++++++---- vllm/worker/model_runner.py | 8 ++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 252909c859628..7e1882adb9f3b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -91,10 +91,14 @@ def _apply_lora_packed_nslice( indices = indices.view(-1) add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, 1.0, 0, output_slices[0]) - add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, - 1.0, output_slices[0], output_slices[1]) - add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, - 1.0, output_slices[0] + output_slices[1], output_slices[1]) + if len(output_slices) == 1: + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_slices[0], output_slices[0]) + else: + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_slices[0], output_slices[1]) + add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, + 1.0, output_slices[0] + output_slices[1], output_slices[1]) return output.view_as(org_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index bbf4119014145..996364d3b5836 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -536,20 +536,20 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") def remove_all_loras(self) -> bool: - return self.lora_manager.remove_all_loras() + return self.lora_manager is not None and self.lora_manager.remove_all_loras() def set_active_loras(self, lora_requests: List[LoRARequest], lora_mapping: LoRAMapping) -> None: self.lora_manager.set_active_loras(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: - return self.lora_manager.add_lora(lora_request) + return self.lora_manager is not None and self.lora_manager.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: - return self.lora_manager.remove_lora(lora_id) + return self.lora_manager is not None and self.lora_manager.remove_lora(lora_id) def list_loras(self) -> Set[int]: - return self.lora_manager.list_loras() + return self.lora_manager is not None and self.lora_manager.list_loras() @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: From 43bcf2ccdd4b5fa8449873768f58c36035c05303 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 22 Dec 2023 16:31:40 +0000 Subject: [PATCH 87/94] Add temporary fixes --- vllm/worker/model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 996364d3b5836..c922e91bad0ef 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -540,7 +540,8 @@ def remove_all_loras(self) -> bool: def set_active_loras(self, lora_requests: List[LoRARequest], lora_mapping: LoRAMapping) -> None: - self.lora_manager.set_active_loras(lora_requests, lora_mapping) + if self.lora_manager is not None: + self.lora_manager.set_active_loras(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: return self.lora_manager is not None and self.lora_manager.add_lora(lora_request) From 6b2e6a51ec5d3ceaec66c8c62dd081148d65426a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Dec 2023 16:06:15 -0800 Subject: [PATCH 88/94] Fixes --- vllm/lora/layers.py | 13 +++++++------ vllm/worker/model_runner.py | 10 +++++++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 252909c859628..8f16c1ecb330e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -89,12 +89,13 @@ def _apply_lora_packed_nslice( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, - 1.0, 0, output_slices[0]) - add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, - 1.0, output_slices[0], output_slices[1]) - add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, - 1.0, output_slices[0] + output_slices[1], output_slices[1]) + output_slices = (0, ) + output_slices + for slice_idx_right in range(1, len(output_slices)): + slice_idx_left = slice_idx_right - 1 + add_lora_slice(output, x, lora_a_stacked[slice_idx_left], + lora_b_stacked[slice_idx_left], indices, 0, 1.0, + output_slices[slice_idx_left], + output_slices[slice_idx_right]) return output.view_as(org_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3552c5d665668..2ad56365c97ad 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -72,7 +72,8 @@ def load_model(self) -> None: if self.lora_config: self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, vocab_size, + self.scheduler_config.max_num_batched_tokens + + self.scheduler_config.max_paddings, vocab_size, self.lora_config, self.device) self.model = self.lora_manager.create_lora_manager(self.model) @@ -532,6 +533,13 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: use_cuda_graph=True, ) + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + graph_runner = CUDAGraphRunner(self.model) graph_runner.capture( input_tokens[:batch_size], From 4b2224eaac2d80f5eeed1b934dd1473d278e32d8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Dec 2023 17:13:49 -0800 Subject: [PATCH 89/94] Fixes --- tests/lora/test_lora.py | 8 +- tests/lora/test_lora_manager.py | 158 ++++++++++++++++---------------- vllm/engine/arg_utils.py | 4 +- vllm/lora/layers.py | 18 ++-- vllm/lora/models.py | 14 ++- vllm/worker/model_runner.py | 10 ++ vllm/worker/worker.py | 3 +- 7 files changed, 116 insertions(+), 99 deletions(-) diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py index 1b972cc53f24d..3415d36b7e341 100644 --- a/tests/lora/test_lora.py +++ b/tests/lora/test_lora.py @@ -126,7 +126,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (m // 2, )) + device="cuda"), output, (m // 2, m // 2)) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) @@ -134,7 +134,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: output[:] = 0 _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), - output, (m // 2, )) + output, (m // 2, m // 2)) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -210,7 +210,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (qkv[0], qkv[1])) + device="cuda"), output, (qkv[0], qkv[1], qkv[2])) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) @@ -218,7 +218,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: output[:] = 0 _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), - output, (qkv[0], qkv[1])) + output, (qkv[0], qkv[1], qkv[2])) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 9c52058ff9a51..78a4a5bc5ecd2 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -121,40 +121,40 @@ def test_lora_model_manager(dist_init, dummy_model): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), lora_target_modules=["dense1", "dense2", "lm_head"]) - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) - assert manager.lora_id_to_index[0] == 1 + assert manager.lora_index_to_id[0] == 1 assert not manager.add_lora(model_lora1) assert not manager.activate_lora(1) assert manager.add_lora(model_lora2) assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert not manager.add_lora(model_lora2) assert not manager.activate_lora(2) assert manager.add_lora(model_lora3) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 with pytest.raises(ValueError): assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert manager.remove_lora(model_lora2.id) - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[1] is None assert not manager.remove_lora(model_lora2.id) assert manager.remove_lora(model_lora1.id) assert not manager.remove_lora(model_lora1.id) assert manager.add_lora(model_lora1) - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] is None assert manager.add_lora(model_lora2) assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] is None assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 def test_lora_lru_cache_model_manager(dist_init, dummy_model): @@ -169,43 +169,43 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), lora_target_modules=["dense1", "dense2", "lm_head"]) - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) - assert manager.lora_id_to_index[0] == 1 + assert manager.lora_index_to_id[0] == 1 assert not manager.add_lora(model_lora1) assert not manager.activate_lora(1) assert manager.add_lora(model_lora2) assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert not manager.add_lora(model_lora2) assert not manager.activate_lora(2) assert manager.add_lora(model_lora3) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 assert manager.remove_lora(model_lora2.id) - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[1] is None assert not manager.remove_lora(model_lora2.id) assert manager.remove_lora(model_lora1.id) assert not manager.remove_lora(model_lora1.id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 1 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 1 assert manager.add_lora(model_lora2) assert manager.deactivate_lora(3) - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] == 1 + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 1 assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 2 - assert manager.lora_id_to_index[1] == 1 + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 1 assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 2 - assert manager.lora_id_to_index[1] == 3 + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 3 def test_lru_lora_model_manager(dist_init, dummy_model): @@ -221,7 +221,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model): LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), ["dense1", "dense2", "lm_head"]) - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity assert manager.add_lora(model_lora1) @@ -230,8 +230,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(2) assert set(manager.list_loras()) == {1, 2} - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 # Add over capacity assert manager.add_lora(model_lora3) @@ -240,8 +240,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(4) assert set(manager.list_loras()) == {3, 4} - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 4 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 # Add 3 again to move it to the top and then add 2 # should return false since it's in already @@ -251,16 +251,16 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(2) assert set(manager.list_loras()) == {3, 2} - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 # Remove manually assert manager.remove_lora(3) assert not manager.remove_lora(3) assert set(manager.list_loras()) == {2} - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 2 assert manager.add_lora(model_lora3) assert manager.activate_lora(3) @@ -268,21 +268,21 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(4) assert set(manager.list_loras()) == {3, 4} - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 4 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 assert manager.remove_oldest_lora() assert set(manager.list_loras()) == {4} - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] == 4 + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 4 assert manager.remove_oldest_lora() assert set(manager.list_loras()) == set() - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) assert not manager.remove_oldest_lora() assert set(manager.list_loras()) == set() - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, @@ -299,8 +299,8 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("2", 2, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -308,10 +308,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("4", 4, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 3, 4} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -319,10 +319,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("5", 5, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -330,10 +330,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("1", 1, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), @@ -341,10 +341,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("8", 8, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 6, 7, 8} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 7 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 8 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 6 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 # Over capacity with pytest.raises(RuntimeError): @@ -372,8 +372,8 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("2", 2, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -381,9 +381,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("4", 4, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 3, 4} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -391,9 +391,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("5", 5, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 5} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -401,9 +401,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("1", 1, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None - assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None + assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), @@ -411,9 +411,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("8", 8, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {6, 7, 8} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 8 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 6 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 7 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 # Over capacity with pytest.raises(RuntimeError): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 62e5aa5257914..090fa95bcac02 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -279,8 +279,8 @@ def create_engine_configs( max_loras=self.max_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras - if self.max_cpu_loras > 0 else None) if self.enable_lora else None + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None return model_config, cache_config, parallel_config, scheduler_config, lora_config diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 8f16c1ecb330e..5c26ce37bbf8d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -89,13 +89,12 @@ def _apply_lora_packed_nslice( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - output_slices = (0, ) + output_slices - for slice_idx_right in range(1, len(output_slices)): - slice_idx_left = slice_idx_right - 1 - add_lora_slice(output, x, lora_a_stacked[slice_idx_left], - lora_b_stacked[slice_idx_left], indices, 0, 1.0, - output_slices[slice_idx_left], - output_slices[slice_idx_right]) + offset_left = 0 + for slice_idx in range(len(output_slices)): + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] return output.view_as(org_output) @@ -480,7 +479,7 @@ def apply_weights(self, x: torch.Tensor, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, - (self.output_dim, ), + (self.output_dim, self.output_dim), ) return output @@ -563,7 +562,8 @@ def create_lora_weights( device=self.base_layer.weight.device, )) - self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size) + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, + self.kv_proj_shard_size) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None diff --git a/vllm/lora/models.py b/vllm/lora/models.py index df3d92aa3eef2..6c78c4a2c7771 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -388,13 +388,16 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False + def _add_lora(self, lora: LoRAModel) -> bool: + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager CPU cache.""" if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") - self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora + self._add_lora(lora) return True return False @@ -600,10 +603,13 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - was_added = super().add_lora(lora) - if not was_added: + if lora.id not in self._registered_loras: + self._add_lora(lora) + was_added = True + else: # We always touch to update the LRU cache order self._registered_loras.touch(lora.id) + was_added = False return was_added def activate_lora( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2ad56365c97ad..efbc09be1830d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -483,19 +483,29 @@ def profile_run(self) -> None: return def remove_all_loras(self) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_all_loras() def set_active_loras(self, lora_requests: List[LoRARequest], lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_loras(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_lora(lora_id) def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_loras() @torch.inference_mode() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f1cce3f83527f..bb8e7fd6cf86e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -111,7 +111,8 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - self.model_runner.remove_all_loras() + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks From dc436ef0f4dcc610da4779e2512338e041d08bd3 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Sat, 23 Dec 2023 09:30:23 +0000 Subject: [PATCH 90/94] Simplify includes --- README.md | 10 +++---- csrc/punica/bgmv/vec_dtypes.cuh | 11 +------- csrc/punica/punica_ops.cu | 12 ++------- csrc/punica/type_convert.h | 47 ++++++++++++++++++--------------- 4 files changed, 33 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 1863c31fcf2d2..051cf3c77ec99 100644 --- a/README.md +++ b/README.md @@ -2,19 +2,17 @@ multi-lora rocm development Derived from [Yard1's multi-lora branch](https://github.com/Yard1/vllm/tree/multi_lora) -[A note to whoever wants to run it on ROCm now] +[Important note] -Starting from ROCm v5.7.0, some type conversion utils for bf16 are implemented in header files, but unfortunately some of them are not specified as inline or static functions. So building the project on ROCm directly would lead to failures in multiple definitions of functions. +Starting from ROCm v5.7, some type conversion functions on bfloat16 are implemented in header files. Unfortunately a few of them are not specified as inline or static functions, so building the project on ROCm directly would result in ODR violations when linking the translation units. -To circumvent this I have been manually adding the "inline" keyward to the related functions. In the container we're using it means adding the keyword "inline" to "/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h#96" so that the line becomes +A way to circumvent this is to manually add the `inline` or `static` keywards to the related functions. In the container that `Dockerfile.rocm` builds from, it means adding the keyword `inline` to `/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h:96` so that the line becomes ```cpp L96: #define __HOST_DEVICE__ __host__ __device__ inline ``` -But of course this is far from pretty, if not considered nasty, as it is modifying include header in a potentially dangerous manner, and surely there should be better ways which I'm trying to find. - -This should be solved starting from [ROCm 6.0.0 though](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3). +This is far from a pretty solution though. Even though it appears that [ROCm may be fixing this](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3), it appears to not be included in ROCm v6.0.0. Fixes like this may need to stay around until better solutions come out.

diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh index 86d2971731e4d..2738892e6dc4a 100644 --- a/csrc/punica/bgmv/vec_dtypes.cuh +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -1,13 +1,6 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ -#ifndef USE_ROCM -#include -#include -#else -#include -#include -#endif #ifdef FLASHINFER_USE_FP8 #include #endif @@ -15,10 +8,8 @@ #include -#ifdef USE_ROCM -#include "../../cuda_compat.h" #include "../type_convert.h" -#endif +#include "../../cuda_compat.h" #define FLASHINFER_INLINE \ inline __attribute__((always_inline)) __device__ __host__ diff --git a/csrc/punica/punica_ops.cu b/csrc/punica/punica_ops.cu index 935ce48891ab1..644740d9c49b0 100644 --- a/csrc/punica/punica_ops.cu +++ b/csrc/punica/punica_ops.cu @@ -1,19 +1,11 @@ #include #include -#ifndef USE_ROCM -#include -#include -#else -#include -#include -#endif - #include -#include "bgmv/bgmv_config.h" -#include "../cuda_compat.h" #include "type_convert.h" +#include "../cuda_compat.h" +#include "bgmv/bgmv_config.h" //====== utils ====== diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h index 9180693e2057b..13d5b411a3c0a 100644 --- a/csrc/punica/type_convert.h +++ b/csrc/punica/type_convert.h @@ -1,13 +1,18 @@ #ifndef CSRC__PUNICA__TYPE_CONVERT_H__ #define CSRC__PUNICA__TYPE_CONVERT_H__ -#ifdef USE_ROCM +#ifndef USE_ROCM -#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ +#include +#include + +#else #include #include +#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__ + typedef __half nv_half; typedef __hip_bfloat16 nv_bfloat16; @@ -16,105 +21,105 @@ typedef __hip_bfloat162 nv_bfloat162; __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { -return __hip_bfloat162{val, val}; + return __hip_bfloat162{val, val}; } __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) { -return __hip_bfloat162{vall, valr}; + return __hip_bfloat162{vall, valr}; } __TYPE_CONVERT__HOST_DEVICE__ inline __half cast_to_half(__half val) { -return val; + return val; } __TYPE_CONVERT__HOST_DEVICE__ inline __half cast_to_half(__hip_bfloat16 val) { -return __float2half(__bfloat162float(val)); + return __float2half(__bfloat162float(val)); } __TYPE_CONVERT__HOST_DEVICE__ inline __half cast_to_half(float val) { -return __float2half(val); + return __float2half(val); } __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat16 cast_to_bf16(__half val) { -return __float2bfloat16(__half2float(val)); + return __float2bfloat16(__half2float(val)); } __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat16 cast_to_bf16(__hip_bfloat16 val) { -return val; + return val; } __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat16 cast_to_bf16(float val) { -return __float2bfloat16(val); + return __float2bfloat16(val); } __TYPE_CONVERT__HOST_DEVICE__ inline float cast_to_float(__half val) { -return __half2float(val); + return __half2float(val); } __TYPE_CONVERT__HOST_DEVICE__ inline float cast_to_float(__hip_bfloat16 val) { -return __bfloat162float(val); + return __bfloat162float(val); } __TYPE_CONVERT__HOST_DEVICE__ inline float cast_to_float(float val) { -return val; + return val; } template __TYPE_CONVERT__HOST_DEVICE__ inline T_dst convert_type(T_src val) { -return static_cast(val); + return static_cast(val); } template <> __TYPE_CONVERT__HOST_DEVICE__ inline float convert_type<__half, float>(__half val) { -return __half2float(val); + return __half2float(val); } template <> __TYPE_CONVERT__HOST_DEVICE__ inline __half convert_type(float val) { -return __float2half(val); + return __float2half(val); } template <> __TYPE_CONVERT__HOST_DEVICE__ inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) { -return __bfloat162float(val); + return __bfloat162float(val); } template <> __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat16 convert_type(float val) { -return __float2bfloat16(val); + return __float2bfloat16(val); } template __TYPE_CONVERT__HOST_DEVICE__ inline T vllm_add(T a, T b) { -return a + b; + return a + b; } template <> __TYPE_CONVERT__HOST_DEVICE__ inline __half vllm_add<__half>(__half a, __half b) { -return __hadd(a, b); + return __hadd(a, b); } template <> __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) { -return __hadd(a, b); + return __hadd(a, b); } #undef __TYPE_CONVERT__HOST_DEVICE__ From 6cc1b42c7763cf474e03af7713bbf0021dbaf0f3 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Sat, 23 Dec 2023 15:23:50 +0000 Subject: [PATCH 91/94] update readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 051cf3c77ec99..478e19894b20a 100644 --- a/README.md +++ b/README.md @@ -4,15 +4,15 @@ Derived from [Yard1's multi-lora branch](https://github.com/Yard1/vllm/tree/mult [Important note] -Starting from ROCm v5.7, some type conversion functions on bfloat16 are implemented in header files. Unfortunately a few of them are not specified as inline or static functions, so building the project on ROCm directly would result in ODR violations when linking the translation units. +Starting from ROCm v5.7, some type conversion functions on bfloat16 are included and implemented in header files. Unfortunately a few of the host functions are not specified as inline or static, so building the project on ROCm directly would result in ODR violations when the translation units are being linked. -A way to circumvent this is to manually add the `inline` or `static` keywards to the related functions. In the container that `Dockerfile.rocm` builds from, it means adding the keyword `inline` to `/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h:96` so that the line becomes +A way to circumvent this is to manually add the `inline` or `static` keywards to the related functions. In the `rocm/pytorch` container that `Dockerfile.rocm` builds from, it means adding the keyword `inline` to `/opt/rocm/include/hip/amd_detail/amd_hip_bf16.h:96` so that the line becomes ```cpp L96: #define __HOST_DEVICE__ __host__ __device__ inline ``` -This is far from a pretty solution though. Even though it appears that [ROCm may be fixing this](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3), it appears to not be included in ROCm v6.0.0. Fixes like this may need to stay around until better solutions come out. +This is far from a pretty solution though. Even though it appears that [ROCm is fixing this](https://github.com/ROCm/clr/commit/86bd518981b364c138f9901b28a529899d8654f3), it doesn't seem to be included in ROCm 6.0.0. Fixes like this may need to stay around until better solutions come out.

From 52d25ced5158b8dba3246ac69d2c48a8f474b2be Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Tue, 26 Dec 2023 08:11:31 +0000 Subject: [PATCH 92/94] Remove unused conversions --- csrc/punica/bgmv/bgmv_impl.cuh | 21 ++++----------- csrc/punica/type_convert.h | 47 ---------------------------------- 2 files changed, 5 insertions(+), 63 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 423e5c1cca627..2e72394647c6a 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -18,21 +18,10 @@ namespace cg = cooperative_groups; #ifdef USE_ROCM -__host__ __device__ -inline void* memcpy_blocking(void *dst, const void *src, size_t len) { - // Does not handle the case of long datatypes - char *d = reinterpret_cast(dst); - const char *s = reinterpret_cast(src); - size_t i = 0; - for (i = 0; i < len; ++i) { - d[i] = s[i]; - } - return dst; -} template __host__ __device__ -inline void* memcpy_blocking_unroll(void *dst, const void *src) { +inline void* memcpy_blocking(void *dst, const void *src) { // Does not handle the case of long datatypes char *d = reinterpret_cast(dst); const char *s = reinterpret_cast(src); @@ -86,10 +75,10 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, cuda::aligned_size_t(X_copy_size), pipe); pipe.producer_commit(); #else - memcpy_blocking_unroll(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + memcpy_blocking(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, W + (idx * feat_out + j) * feat_in + (threadIdx.y * tx + threadIdx.x) * vec_size); - memcpy_blocking_unroll(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + memcpy_blocking(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, X + (batch_idx * feat_in) + (threadIdx.y * tx + threadIdx.x) * vec_size); #endif @@ -122,12 +111,12 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, pipe.producer_commit(); #else if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { - memcpy_blocking_unroll(W_shared + W_shared_offset[copy_idx] + + memcpy_blocking(W_shared + W_shared_offset[copy_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size, W + (idx * feat_out + j) * feat_in + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size); - memcpy_blocking_unroll(X_shared + X_shared_offset[copy_idx] + + memcpy_blocking(X_shared + X_shared_offset[copy_idx] + (threadIdx.y * tx + threadIdx.x) * vec_size, X + (batch_idx * feat_in) + tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x) * vec_size); diff --git a/csrc/punica/type_convert.h b/csrc/punica/type_convert.h index 13d5b411a3c0a..dff7ce49283d7 100644 --- a/csrc/punica/type_convert.h +++ b/csrc/punica/type_convert.h @@ -15,9 +15,7 @@ typedef __half nv_half; typedef __hip_bfloat16 nv_bfloat16; - typedef __hip_bfloat162 nv_bfloat162; -//typedef __hip_bfloat16 nv_bfloat16; __TYPE_CONVERT__HOST_DEVICE__ inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) { @@ -29,51 +27,6 @@ inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) return __hip_bfloat162{vall, valr}; } -__TYPE_CONVERT__HOST_DEVICE__ -inline __half cast_to_half(__half val) { - return val; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __half cast_to_half(__hip_bfloat16 val) { - return __float2half(__bfloat162float(val)); -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __half cast_to_half(float val) { - return __float2half(val); -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 cast_to_bf16(__half val) { - return __float2bfloat16(__half2float(val)); -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 cast_to_bf16(__hip_bfloat16 val) { - return val; -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline __hip_bfloat16 cast_to_bf16(float val) { - return __float2bfloat16(val); -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline float cast_to_float(__half val) { - return __half2float(val); -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline float cast_to_float(__hip_bfloat16 val) { - return __bfloat162float(val); -} - -__TYPE_CONVERT__HOST_DEVICE__ -inline float cast_to_float(float val) { - return val; -} - template __TYPE_CONVERT__HOST_DEVICE__ inline T_dst convert_type(T_src val) { From 7ddd34485bdd627f5be1b835ad4ea9cfbcfb7df0 Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Fri, 5 Jan 2024 07:28:35 +0000 Subject: [PATCH 93/94] Add patches for ROCm --- Dockerfile.rocm | 1 + patch_rocm.rocm.sh | 27 + patch_xformers-0.0.22.post7.rocm.sh | 22 - rocm_patch/rocm__amd_bf16.patch | 975 ++++++++++++++++++++++++++++ 4 files changed, 1003 insertions(+), 22 deletions(-) create mode 100644 patch_rocm.rocm.sh delete mode 100644 patch_xformers-0.0.22.post7.rocm.sh create mode 100644 rocm_patch/rocm__amd_bf16.patch diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 36a7ee37fd228..23c962a3d7075 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -52,6 +52,7 @@ RUN pip install xformers==0.0.23 --no-deps RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ + && bash patch_rocm.rocm.sh \ && bash patch_xformers.rocm.sh \ && python3 setup.py install \ && cd .. diff --git a/patch_rocm.rocm.sh b/patch_rocm.rocm.sh new file mode 100644 index 0000000000000..fc7570e6086c7 --- /dev/null +++ b/patch_rocm.rocm.sh @@ -0,0 +1,27 @@ +#!/bin/bash +set -e + +if [ -z "$ROCM_PATH" ]; then + echo "Could not determine ROCm installation path by ROCM_PATH. Abort HIP patching" + exit 1 +fi + +export __HIP_FILE_TO_PATCH="$ROCM_PATH/include/hip/amd_detail/amd_hip_bf16.h" +export __HIP_PATCH_FILE="./rocm_patch/rocm__amd_bf16.patch" + +if [ ! -f "$__HIP_FILE_TO_PATCH" ]; then + echo "Could not find the file to be patched in $__HIP_FILE_TO_PATCH. Abort HIP patching" + exit 2 +fi + +echo "File to be patched: $__HIP_FILE_TO_PATCH" + +if ! patch -R -p0 -s -f --dry-run $__HIP_FILE_TO_PATCH $__HIP_PATCH_FILE; then + echo "Applying patch to ${__HIP_FILE_TO_PATCH}" + patch -p0 $__HIP_FILE_TO_PATCH $__HIP_PATCH_FILE + echo "Successfully patched ${__HIP_FILE_TO_PATCH}" +else + echo "${__HIP_FILE_TO_PATCH} has been patched before" +fi + +exit 0 \ No newline at end of file diff --git a/patch_xformers-0.0.22.post7.rocm.sh b/patch_xformers-0.0.22.post7.rocm.sh deleted file mode 100644 index 8dcb6240f0a44..0000000000000 --- a/patch_xformers-0.0.22.post7.rocm.sh +++ /dev/null @@ -1,22 +0,0 @@ -#!/bin/bash -export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)') -export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)') - -echo $XFORMERS_FMHA_FLASH_PATH -echo $XFORMERS_FMHA_COMMON_PATH - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}" - patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-0.0.22.post7.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}" -else - echo "${XFORMERS_FMHA_FLASH_PATH} was patched before" -fi - -if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch"; then - echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}" - patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-0.0.22.post7.rocm.patch" - echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}" -else - echo "${XFORMERS_FMHA_COMMON_PATH} was patched before" -fi diff --git a/rocm_patch/rocm__amd_bf16.patch b/rocm_patch/rocm__amd_bf16.patch new file mode 100644 index 0000000000000..dbdb8002169fb --- /dev/null +++ b/rocm_patch/rocm__amd_bf16.patch @@ -0,0 +1,975 @@ +--- amd_hip_bf16_ori.h 2024-01-05 06:45:45.451392253 +0000 ++++ amd_hip_bf16.h 2024-01-05 06:44:22.164682921 +0000 +@@ -85,17 +85,31 @@ + #ifndef _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ + #define _HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_ + ++#if !defined(__HIPCC_RTC__) ++#include ++#endif // !defined(__HIPCC_RTC__) ++ + #include "amd_hip_vector_types.h" // float2 etc + #include "device_library_decls.h" // ocml conversion functions + #include "math_fwd.h" // ocml device functions + + #if defined(__HIPCC_RTC__) +-#define __HOST_DEVICE__ __device__ ++#define __HOST_DEVICE__ __device__ static + #else ++#include + #include +-#define __HOST_DEVICE__ __host__ __device__ ++#include ++#define __HOST_DEVICE__ __host__ __device__ static inline + #endif + ++#define HIPRT_ONE_BF16 __float2bfloat16(1.0f) ++#define HIPRT_ZERO_BF16 __float2bfloat16(0.0f) ++#define HIPRT_INF_BF16 __ushort_as_bfloat16((unsigned short)0x7F80U) ++#define HIPRT_MAX_NORMAL_BF16 __ushort_as_bfloat16((unsigned short)0x7F7FU) ++#define HIPRT_MIN_DENORM_BF16 __ushort_as_bfloat16((unsigned short)0x0001U) ++#define HIPRT_NAN_BF16 __ushort_as_bfloat16((unsigned short)0x7FFFU) ++#define HIPRT_NEG_ZERO_BF16 __ushort_as_bfloat16((unsigned short)0x8000U) ++ + // Since we are using unsigned short to represent data in bfloat16, it can be of different sizes on + // different machines. These naive checks should prevent some undefined behavior on systems which + // have different sizes for basic types. +@@ -185,7 +199,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Moves bfloat16 value to bfloat162 + */ +-__device__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { ++__HOST_DEVICE__ __hip_bfloat162 __bfloat162bfloat162(const __hip_bfloat16 a) { + return __hip_bfloat162{a, a}; + } + +@@ -193,13 +207,13 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets bits in a __hip_bfloat16 as a signed short integer + */ +-__device__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } ++__HOST_DEVICE__ short int __bfloat16_as_short(const __hip_bfloat16 h) { return (short)h.data; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets bits in a __hip_bfloat16 as an unsigned signed short integer + */ +-__device__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } ++__HOST_DEVICE__ unsigned short int __bfloat16_as_ushort(const __hip_bfloat16 h) { return h.data; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV +@@ -221,7 +235,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Combine two __hip_bfloat16 to __hip_bfloat162 + */ +-__device__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat162 __halves2bfloat162(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __hip_bfloat162{a, b}; + } + +@@ -229,13 +243,13 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns high 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } ++__HOST_DEVICE__ __hip_bfloat16 __high2bfloat16(const __hip_bfloat162 a) { return a.y; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns high 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __high2bfloat162(const __hip_bfloat162 a) { + return __hip_bfloat162{a.y, a.y}; + } + +@@ -249,7 +263,8 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Extracts high 16 bits from each and combines them + */ +-__device__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __highs2bfloat162(const __hip_bfloat162 a, ++ const __hip_bfloat162 b) { + return __hip_bfloat162{a.y, b.y}; + } + +@@ -257,13 +272,13 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns low 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } ++__HOST_DEVICE__ __hip_bfloat16 __low2bfloat16(const __hip_bfloat162 a) { return a.x; } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Returns low 16 bits of __hip_bfloat162 + */ +-__device__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __low2bfloat162(const __hip_bfloat162 a) { + return __hip_bfloat162{a.x, a.x}; + } + +@@ -277,7 +292,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Swaps both halves + */ +-__device__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __lowhigh2highlow(const __hip_bfloat162 a) { + return __hip_bfloat162{a.y, a.x}; + } + +@@ -285,7 +300,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Extracts low 16 bits from each and combines them + */ +-__device__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __lows2bfloat162(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{a.x, b.x}; + } + +@@ -293,7 +308,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets short int into a bfloat16 + */ +-__device__ __hip_bfloat16 __short_as_bfloat16(const short int a) { ++__HOST_DEVICE__ __hip_bfloat16 __short_as_bfloat16(const short int a) { + return __hip_bfloat16{(unsigned short)a}; + } + +@@ -301,7 +316,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_CONV + * \brief Reinterprets unsigned short int into a bfloat16 + */ +-__device__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { ++__HOST_DEVICE__ __hip_bfloat16 __ushort_as_bfloat16(const unsigned short int a) { + return __hip_bfloat16{a}; + } + +@@ -310,7 +325,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Adds two bfloat16 values + */ +-__device__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hadd(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b)); + } + +@@ -318,7 +333,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Subtracts two bfloat16 values + */ +-__device__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hsub(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) - __bfloat162float(b)); + } + +@@ -326,7 +341,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Divides two bfloat16 values + */ +-__device__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hdiv(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) / __bfloat162float(b)); + } + +@@ -344,7 +359,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Multiplies two bfloat16 values + */ +-__device__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b)); + } + +@@ -352,7 +367,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Negate a bfloat16 value + */ +-__device__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { ++__HOST_DEVICE__ __hip_bfloat16 __hneg(const __hip_bfloat16 a) { + auto ret = a; + ret.data ^= 0x8000; + return ret; +@@ -362,7 +377,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH + * \brief Returns absolute of a bfloat16 + */ +-__device__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { ++__HOST_DEVICE__ __hip_bfloat16 __habs(const __hip_bfloat16 a) { + auto ret = a; + ret.data &= 0x7FFF; + return ret; +@@ -372,7 +387,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Divides bfloat162 values + */ +-__device__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __h2div(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__float2bfloat16(__bfloat162float(a.x) / __bfloat162float(b.x)), + __float2bfloat16(__bfloat162float(a.y) / __bfloat162float(b.y))}; + } +@@ -381,7 +396,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Returns absolute of a bfloat162 + */ +-__device__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __habs2(const __hip_bfloat162 a) { + return __hip_bfloat162{__habs(a.x), __habs(a.y)}; + } + +@@ -389,7 +404,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Adds two bfloat162 values + */ +-__device__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __hadd2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__hadd(a.x, b.x), __hadd(a.y, b.y)}; + } + +@@ -406,7 +421,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Multiplies two bfloat162 values + */ +-__device__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__hmul(a.x, b.x), __hmul(a.y, b.y)}; + } + +@@ -414,7 +429,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Converts a bfloat162 into negative + */ +-__device__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { ++__HOST_DEVICE__ __hip_bfloat162 __hneg2(const __hip_bfloat162 a) { + return __hip_bfloat162{__hneg(a.x), __hneg(a.y)}; + } + +@@ -422,15 +437,251 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH + * \brief Subtracts two bfloat162 values + */ +-__device__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ __hip_bfloat162 __hsub2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hip_bfloat162{__hsub(a.x, b.x), __hsub(a.y, b.y)}; + } + + /** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to multiply two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator*(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hmul(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to multiply-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator*=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hmul(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to unary+ on a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l) { return l; } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to add two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator+(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hadd(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to negate a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l) { return __hneg(l); } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to subtract two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator-(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hsub(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to post increment a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator++(__hip_bfloat16& l, const int) { ++ auto ret = l; ++ l = __hadd(l, HIPRT_ONE_BF16); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to pre increment a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator++(__hip_bfloat16& l) { ++ l = __hadd(l, HIPRT_ONE_BF16); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to post decrement a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator--(__hip_bfloat16& l, const int) { ++ auto ret = l; ++ l = __hsub(l, HIPRT_ONE_BF16); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to pre decrement a __hip_bfloat16 number ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator--(__hip_bfloat16& l) { ++ l = __hsub(l, HIPRT_ONE_BF16); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to add-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator+=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hadd(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to subtract-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator-=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hsub(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to divide two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16 operator/(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hdiv(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_ARITH ++ * \brief Operator to divide-assign two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat16& operator/=(__hip_bfloat16& l, const __hip_bfloat16& r) { ++ l = __hdiv(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to multiply two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator*(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hmul2(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to multiply-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator*=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __hmul2(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to unary+ on a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l) { return l; } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to add two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator+(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hadd2(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to negate a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l) { return __hneg2(l); } ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to subtract two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator-(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hsub2(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to post increment a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator++(__hip_bfloat162& l, const int) { ++ auto ret = l; ++ l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to pre increment a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator++(__hip_bfloat162& l) { ++ l = __hadd2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to post decrement a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator--(__hip_bfloat162& l, const int) { ++ auto ret = l; ++ l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return ret; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to pre decrement a __hip_bfloat162 number ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator--(__hip_bfloat162& l) { ++ l = __hsub2(l, {HIPRT_ONE_BF16, HIPRT_ONE_BF16}); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to add-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator+=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __hadd2(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to subtract-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator-=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __hsub2(l, r); ++ return l; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to divide two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162 operator/(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __h2div(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_ARITH ++ * \brief Operator to divide-assign two __hip_bfloat162 numbers ++ */ ++__HOST_DEVICE__ __hip_bfloat162& operator/=(__hip_bfloat162& l, const __hip_bfloat162& r) { ++ l = __h2div(l, r); ++ return l; ++} ++ ++/** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values + */ +-__device__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __heq(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) == __bfloat162float(b); + } + +@@ -438,7 +689,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered equal + */ +-__device__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hequ(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) < __bfloat162float(b)) && + !(__bfloat162float(a) > __bfloat162float(b)); + } +@@ -447,7 +698,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - greater than + */ +-__device__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hgt(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) > __bfloat162float(b); + } + +@@ -455,7 +706,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered greater than + */ +-__device__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hgtu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) <= __bfloat162float(b)); + } + +@@ -463,7 +714,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - greater than equal + */ +-__device__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hge(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) >= __bfloat162float(b); + } + +@@ -471,7 +722,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered greater than equal + */ +-__device__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hgeu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) < __bfloat162float(b)); + } + +@@ -479,7 +730,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - not equal + */ +-__device__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hne(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) != __bfloat162float(b); + } + +@@ -487,7 +738,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered not equal + */ +-__device__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hneu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) == __bfloat162float(b)); + } + +@@ -495,23 +746,31 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - return max + */ +-__device__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hmax(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++#if __HIP_DEVICE_COMPILE__ + return __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a), __bfloat162float(b))); ++#else ++ return __float2bfloat16(std::max(__bfloat162float(a), __bfloat162float(b))); ++#endif + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - return min + */ +-__device__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ __hip_bfloat16 __hmin(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++#if __HIP_DEVICE_COMPILE__ + return __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a), __bfloat162float(b))); ++#else ++ return __float2bfloat16(std::min(__bfloat162float(a), __bfloat162float(b))); ++#endif + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - less than operator + */ +-__device__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hlt(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) < __bfloat162float(b); + } + +@@ -519,15 +778,15 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered less than + */ +-__device__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hltu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) >= __bfloat162float(b)); + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP +- * \brief Compare two bfloat162 values - less than ++ * \brief Compare two bfloat162 values - less than equal + */ +-__device__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hle(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return __bfloat162float(a) <= __bfloat162float(b); + } + +@@ -535,7 +794,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Compare two bfloat162 values - unordered less than equal + */ +-__device__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { ++__HOST_DEVICE__ bool __hleu(const __hip_bfloat16 a, const __hip_bfloat16 b) { + return !(__bfloat162float(a) > __bfloat162float(b)); + } + +@@ -543,19 +802,33 @@ + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Checks if number is inf + */ +-__device__ int __hisinf(const __hip_bfloat16 a) { return __ocml_isinf_f32(__bfloat162float(a)); } ++__HOST_DEVICE__ int __hisinf(const __hip_bfloat16 a) { ++ unsigned short sign = a.data & 0x8000U; ++#if __HIP_DEVICE_COMPILE__ ++ int res = __ocml_isinf_f32(__bfloat162float(a)); ++#else ++ int res = std::isinf(__bfloat162float(a)) ? 1 : 0; ++#endif ++ return (res == 0) ? res : ((sign != 0U) ? -res : res); ++} + + /** + * \ingroup HIP_INTRINSIC_BFLOAT16_COMP + * \brief Checks if number is nan + */ +-__device__ bool __hisnan(const __hip_bfloat16 a) { return __ocml_isnan_f32(__bfloat162float(a)); } ++__HOST_DEVICE__ bool __hisnan(const __hip_bfloat16 a) { ++#if __HIP_DEVICE_COMPILE__ ++ return __ocml_isnan_f32(__bfloat162float(a)); ++#else ++ return std::isnan(__bfloat162float(a)); ++#endif ++} + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks if two numbers are equal + */ +-__device__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbeq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __heq(a.x, b.x) && __heq(a.y, b.y); + } + +@@ -563,7 +836,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks if two numbers are equal - unordered + */ +-__device__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbequ2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hequ(a.x, b.x) && __hequ(a.y, b.y); + } + +@@ -571,7 +844,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b + */ +-__device__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hge(a.x, b.x) && __hge(a.y, b.y); + } + +@@ -579,7 +852,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b - unordered + */ +-__device__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbgeu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hgeu(a.x, b.x) && __hgeu(a.y, b.y); + } + +@@ -587,7 +860,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b + */ +-__device__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hgt(a.x, b.x) && __hgt(a.y, b.y); + } + +@@ -595,7 +868,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b - unordered + */ +-__device__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbgtu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hgtu(a.x, b.x) && __hgtu(a.y, b.y); + } + +@@ -603,7 +876,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b + */ +-__device__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hble2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hle(a.x, b.x) && __hle(a.y, b.y); + } + +@@ -611,7 +884,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b - unordered + */ +-__device__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbleu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hleu(a.x, b.x) && __hleu(a.y, b.y); + } + +@@ -619,7 +892,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b + */ +-__device__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hblt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hlt(a.x, b.x) && __hlt(a.y, b.y); + } + +@@ -627,7 +900,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b - unordered + */ +-__device__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbltu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hltu(a.x, b.x) && __hltu(a.y, b.y); + } + +@@ -635,7 +908,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b + */ +-__device__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hne(a.x, b.x) && __hne(a.y, b.y); + } + +@@ -643,7 +916,7 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b + */ +-__device__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++__HOST_DEVICE__ bool __hbneu2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + return __hneu(a.x, b.x) && __hneu(a.y, b.y); + } + +@@ -651,84 +924,175 @@ + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a != b, returns 1.0 if equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__heq(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__heq(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __heq2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__heq(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__heq(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a >= b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hge(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hge(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hge2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hge(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hge(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a > b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hgt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hgt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hgt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hgt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hgt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a is NaN, returns 1.0 if NaN, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { +- return __hip_bfloat162{ +- {__ocml_isnan_f32(__bfloat162float(a.x)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__ocml_isnan_f32(__bfloat162float(a.y)) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hisnan2(const __hip_bfloat162 a) { ++ return __hip_bfloat162{{__hisnan(a.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hisnan(a.y) ? HIPRT_ONE_BF16 : HIPRT_ONE_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a <= b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hle(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hle(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hle2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hle(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hle(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Check for a < b, returns 1.0 if greater than equal, otherwise 0.0 + */ +-__device__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hlt(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hlt(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hlt2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hlt(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hlt(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Returns max of two elements + */ +-__device__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{ +- __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.x), __bfloat162float(b.x))), +- __float2bfloat16(__ocml_fmax_f32(__bfloat162float(a.y), __bfloat162float(b.y)))}; ++__HOST_DEVICE__ __hip_bfloat162 __hmax2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{__hmax(a.x, b.x), __hmax(a.y, b.y)}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Returns min of two elements + */ +-__device__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{ +- __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.x), __bfloat162float(b.x))), +- __float2bfloat16(__ocml_fmin_f32(__bfloat162float(a.y), __bfloat162float(b.y)))}; ++__HOST_DEVICE__ __hip_bfloat162 __hmin2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{__hmin(a.x, b.x), __hmin(a.y, b.y)}; + } + + /** + * \ingroup HIP_INTRINSIC_BFLOAT162_COMP + * \brief Checks for not equal to + */ +-__device__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { +- return __hip_bfloat162{{__hne(a.x, b.x) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}, +- {__hne(a.y, b.y) ? __float2bfloat16(1.0f) : __float2bfloat16(0.0f)}}; ++__HOST_DEVICE__ __hip_bfloat162 __hne2(const __hip_bfloat162 a, const __hip_bfloat162 b) { ++ return __hip_bfloat162{{__hne(a.x, b.x) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}, ++ {__hne(a.y, b.y) ? HIPRT_ONE_BF16 : HIPRT_ZERO_BF16}}; ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator==(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __heq(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a not equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator!=(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hne(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a less than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hlt(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<=(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hle(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a greater than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hgt(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>=(const __hip_bfloat16& l, const __hip_bfloat16& r) { ++ return __hge(l, r); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform an equal compare on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator==(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __heq(l.x, r.x) && __heq(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a not equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator!=(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hne(l.x, r.x) || __hne(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a less than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hlt(l.x, r.x) && __hlt(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a less than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator<=(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hle(l.x, r.x) && __hle(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT162_COMP ++ * \brief Operator to perform a greater than on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hgt(l.x, r.x) && __hgt(l.y, r.y); ++} ++ ++/** ++ * \ingroup HIP_INTRINSIC_BFLOAT16_COMP ++ * \brief Operator to perform a greater than equal on two __hip_bfloat16 numbers ++ */ ++__HOST_DEVICE__ bool operator>=(const __hip_bfloat162& l, const __hip_bfloat162& r) { ++ return __hge(l.x, r.x) && __hge(l.y, r.y); + } + + /** +@@ -970,5 +1334,4 @@ + __device__ __hip_bfloat162 h2trunc(const __hip_bfloat162 h) { + return __hip_bfloat162{htrunc(h.x), htrunc(h.y)}; + } +- + #endif From bf2dea1fb705d668f9be9bb51372e1fa85bd498a Mon Sep 17 00:00:00 2001 From: miloice <17350011+kliuae@users.noreply.github.com> Date: Sat, 6 Jan 2024 13:19:33 +0000 Subject: [PATCH 94/94] Add rank 128 --- csrc/punica/bgmv/bgmv_config.h | 11 ++++++++--- vllm/config.py | 2 +- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index ced0397dab216..664dddc680ab6 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -10,8 +10,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, // clang-format off -#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ - f(in_T, out_T, W_T, narrow, 128) \ +#define FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, narrow) \ f(in_T, out_T, W_T, narrow, 256) \ f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 1024) \ @@ -48,11 +47,17 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 49152) \ // Keep above in sync with vllm/lora/layers::SamplerWithLoRA +#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 128) \ + FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, narrow) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \ + FOR_BGMV_WIDE_exc128(f, in_T, out_T, W_T, 128) // clang-format on diff --git a/vllm/config.py b/vllm/config.py index 9facbd6664075..b279ec548c511 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -416,7 +416,7 @@ class LoRAConfig: def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - possible_max_ranks = (8, 16, 32, 64) + possible_max_ranks = (8, 16, 32, 64, 128) possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError(