diff --git a/csrc/ops.h b/csrc/ops.h index 346898964010d..e39d4ef3188a3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -153,6 +153,7 @@ torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, #ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu index e6f06d72fbfd4..72d549e597df5 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu @@ -58,7 +58,13 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, vllm::cutlass_scaled_mm_blockwise_sm90_fp8(c, a, b, a_scales, b_scales); } else { - TORCH_CHECK(false, "Unsupported scale group shapes for CUTLASS 3.x GEMM"); + TORCH_CHECK(false, + "Unsupported scale group shapes for CUTLASS 3.x GEMM.\n " + "a_scale_group_shape must be [1, 128], got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128], got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); } } diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index da77312bc4b98..6bef55088682a 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -81,6 +81,19 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) { return false; } +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) { + // CUTLASS block-quantized FP8 kernels need at least CUDA 12.0 + // and at least SM90 (Hopper) + +#if defined CUDA_VERSION + if (cuda_device_capability >= 90) { + return CUDA_VERSION >= 12000; + } +#endif + + return false; +} + void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, torch::Tensor const& b_scales, @@ -212,4 +225,4 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, "No compiled cutlass_scaled_mm_azp for a compute capability less than " "CUDA device capability: ", version_num); -} \ No newline at end of file +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 1846d9ac29943..186e9c0e81b77 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -324,6 +324,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8(int cuda_device_capability) -> bool"); ops.impl("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); + // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) + ops.def( + "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " + "bool"); + ops.impl("cutlass_scaled_mm_supports_block_fp8", + &cutlass_scaled_mm_supports_fp8); + // Check if cutlass sparse scaled_mm is supported for CUDA devices of the // given capability ops.def( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index fd94134de0219..da237da2eccac 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -435,6 +435,11 @@ def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) +def cutlass_scaled_mm_supports_block_fp8(cuda_device_capability: int) -> bool: + return torch.ops._C.cutlass_scaled_mm_supports_block_fp8( + cuda_device_capability) + + def cutlass_scaled_mm(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 57dd6e310297d..adab1973b40ee 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -21,7 +21,8 @@ is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( all_close_1d, apply_fp8_linear, convert_to_channelwise, - cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, + cutlass_block_fp8_supported, cutlass_fp8_supported, + normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize, requantize_with_max_scale) from vllm.model_executor.parameter import (BlockQuantScaleParameter, ModelWeightParameter, @@ -133,6 +134,7 @@ class Fp8LinearMethod(LinearMethodBase): def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config self.cutlass_fp8_supported = cutlass_fp8_supported() + self.cutlass_block_fp8_supported = cutlass_block_fp8_supported() # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization @@ -359,6 +361,7 @@ def apply(self, weight_scale=layer.weight_scale_inv, input_scale=layer.input_scale, bias=bias, + cutlass_block_fp8_supported=self.cutlass_block_fp8_supported, ) return apply_fp8_linear( diff --git a/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/vllm/model_executor/layers/quantization/utils/fp8_utils.py index a7a3fa6601639..ccebff341a7ed 100644 --- a/vllm/model_executor/layers/quantization/utils/fp8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -8,6 +8,7 @@ import triton import triton.language as tl +from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform @@ -21,20 +22,34 @@ def apply_w8a8_block_fp8_linear( weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = True, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) - output = w8a8_block_fp8_matmul(q_input, - weight, - x_scale, - weight_scale, - block_size, - output_dtype=input.dtype) - + shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0) + if cutlass_block_fp8_supported and shape_supported_by_cutlass: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) + output = ops.cutlass_scaled_mm(q_input, + weight.T, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.T) + else: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=False) + output = w8a8_block_fp8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) @@ -98,10 +113,7 @@ def _per_token_group_quant_fp8( y_ptr, y_q_ptr, y_s_ptr, - # Stride of input - y_stride, - # Columns of input - N, + group_size, # Avoid to divide zero eps, # Information for float8 @@ -116,12 +128,60 @@ def _per_token_group_quant_fp8( """ # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) - y_ptr += g_id * y_stride - y_q_ptr += g_id * y_stride + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size y_s_ptr += g_id cols = tl.arange(0, BLOCK) # N <= BLOCK - mask = cols < N + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant @@ -138,12 +198,13 @@ def per_token_group_quant_fp8( group_size: int, eps: float = 1e-10, dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. Args: - x: The input tenosr with ndim >= 2. + x: The input tensor with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` @@ -167,29 +228,46 @@ def per_token_group_quant_fp8( x_q = torch.empty_like(x, device=x.device, dtype=dtype) M = x.numel() // group_size N = group_size - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size, ), - device=x.device, - dtype=torch.float32, - ) + if column_major_scales: + shape = (x.shape[-1] // group_size, ) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, + dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_fp8[(M, )]( - x, - x_q, - x_s, - group_size, - N, - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M, )]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) return x_q, x_s diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 9977804188a50..3af3b3e0ea942 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -30,6 +30,16 @@ def cutlass_fp8_supported() -> bool: return ops.cutlass_scaled_mm_supports_fp8(capability) +def cutlass_block_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_block_fp8(capability) + + def per_tensor_dequantize( tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor]) -> torch.Tensor: