diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 35d2ba1f8bab4..96fcafc9dc1c1 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -77,7 +77,6 @@ echo "Commands:$commands" #ignore certain kernels tests if [[ $commands == *" kernels "* ]]; then commands="${commands} \ - --ignore=kernels/test_attention.py \ --ignore=kernels/test_attention_selector.py \ --ignore=kernels/test_blocksparse_attention.py \ --ignore=kernels/test_causal_conv1d.py \ diff --git a/CMakeLists.txt b/CMakeLists.txt index c5fc2f3c1aaff..f7e329294ce30 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}") set(PYTHON_SUPPORTED_VERSIONS "3.9" "3.10" "3.11" "3.12") # Supported NVIDIA architectures. -set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0") +set(CUDA_SUPPORTED_ARCHS "7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0") # Supported AMD GPU architectures. set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101") @@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_ARCHS) set(MARLIN_SRCS "csrc/quantization/fp8/fp8_marlin.cu" @@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require # CUDA 12.0 or later (and only work on Hopper, 9.0a for now). - cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0a;10.0a;10.1a;12.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" @@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + "7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # 2:4 Sparse Kernels # The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor - # require CUDA 12.2 or later (and only work on Hopper, 9.0a for now). + # require CUDA 12.2 or later (and only work on Hopper and Blackwell). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS) set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu") set_gencode_flags_for_srcs( @@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS) set(SRCS "csrc/quantization/fp4/nvfp4_quant_kernels.cu" - "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu" - ) + "csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu") set_gencode_flags_for_srcs( SRCS "${SRCS}" CUDA_ARCHS "${FP4_ARCHS}") @@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(FP4_ARCHS) endif() + # FP8 Blackwell Archs + cuda_archs_loose_intersection(BLACKWELL_ARCHS "10.0;10.1;12.0" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND BLACKWELL_ARCHS) + set(SRCS + "csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${BLACKWELL_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + message(STATUS "Building FP8 for archs: ${BLACKWELL_ARCHS}") + else() + # clear BLACKWELL_ARCHS + set(BLACKWELL_ARCHS) + endif() + # # Machete kernels @@ -514,6 +529,7 @@ define_gpu_extension_target( COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} + INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -537,7 +553,7 @@ set_gencode_flags_for_srcs( CUDA_ARCHS "${CUDA_ARCHS}") if(VLLM_GPU_LANG STREQUAL "CUDA") - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0" "${CUDA_ARCHS}") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) set(MARLIN_MOE_SRC "csrc/moe/marlin_kernels/marlin_moe_kernel.h" diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index c862dec81fccd..bb28c32798e2c 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -2,6 +2,7 @@ import argparse import time +from contextlib import nullcontext from datetime import datetime from itertools import product from typing import Any, TypedDict @@ -40,6 +41,7 @@ def benchmark_config( use_fp8_w8a8: bool, use_int8_w8a16: bool, num_iters: int = 100, + block_quant_shape: List[int] = None, ) -> float: init_dtype = torch.float16 if use_fp8_w8a8 else dtype x = torch.randn(num_tokens, hidden_size, dtype=dtype) @@ -81,8 +83,24 @@ def benchmark_config( dtype=torch.float32) w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32) if use_fp8_w8a8: - w1_scale = torch.randn(num_experts, dtype=torch.float32) - w2_scale = torch.randn(num_experts, dtype=torch.float32) + if block_quant_shape: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + E = num_experts + N = shard_intermediate_size // 2 + K = hidden_size + factor_for_scale = 1e-2 + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1), + dtype=torch.float32) * factor_for_scale + w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2), + dtype=torch.float32) * factor_for_scale + else: + w1_scale = torch.randn(num_experts, dtype=torch.float32) + w2_scale = torch.randn(num_experts, dtype=torch.float32) + a1_scale = torch.randn(1, dtype=torch.float32) a2_scale = torch.randn(1, dtype=torch.float32) @@ -111,6 +129,7 @@ def run(): w2_scale=w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, + block_shape=block_quant_shape, ) # JIT compilation & warmup @@ -175,7 +194,8 @@ def get_rocm_tuning_space(use_fp16): return param_ranges -def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: +def get_configs_compute_bound(use_fp16, + block_quant_shape) -> list[dict[str, int]]: configs: list[BenchmarkConfig] = [] if current_platform.is_rocm(): @@ -204,17 +224,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]: for config_values in product(*values): config = dict(zip(keys, config_values)) configs.append(config) + + # Remove configs that are not compatible with fp8 block quantization + # BLOCK_SIZE_K must be a multiple of block_k + # BLOCK_SIZE_N must be a multiple of block_n + if block_quant_shape is not None and not use_fp16: + block_n, block_k = block_quant_shape[0], block_quant_shape[1] + for config in configs[:]: + if config["BLOCK_SIZE_K"] % block_k != 0 or config[ + "BLOCK_SIZE_N"] % block_n != 0: + configs.remove(config) return configs def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, - search_space, is_fp16): + search_space, is_fp16, topk): N1, K1 = shard_intermediate_size, hidden_size N2, K2 = hidden_size, shard_intermediate_size // 2 - pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space, - is_fp16) - pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space, - is_fp16) + pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1, + search_space, is_fp16) + pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2, + search_space, is_fp16) search_space = merge_unique_dicts(pruned_space_1, pruned_space_2) return search_space @@ -372,6 +402,7 @@ def tune( use_fp8_w8a8: bool, use_int8_w8a16: bool, search_space: list[dict[str, int]], + block_quant_shape: list[int], ) -> dict[str, int]: best_config = None best_time = float("inf") @@ -380,21 +411,24 @@ def tune( search_space = prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size, search_space, - is_fp16) + is_fp16, topk) - with torch.cuda.device(self.device_id): + with torch.cuda.device(self.device_id) if current_platform.is_rocm( + ) else nullcontext(): for config in tqdm(search_space): try: - kernel_time = benchmark_config(config, - num_tokens, - num_experts, - shard_intermediate_size, - hidden_size, - topk, - dtype, - use_fp8_w8a8, - use_int8_w8a16, - num_iters=20) + kernel_time = benchmark_config( + config, + num_tokens, + num_experts, + shard_intermediate_size, + hidden_size, + topk, + dtype, + use_fp8_w8a8, + use_int8_w8a16, + num_iters=20, + block_quant_shape=block_quant_shape) except triton.runtime.autotuner.OutOfResources: # Some configurations may be invalid and fail to compile. continue @@ -436,8 +470,8 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, shard_intermediate_size: int, hidden_size: int, topk: int, - dtype: torch.dtype, use_fp8_w8a8: bool, - use_int8_w8a16: bool) -> None: + dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool, + block_quant_shape: List[int]) -> None: dtype_str = get_config_dtype_str(dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8) @@ -445,7 +479,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, # NOTE(woosuk): The current naming convention uses w2.shape[2], which # is the intermediate size after silu_and_mul. filename = get_config_file_name(num_experts, shard_intermediate_size // 2, - dtype_str) + dtype_str, block_quant_shape) print(f"Writing best config to {filename}...") with open(filename, "w") as f: @@ -455,7 +489,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int, def main(args: argparse.Namespace): print(args) - + block_quant_shape = None config = AutoConfig.from_pretrained( args.model, trust_remote_code=args.trust_remote_code) if config.architectures[0] == "DbrxForCausalLM": @@ -474,6 +508,7 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + block_quant_shape = config.quantization_config['weight_block_size'] else: # Default: Mixtral. E = config.num_local_experts @@ -511,27 +546,30 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]: if args.tune: is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) - search_space = get_configs_compute_bound(is_fp16) + search_space = get_configs_compute_bound(is_fp16, block_quant_shape) print(f"Start tuning over {len(search_space)} configurations...") start = time.time() configs = _distribute( - "tune", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space) - for batch_size in batch_sizes]) + "tune", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape) + for batch_size in batch_sizes]) best_configs = { M: sort_config(config) for M, config in zip(batch_sizes, configs) } save_configs(best_configs, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16) + topk, dtype, use_fp8_w8a8, use_int8_w8a16, + block_quant_shape) end = time.time() print(f"Tuning took {end - start:.2f} seconds") else: outputs = _distribute( - "benchmark", [(batch_size, E, shard_intermediate_size, hidden_size, - topk, dtype, use_fp8_w8a8, use_int8_w8a16) - for batch_size in batch_sizes]) + "benchmark", + [(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype, + use_fp8_w8a8, use_int8_w8a16, block_quant_shape) + for batch_size in batch_sizes]) for batch_size, (config, kernel_time) in zip(batch_sizes, outputs): print(f"Batch size: {batch_size}, config: {config}") diff --git a/benchmarks/kernels/benchmark_paged_attention.py b/benchmarks/kernels/benchmark_paged_attention.py index d00e848243611..221d7b7d5d91b 100644 --- a/benchmarks/kernels/benchmark_paged_attention.py +++ b/benchmarks/kernels/benchmark_paged_attention.py @@ -11,8 +11,9 @@ from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, FlexibleArgumentParser, create_kv_caches_with_random) -NUM_BLOCKS = 1024 +NUM_BLOCKS = 128 * 1024 PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 @torch.inference_mode() @@ -80,6 +81,12 @@ def main( # Prepare for the paged attention kernel. output = torch.empty_like(query) if version == "v2": + if current_platform.is_rocm(): + global PARTITION_SIZE + if not args.custom_paged_attn: + PARTITION_SIZE = 1024 + else: + PARTITION_SIZE = PARTITION_SIZE_ROCM num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) tmp_output = torch.empty( size=(num_seqs, num_query_heads, num_partitions, head_size), @@ -123,25 +130,46 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: v_scale, ) elif version == "v2": - ops.paged_attention_v2( - output, - exp_sums, - max_logits, - tmp_output, - query, - key_cache, - value_cache, - num_kv_heads, - scale, - block_tables, - seq_lens, - block_size, - max_seq_len, - alibi_slopes, - kv_cache_dtype, - k_scale, - v_scale, - ) + if not args.custom_paged_attn: + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) else: raise ValueError(f"Invalid version: {version}") torch.cuda.synchronize() @@ -195,6 +223,9 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float: help="Data type for kv cache storage. If 'auto', will use model " "data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. " "ROCm (AMD GPU) supports fp8 (=fp8_e4m3)") + parser.add_argument("--custom-paged-attn", + action="store_true", + help="Use custom paged attention") args = parser.parse_args() print(args) diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index 1a0cd45f4e20b..0a812dc56a994 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -22,7 +22,7 @@ struct identity { T operator()(T lhs) const { return lhs; } }; -template +template struct TrivialEpilogue { private: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; @@ -44,32 +44,30 @@ struct TrivialEpilogue { * This class provides the common load descriptors for the * ScaledEpilogue[...] classes */ -template +template struct ScaledEpilogueBase { protected: using Accum = cutlass::epilogue::fusion::Sm90AccFetch; template using ColOrScalarLoad = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<0>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<0>, Int<0>>>; template using RowOrScalarLoad = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, - Stride, Int<1>, Int<0>>>; + 0 /*Stages*/, TileShape, T, Stride, Int<1>, Int<0>>>; // Don't want to support nullptr by default template using ColLoad = cutlass::epilogue::fusion::Sm90ColBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<0>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, TileShape, T, T, Stride, Int<0>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; // Don't want to support nullptr by default template using RowLoad = cutlass::epilogue::fusion::Sm90RowBroadcast< - 0 /*Stages*/, typename EpilogueDescriptor::TileShape, T, T, - Stride, Int<1>, Int<0>>, 128 / sizeof_bits_v, EnableNullPtr>; + 0 /*Stages*/, TileShape, T, T, Stride, Int<1>, Int<0>>, + 128 / sizeof_bits_v, EnableNullPtr>; // This utility function constructs the arguments for the load descriptors // from a tensor. It can handle both row and column, as well as row/column or @@ -116,11 +114,11 @@ struct ScaledEpilogueBase { the A and B operands respectively. These scales may be either per-tensor or per row or column. */ -template +template struct ScaledEpilogue - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -160,11 +158,11 @@ struct ScaledEpilogue * The bias tensor must be per-output channel. * ScaleA and ScaleB can be per-tensor or per-token/per-channel. */ -template +template struct ScaledEpilogueBias - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -203,11 +201,11 @@ struct ScaledEpilogueBias * bias is a column vector instead of a row vector. Useful e.g. if we are * computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels. */ -template +template struct ScaledEpilogueColumnBias - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias * * This epilogue also supports bias, which remains per-channel. */ -template +template struct ScaledEpilogueBiasAzp - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; @@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp * * This epilogue also supports bias, which remains per-channel. */ -template +template struct ScaledEpilogueBiasAzpToken - : private ScaledEpilogueBase { + : private ScaledEpilogueBase { private: - using SUPER = ScaledEpilogueBase; + using SUPER = ScaledEpilogueBase; using Accum = typename SUPER::Accum; using ScaleA = typename SUPER::template ColOrScalarLoad; using ScaleB = typename SUPER::template RowOrScalarLoad; diff --git a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh index 69a3f64cb0b04..26de32ce2b16a 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh @@ -16,6 +16,7 @@ #include "cutlass/gemm/kernel/gemm_universal.hpp" #include "cutlass/epilogue/collective/collective_builder.hpp" #include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/util/packed_stride.hpp" #include "core/math.hpp" #include "cutlass_extensions/common.hpp" @@ -64,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; + using ElementC = typename Gemm::ElementC; using ElementD = typename Gemm::ElementD; using GemmKernel = typename Gemm::GemmKernel; - int64_t lda = a.stride(0); - int64_t ldb = b.stride(1); - int64_t ldc = out.stride(0); - - using StrideA = cute::Stride, int64_t>; - using StrideB = cute::Stride, int64_t>; - using StrideC = typename Gemm::StrideC; - - StrideA a_stride{lda, cute::Int<1>{}, 0}; - StrideB b_stride{ldb, cute::Int<1>{}, 0}; - StrideC c_stride{ldc, cute::Int<1>{}, cute::Int<0>{}}; + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = StrideC; + using StrideAux = StrideC; typename GemmKernel::ProblemShape prob_shape = get_problem_shape(a, b); + auto [M, N, K, L] = prob_shape; + + StrideA a_stride = + cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + StrideB b_stride = + cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + StrideC c_stride = + cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + StrideD d_stride = + cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + StrideAux aux_stride = d_stride; auto a_ptr = static_cast(a.data_ptr()); auto b_ptr = static_cast(b.data_ptr()); @@ -87,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, b_stride}; auto c_ptr = static_cast(out.data_ptr()); + // auto d_ptr = static_cast(out.data_ptr()); typename GemmKernel::EpilogueArguments epilogue_args{ Gemm::Epilogue::prepare_args( std::forward(epilogue_params)...), - c_ptr, c_stride, c_ptr, c_stride}; + c_ptr, c_stride, c_ptr, d_stride}; cutlass_gemm_caller(a.device(), prob_shape, mainloop_args, epilogue_args); diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh index d2f43e2b7a89d..8f4df836bcc8d 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh @@ -40,12 +40,7 @@ struct cutlass_3x_gemm { typename std::conditional, int32_t, float>::type; - using EpilogueDescriptor = - cutlass::epilogue::collective::detail::EpilogueDescriptor< - TileShape, cutlass::epilogue::collective::EpilogueTileAuto, ElementD, - ElementD, EpilogueSchedule>; - - using Epilogue = Epilogue_; + using Epilogue = Epilogue_; using StrideD = Stride, Int<0>>; using ElementC = void; @@ -88,4 +83,65 @@ struct cutlass_3x_gemm { struct GemmKernel : public KernelType {}; }; +template typename Epilogue_, + typename TileShape, typename ClusterShape, typename KernelSchedule, + typename EpilogueSchedule> +struct cutlass_3x_gemm_sm100 { + using ElementAB = ElementAB_; + using LayoutA = cutlass::layout::RowMajor; + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits::value; + + using LayoutB = cutlass::layout::ColumnMajor; + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits::value; + + using ElementC = void; + using LayoutC = cutlass::layout::RowMajor; + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits::value; + + using ElementD = ElementD_; + using LayoutD = cutlass::layout::RowMajor; + static constexpr int AlignmentD = AlignmentC; + + using ElementAcc = + typename std::conditional, int32_t, + float>::type; + using Epilogue = Epilogue_; + + // MMA type + using ElementAccumulator = float; + + // Epilogue types + using ElementBias = cutlass::half_t; + using ElementCompute = float; + using ElementAux = ElementD; + using LayoutAux = LayoutD; + using ElementAmax = float; + + using EVTCompute = typename Epilogue::EVTCompute; + + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, TileShape, + ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, EpilogueSchedule, + EVTCompute>::CollectiveOp; + + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, ElementAB, + LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, CollectiveMainloop, CollectiveEpilogue, void>; +}; + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp index 7ede9e067477b..85272804774db 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp @@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out, torch::Tensor const& a_scales, torch::Tensor const& b_scales); +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias); + } // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu new file mode 100644 index 0000000000000..cf2cccc913f62 --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, + torch::Tensor const& b, + torch::Tensor const& a_scales, + torch::Tensor const& b_scales, + std::optional const& bias) { + TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + TORCH_CHECK(bias->dtype() == out.dtype(), + "currently bias dtype must match output dtype ", out.dtype()); + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm100_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh new file mode 100644 index 0000000000000..468b77d9593bc --- /dev/null +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include "scaled_mm.cuh" +#include "cutlass_gemm_caller.cuh" + +/** + * This file defines Gemm kernel configurations for SM100 (fp8) based on the + * Gemm shape. + */ + +namespace vllm { + +using c3x::cutlass_gemm_caller; + +template typename Epilogue> +struct sm100_fp8_config_default { + static_assert(std::is_same()); + using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; + using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto; + using TileShape = Shape<_256, _128, _64>; + using ClusterShape = Shape<_2, _2, _1>; + using Cutlass3xGemm = + cutlass_3x_gemm_sm100; +}; + +template typename Epilogue, + typename... EpilogueArgs> +inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, + torch::Tensor const& a, + torch::Tensor const& b, + EpilogueArgs&&... args) { + static_assert(std::is_same()); + TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); + TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + + using Cutlass3xGemmDefault = + typename sm100_fp8_config_default::Cutlass3xGemm; + return cutlass_gemm_caller( + out, a, b, std::forward(args)...); +} + +template