From 90a3e0f2b41afea18b2ed3b5e213e5bfb5645a60 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 17 Oct 2024 19:45:44 -0500 Subject: [PATCH 1/5] basic support for symmetric smoothquant model --- .../compressed_tensors/int8_quant_kernels.cu | 42 ++++++++++++------- vllm/_custom_ops.py | 11 +++-- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu index aec9fa002f96e..e9987535bd3ea 100644 --- a/csrc/quantization/compressed_tensors/int8_quant_kernels.cu +++ b/csrc/quantization/compressed_tensors/int8_quant_kernels.cu @@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type const* scale_ptr, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) / scale); + out[i] = float_to_int8_rn(static_cast(input[i]) / scale); } } @@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel( scale_type const* scale_ptr, azp_type const* azp_ptr, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; scale_type const scale = *scale_ptr; azp_type const azp = *azp_ptr; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const val = static_cast(input[i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); - out[token_idx * hidden_size + i] = quant_val; + out[i] = quant_val; } } @@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, const int hidden_size) { int const tid = threadIdx.x; - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; float absmax_val = 0.0f; float const zero = 0.0f; + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; + for (int i = tid; i < hidden_size; i += blockDim.x) { - float val = static_cast(input[token_idx * hidden_size + i]); + float val = static_cast(input[i]); val = val > zero ? val : -val; absmax_val = val > absmax_val ? val : absmax_val; } @@ -150,8 +161,7 @@ __global__ void dynamic_scaled_int8_quant_kernel( float const tmp_scale = 127.0f / block_absmax_val; for (int i = tid; i < hidden_size; i += blockDim.x) { - out[token_idx * hidden_size + i] = float_to_int8_rn( - static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + out[i] = float_to_int8_rn(static_cast(input[i]) * tmp_scale); } } @@ -159,13 +169,17 @@ template __global__ void dynamic_scaled_int8_azp_quant_kernel( scalar_t const* __restrict__ input, int8_t* __restrict__ out, scale_type* scale, azp_type* azp, const int hidden_size) { - int const token_idx = blockIdx.x; + int64_t const token_idx = blockIdx.x; + + // Must be performed using 64-bit math to avoid integer overflow. + out += token_idx * hidden_size; + input += token_idx * hidden_size; // Scan for the min and max value for this token float max_val = std::numeric_limits::min(); float min_val = std::numeric_limits::max(); for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto val = static_cast(input[token_idx * hidden_size + i]); + auto val = static_cast(input[i]); max_val = std::max(max_val, val); min_val = std::min(min_val, val); } @@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel( // Quantize the values for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { - auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const val = static_cast(input[i]); auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); - out[token_idx * hidden_size + i] = quant_val; + out[i] = quant_val; } } diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c6e5bed5ad9a3..f09e33c5f5815 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -517,7 +517,6 @@ def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) - def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, @@ -531,9 +530,15 @@ def cutlass_scaled_mm(a: torch.Tensor, m = a.shape[0] n = b.shape[1] - out = torch.empty((m, n), dtype=out_dtype, device=a.device) - torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) + if is_hip(): + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a * out + out = scale_b.T * out + out = out.to(out_dtype) + else: + out = torch.empty((m, n), dtype=out_dtype, device=a.device) + torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) return out From dfdc1300f3a0c101485021cf83d3a2ebfd592468 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 17 Oct 2024 20:05:05 -0500 Subject: [PATCH 2/5] yap --- vllm/_custom_ops.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index f09e33c5f5815..6a016a69083dd 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -517,6 +517,7 @@ def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor, def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) + def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, @@ -773,8 +774,8 @@ def scaled_int8_quant( if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, None From 6b789f83088bde82164f0ae8b533406c8412f301 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Thu, 17 Oct 2024 20:11:32 -0500 Subject: [PATCH 3/5] yapf --- vllm/_custom_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 6a016a69083dd..adacfe82e4373 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -774,8 +774,8 @@ def scaled_int8_quant( if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp - is None), "azp must only be provided for asymmetric quantization." + azp is + None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, None From 1a9d5b0a54fb89ae7fd52288787485d01d17d624 Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Mon, 21 Oct 2024 12:51:30 -0500 Subject: [PATCH 4/5] Put implementation into torch function --- vllm/_custom_ops.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index adacfe82e4373..7a3cec00fa8e9 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1,6 +1,6 @@ import contextlib import functools -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union import torch import torch.library @@ -518,6 +518,22 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def scaled_mm_torch(a: torch.Tensor, + b: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out_dtype: Type[torch.dtype], + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out = torch.mm(a.to(torch.float32), b.to(torch.float32)) + out = scale_a * out + out = scale_b.T * out + out = out.to(out_dtype) + if bias is not None: + out = out + bias + + return out + + def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, @@ -533,10 +549,7 @@ def cutlass_scaled_mm(a: torch.Tensor, n = b.shape[1] if is_hip(): - out = torch.mm(a.to(torch.float32), b.to(torch.float32)) - out = scale_a * out - out = scale_b.T * out - out = out.to(out_dtype) + return scaled_mm_torch(a, b, scale_a, scale_b, out_dtype, bias) else: out = torch.empty((m, n), dtype=out_dtype, device=a.device) torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) @@ -774,8 +787,8 @@ def scaled_int8_quant( if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp is - None), "azp must only be provided for asymmetric quantization." + azp + is None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, None From dc6ad4fc23e61d44abfdc305556f9905158c0e0b Mon Sep 17 00:00:00 2001 From: Randall Smith Date: Mon, 21 Oct 2024 13:20:31 -0500 Subject: [PATCH 5/5] yapf --- vllm/_custom_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7a3cec00fa8e9..d985da18d200e 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -787,8 +787,8 @@ def scaled_int8_quant( if scale is not None: # static-per-tensor quantization. assert symmetric == ( - azp - is None), "azp must only be provided for asymmetric quantization." + azp is + None), "azp must only be provided for asymmetric quantization." torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) return output, scale, None