From c79bbc37e76e4240ae1bf4d687a6a20711e53a67 Mon Sep 17 00:00:00 2001 From: Rachel Guo Date: Thu, 30 Jan 2025 14:38:53 -0800 Subject: [PATCH] Port oss f16_fast_gemv into fbcode (#3610) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/688 This diff content includes: 1. Port OSS FastGEMV `fp16` kernel into fbcode and expose to python as a step 1 - `torch.ops.fbgemm.f16_fast_gemv` https://github.com/wangsiping97/FastGEMV/blob/1fdff6f74aade033c02727a419afd6a4b4bfbc3f/fast_gemv.cu#L14 2. Add `fp16_oss_fast_gemv` to quantize ops benchmark script 3. Add two simple tests for custom op`torch.ops.fbgemm.f16_fast_gemv` to test - `torch.compile()` able - correctness Perf numbers: P1722119058 compared with `f16_baseline,fp16_oss_fast_gemv,cuda_lite,marlin_bf16i4,machete_bf16i4` ====================== ### Benchmark Results | **M** | **N** | **K** | **Method** | **Elapsed Time (ms)** | **TFLOPS** | **GB/s** | | --- | --- | --- | --- | --- | --- | --- | | 1 | 1280 | 8192 | bf16_baseline | 0.024 | 0.860 | 861.042 | | 1 | 1280 | 8192 | fp16_oss_fast_gemv | 0.019 | 1.126 | 1127.391 | | 1 | 1280 | 8192 | cuda_lite | 0.015 | 1.357 | 679.032 | | 1 | 1280 | 8192 | marlin_bf16i4 | 0.027 | 0.768 | 192.612 | | 1 | 1280 | 8192 | machete_bf16i4 | 0.026 | 0.810 | 203.219 | | 1 | 8192 | 1024 | bf16_baseline | 0.018 | 0.952 | 953.176 | | 1 | 8192 | 1024 | fp16_oss_fast_gemv | 0.010 | 1.763 | 1765.033 | | 1 | 8192 | 1024 | cuda_lite | 0.014 | 1.198 | 600.054 | | 1 | 8192 | 1024 | marlin_bf16i4 | 0.015 | 1.144 | 287.150 | | 1 | 8192 | 1024 | machete_bf16i4 | 0.014 | 1.187 | 298.096 | | 1 | 7168 | 8192 | bf16_baseline | 0.073 | 1.609 | 1608.983 | | 1 | 7168 | 8192 | fp16_oss_fast_gemv | 0.069 | 1.697 | 1697.308 | | 1 | 7168 | 8192 | cuda_lite | 0.044 | 2.679 | 1340.093 | | 1 | 7168 | 8192 | marlin_bf16i4 | 0.033 | 3.590 | 898.436 | | 1 | 7168 | 8192 | machete_bf16i4 | 0.039 | 3.017 | 755.147 | | 1 | 8192 | 3584 | bf16_baseline | 0.045 | 1.312 | 1312.239 | | 1 | 8192 | 3584 | fp16_oss_fast_gemv | 0.026 | 2.268 | 1134.843 | | 1 | 8192 | 3584 | cuda_lite | 0.026 | 2.271 | 1136.151 | | 1 | 8192 | 3584 | marlin_bf16i4 | 0.021 | 2.808 | 703.164 | | 1 | 8192 | 3584 | machete_bf16i4 | 0.024 | 2.460 | 615.990 | heuristic sweep results from the 4 problem sizes we care about: P1722043272 **Next step:** Need fp8 mixed precision support for fast gemv kernel which is what we want Differential Revision: D68470488 --- fbgemm_gpu/experimental/gen_ai/CMakeLists.txt | 5 +- fbgemm_gpu/experimental/gen_ai/LICENSE | 21 ++ .../experimental/gen_ai/bench/quantize_ops.py | 32 ++ .../src/quantize/fast_gemv/fp16_fast_gemv.cu | 78 +++++ .../quantize/fast_gemv/include/fast_gemv.cu | 315 ++++++++++++++++++ .../quantize/fast_gemv/include/fast_gemv.cuh | 58 ++++ .../quantize/fast_gemv/include/utility.cuh | 78 +++++ .../gen_ai/src/quantize/quantize.cpp | 12 + .../gen_ai/test/quantize/quantize_test.py | 19 ++ 9 files changed, 617 insertions(+), 1 deletion(-) create mode 100644 fbgemm_gpu/experimental/gen_ai/LICENSE create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/fp16_fast_gemv.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cu create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cuh create mode 100644 fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/utility.cuh diff --git a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt index 25b42c5d62..8a51af2e34 100644 --- a/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt +++ b/fbgemm_gpu/experimental/gen_ai/CMakeLists.txt @@ -33,7 +33,10 @@ endif() # CUDA-specific sources file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_cuda src/quantize/cutlass_extensions/*.cu - src/quantize/cutlass_extensions/**/*.cu) + src/quantize/cutlass_extensions/**/*.cu + src/quantize/fast_gemv/*.cu + src/quantize/fast_gemv/**/*.cu + src/quantize/fast_gemv/**/*.cuh) # HIP-specific sources file(GLOB_RECURSE experimental_gen_ai_cpp_source_files_hip diff --git a/fbgemm_gpu/experimental/gen_ai/LICENSE b/fbgemm_gpu/experimental/gen_ai/LICENSE new file mode 100644 index 0000000000..d6cc4287a8 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Siping Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index c8e46e9f88..4661817dd6 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -361,6 +361,38 @@ def cuda(self) -> bool: return True +@register_quantize_op +class Fp16OSSFastGemv(QuantizeOpBase): + """ + FP16 OSS fast gemv kernel. + """ + + def quantize(self, x, w): + # dummy quantize + return x, w + + def compute(self, x, w): + out = torch.ops.fbgemm.fp16_fast_gemv(x, w) + return out + + def quantize_and_compute(self, x, w): + x, w = self.quantize(x, w) + return self.compute(x, w) + + @property + def name(self) -> str: + return "fp16_oss_fast_gemv" + + @property + def hip(self) -> bool: + # This implementation is specific to cublas. + return False + + @property + def cuda(self) -> bool: + return True + + @register_quantize_op class FP8CublasRowwiseGemm(QuantizeOpBase): """ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/fp16_fast_gemv.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/fp16_fast_gemv.cu new file mode 100644 index 0000000000..2d0c115a0a --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/fp16_fast_gemv.cu @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "include/fast_gemv.cuh" + +namespace fbgemm_gpu { + +#if CUDART_VERSION >= 12000 + +dim3 get_best_block_dim(int m, int n, int k) { + if (m == 1 && n == 1280 && k == 8192) { + return dim3(128, 4); + } else if (m == 1 && n == 8192 && k == 1024) { + return dim3(256, 1); + } else if (m == 1 && n == 7168 && k == 8192) { + return dim3(256, 1); + } else if (m == 1 && n == 8192 && k == 3584) { + return dim3(256, 1); + } else { + // Default block dimensions + return dim3(32, 4); + } +} + +at::Tensor fp16_fast_gemv(at::Tensor X, at::Tensor W) { + // note: oss fast gemv implementation accepts vector shape as (size, 1) i.e. + // (K, M) + // X: K x M + // W: N x K + auto m = X.size(1); + auto n = W.size(0); + auto k = W.size(1); + + TORCH_CHECK(X.is_cuda() && X.is_contiguous()); + TORCH_CHECK(W.is_cuda() && W.is_contiguous()); + + // the block_dim values are sweeped results + dim3 block_dim = get_best_block_dim(m, n, k); + dim3 grid_dim(1, n / block_dim.y); + unsigned int num_per_thread = k / block_dim.x; + + auto stream = at::cuda::getCurrentCUDAStream(); + + auto out_sizes = X.sizes().vec(); + out_sizes.front() = n; + auto Y = at::empty(out_sizes, X.options().dtype(at::kHalf)); + + gemv_fp16<<>>( + (half*)W.data_ptr(), // mat + (half*)X.data_ptr(), // vec + (half*)Y.data_ptr(), // res + k, + num_per_thread); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + return Y; +} + +#else + +at::Tensor fp16_fast_gemv(at::Tensor X, at::Tensor W) { + throw std::runtime_error( + "CUDA version is older than 12.0"); // requires CUDA>=12 +} +#endif + +} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cu new file mode 100644 index 0000000000..9d332cef12 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cu @@ -0,0 +1,315 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This source code is pulled from the original github repo: + * https://github.com/wangsiping97/FastGEMV. + */ + +#include +#include +#include +#include +#include +#include + +#include "fast_gemv.cuh" +#include "utility.cuh" + +///////////////////////////// NORMAL ////////////////////////////// +// thread_per_block = blockDim.x +// blockDim.y <= SHARED_MEM_MAX_ROWS +__global__ void gemv_fp16( + half* mat, + half* vec, + half* res, + unsigned int n, + unsigned int num_per_thread) { + float sum = 0; + // each thread load num_per_thread elements from global + unsigned int tid = threadIdx.x; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int start_idx = threadIdx.x; + float4* mat4 = reinterpret_cast(mat); + float4* vec4 = reinterpret_cast(vec); + +#pragma unroll + for (int iter = 0; iter < num_per_thread >> 3; iter++) { + unsigned int j = start_idx + iter * blockDim.x; + if (j < n >> 3) { + float4 vec_val = vec4[j]; + float4 mat_val = mat4[row * (n >> 3) + j]; + const half2* vec_h1 = (half2*)&vec_val.x; + const half2* vec_h2 = (half2*)&vec_val.y; + const half2* vec_h3 = (half2*)&vec_val.z; + const half2* vec_h4 = (half2*)&vec_val.w; + const half2* mat_h1 = (half2*)&mat_val.x; + const half2* mat_h2 = (half2*)&mat_val.y; + const half2* mat_h3 = (half2*)&mat_val.z; + const half2* mat_h4 = (half2*)&mat_val.w; + sum += cutlass::NumericConverter::convert(vec_h1->x) * + cutlass::NumericConverter::convert(mat_h1->x); + sum += cutlass::NumericConverter::convert(vec_h1->y) * + cutlass::NumericConverter::convert(mat_h1->y); + sum += cutlass::NumericConverter::convert(vec_h2->x) * + cutlass::NumericConverter::convert(mat_h2->x); + sum += cutlass::NumericConverter::convert(vec_h2->y) * + cutlass::NumericConverter::convert(mat_h2->y); + sum += cutlass::NumericConverter::convert(vec_h3->x) * + cutlass::NumericConverter::convert(mat_h3->x); + sum += cutlass::NumericConverter::convert(vec_h3->y) * + cutlass::NumericConverter::convert(mat_h3->y); + sum += cutlass::NumericConverter::convert(vec_h4->x) * + cutlass::NumericConverter::convert(mat_h4->x); + sum += cutlass::NumericConverter::convert(vec_h4->y) * + cutlass::NumericConverter::convert(mat_h4->y); + } + } + + sum = warpReduceSum(sum, blockDim.x); + + if (blockDim.x <= WARP_SIZE) { + if (tid == 0) { + res[row] = __float2half(sum); + } + return; + } + + // Shared mem for partial sums (one per warp in the block) + static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + if (laneId == 0) + warpLevelSums[threadIdx.y][warpId] = sum; + __syncthreads(); + // read from shared memory only if that warp existed + sum = (threadIdx.x < blockDim.x / WARP_SIZE) + ? warpLevelSums[threadIdx.y][laneId] + : 0.0; + // Final reduce using first warp + if (warpId == 0) + sum = warpReduceSum(sum, blockDim.x / WARP_SIZE); + if (tid == 0) { + res[row] = __float2half(sum); + } +} + +///////////////////////////// QUANTIZED-INT8 ////////////////////////////// + +__global__ void gemv_quantized_int8( + int8_t* mat, + half* vec, + half* res, + unsigned int n, + half scale, + half zero_point, + unsigned int num_per_thread) { + float sum = 0; + // each thread load num_per_thread elements from global + unsigned int tid = threadIdx.x; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int start_idx = threadIdx.x; + half4* mat4 = reinterpret_cast(mat); + float4* vec4 = reinterpret_cast(vec); + + float zero_point_f = static_cast(zero_point); + float scale_f = static_cast(scale); + +#pragma unroll + for (int iter = 0; iter < num_per_thread >> 3; iter++) { + unsigned int j = start_idx + iter * blockDim.x; + if (j < n >> 3) { + float4 vec_val = vec4[j]; + half4 mat_val = mat4[row * (n >> 3) + j]; + const half2* vec_h1 = (half2*)&vec_val.x; + const half2* vec_h2 = (half2*)&vec_val.y; + const half2* vec_h3 = (half2*)&vec_val.z; + const half2* vec_h4 = (half2*)&vec_val.w; + const int8_2* mat_h1 = (int8_2*)&mat_val.x; + const int8_2* mat_h2 = (int8_2*)&mat_val.y; + const int8_2* mat_h3 = (int8_2*)&mat_val.z; + const int8_2* mat_h4 = (int8_2*)&mat_val.w; + sum += static_cast(vec_h1->x) * + (static_cast(mat_h1->x) - zero_point_f); + sum += static_cast(vec_h1->y) * + (static_cast(mat_h1->y) - zero_point_f); + sum += static_cast(vec_h2->x) * + (static_cast(mat_h2->x) - zero_point_f); + sum += static_cast(vec_h2->y) * + (static_cast(mat_h2->y) - zero_point_f); + sum += static_cast(vec_h3->x) * + (static_cast(mat_h3->x) - zero_point_f); + sum += static_cast(vec_h3->y) * + (static_cast(mat_h3->y) - zero_point_f); + sum += static_cast(vec_h4->x) * + (static_cast(mat_h4->x) - zero_point_f); + sum += static_cast(vec_h4->y) * + (static_cast(mat_h4->y) - zero_point_f); + } + } + + sum *= scale_f; + + sum = warpReduceSum(sum, blockDim.x); + + if (blockDim.x <= WARP_SIZE) { + if (tid == 0) { + res[row] = __float2half(sum); + } + return; + } + + // Shared mem for partial sums (one per warp in the block) + static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + if (laneId == 0) + warpLevelSums[threadIdx.y][warpId] = sum; + __syncthreads(); + // read from shared memory only if that warp existed + sum = (threadIdx.x < blockDim.x / WARP_SIZE) + ? warpLevelSums[threadIdx.y][laneId] + : 0.0; + // Final reduce using first warp + if (warpId == 0) + sum = warpReduceSum(sum, blockDim.x / WARP_SIZE); + if (tid == 0) { + res[row] = __float2half(sum); + } +} + +///////////////////////////// QUANTIZED-INT4 ////////////////////////////// + +// based on previous experiments, num_per_thread can >= 16 +__global__ void gemv_quantized_int4( + uint4_2* mat, + half* vec, + half* res, + unsigned int n, + half scale, + half zero_point, + unsigned int num_per_thread) { + float sum = 0; + // each thread load num_per_thread elements from global + unsigned int tid = threadIdx.x; + unsigned int row = blockIdx.y * blockDim.y + threadIdx.y; + unsigned int start_idx = threadIdx.x; + uint4_2_4* mat4 = reinterpret_cast(mat); + float4* vec4 = reinterpret_cast(vec); + + float zero_point_f = static_cast(zero_point); + float scale_f = static_cast(scale); + +#pragma unroll + for (int iter = 0; iter < num_per_thread >> 4; iter++) { + unsigned int j = 2 * (start_idx + iter * blockDim.x); + if (j < n >> 3) { + float4 vec_val_1 = vec4[j]; // 8 half + float4 vec_val_2 = vec4[j + 1]; + const half2* vec_h1 = (half2*)&vec_val_1.x; + const half2* vec_h2 = (half2*)&vec_val_1.y; + const half2* vec_h3 = (half2*)&vec_val_1.z; + const half2* vec_h4 = (half2*)&vec_val_1.w; + const half2* vec_h5 = (half2*)&vec_val_2.x; + const half2* vec_h6 = (half2*)&vec_val_2.y; + const half2* vec_h7 = (half2*)&vec_val_2.z; + const half2* vec_h8 = (half2*)&vec_val_2.w; + + uint4_2_4 mat_val_1 = mat4[row * (n >> 3) + j]; + uint4_2_4 mat_val_2 = mat4[row * (n >> 3) + j + 1]; + const uint4_2* mat_h1 = (uint4_2*)&mat_val_1.x; + const uint4_2* mat_h2 = (uint4_2*)&mat_val_1.y; + const uint4_2* mat_h3 = (uint4_2*)&mat_val_1.z; + const uint4_2* mat_h4 = (uint4_2*)&mat_val_1.w; + const uint4_2* mat_h5 = (uint4_2*)&mat_val_2.x; + const uint4_2* mat_h6 = (uint4_2*)&mat_val_2.y; + const uint4_2* mat_h7 = (uint4_2*)&mat_val_2.z; + const uint4_2* mat_h8 = (uint4_2*)&mat_val_2.w; + + sum += static_cast(vec_h1->x) * + (static_cast(mat_h1->getX()) - zero_point_f); + sum += static_cast(vec_h1->y) * + (static_cast(mat_h1->getY()) - zero_point_f); + sum += static_cast(vec_h2->x) * + (static_cast(mat_h2->getX()) - zero_point_f); + sum += static_cast(vec_h2->y) * + (static_cast(mat_h2->getY()) - zero_point_f); + sum += static_cast(vec_h3->x) * + (static_cast(mat_h3->getX()) - zero_point_f); + sum += static_cast(vec_h3->y) * + (static_cast(mat_h3->getY()) - zero_point_f); + sum += static_cast(vec_h4->x) * + (static_cast(mat_h4->getX()) - zero_point_f); + sum += static_cast(vec_h4->y) * + (static_cast(mat_h4->getY()) - zero_point_f); + sum += static_cast(vec_h5->x) * + (static_cast(mat_h5->getX()) - zero_point_f); + sum += static_cast(vec_h5->y) * + (static_cast(mat_h5->getY()) - zero_point_f); + sum += static_cast(vec_h6->x) * + (static_cast(mat_h6->getX()) - zero_point_f); + sum += static_cast(vec_h6->y) * + (static_cast(mat_h6->getY()) - zero_point_f); + sum += static_cast(vec_h7->x) * + (static_cast(mat_h7->getX()) - zero_point_f); + sum += static_cast(vec_h7->y) * + (static_cast(mat_h7->getY()) - zero_point_f); + sum += static_cast(vec_h8->x) * + (static_cast(mat_h8->getX()) - zero_point_f); + sum += static_cast(vec_h8->y) * + (static_cast(mat_h8->getY()) - zero_point_f); + } + } + + sum *= scale_f; + + sum = warpReduceSum(sum, blockDim.x); + + if (blockDim.x <= WARP_SIZE) { + if (tid == 0) { + res[row] = __float2half(sum); + } + return; + } + + // Shared mem for partial sums (one per warp in the block) + static __shared__ float warpLevelSums[SHARED_MEM_MAX_ROWS][WARP_SIZE]; + const int laneId = threadIdx.x % WARP_SIZE; + const int warpId = threadIdx.x / WARP_SIZE; + if (laneId == 0) + warpLevelSums[threadIdx.y][warpId] = sum; + __syncthreads(); + // read from shared memory only if that warp existed + sum = (threadIdx.x < blockDim.x / WARP_SIZE) + ? warpLevelSums[threadIdx.y][laneId] + : 0.0; + // Final reduce using first warp + if (warpId == 0) + sum = warpReduceSum(sum, blockDim.x / WARP_SIZE); + if (tid == 0) { + res[row] = __float2half(sum); + } +} + +///////////////////////////// REDUCE SUM ////////////////////////////// + +__device__ __forceinline__ float warpReduceSum( + float sum, + unsigned int threadNum) { + if (threadNum >= 32) + sum += __shfl_down_sync(0xffffffff, sum, 16); // 0-16, 1-17, 2-18, etc. + if (threadNum >= 16) + sum += __shfl_down_sync(0xffffffff, sum, 8); // 0-8, 1-9, 2-10, etc. + if (threadNum >= 8) + sum += __shfl_down_sync(0xffffffff, sum, 4); // 0-4, 1-5, 2-6, etc. + if (threadNum >= 4) + sum += __shfl_down_sync(0xffffffff, sum, 2); // 0-2, 1-3, 4-6, 5-7, etc. + if (threadNum >= 2) + sum += __shfl_down_sync(0xffffffff, sum, 1); // 0-1, 2-3, 4-5, etc. + return sum; +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cuh new file mode 100644 index 0000000000..f8aabcb52c --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/fast_gemv.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This source code is pulled from the original github repo: + * https://github.com/wangsiping97/FastGEMV. + */ + +#ifndef FAST_GEMV_CUH_ +#define FAST_GEMV_CUH_ + +#include +#include +#include + +#include "utility.cuh" + +#define WARP_SIZE 32 +#define SHARED_MEM_MAX_ROWS 64 +#define MAX_THREADS_PER_BLOCK 1024 + +///////////////////////////// GEMV ////////////////////////////// +__global__ void gemv_fp16( + half* mat, + half* vec, + half* res, + unsigned int n, + unsigned int num_per_thread); + +__global__ void gemv_quantized_int8( + int8_t* mat, + half* vec, + half* res, + unsigned int n, + half scale, + half zero_point, + unsigned int num_per_thread); + +__global__ void gemv_quantized_int4( + uint4_2* mat, + half* vec, + half* res, + unsigned int n, + half scale, + half zero_point, + unsigned int num_per_thread); + +///////////////////////////// REDUCE SUM ////////////////////////////// +__device__ __forceinline__ float warpReduceSum( + float sum, + unsigned int threadNum); + +#endif // FAST_GEMV_CUH_ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/utility.cuh b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/utility.cuh new file mode 100644 index 0000000000..2649621816 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/fast_gemv/include/utility.cuh @@ -0,0 +1,78 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * This source code is pulled from the original github repo: + * https://github.com/wangsiping97/FastGEMV. + */ + +#ifndef UTILITY_H_ +#define UTILITY_H_ + +#include +#include +#include +#include + +#include + +///////////////////////////// DATA TYPES ////////////////////////////// + +struct uint4_2 { + uint8_t data; + + uint4_2(uint8_t x = 0, uint8_t y = 0) { + setX(x); + setY(y); + } + + __host__ __device__ uint8_t getX() const { + return data & 0x0F; // get the lower 4 bits + } + + __host__ __device__ uint8_t getY() const { + return (data >> 4) & 0x0F; // get the upper 4 bits + } + + __host__ __device__ void setX(uint8_t x) { + data = (data & 0xF0) | (x & 0x0F); // set the lower 4 bits + } + + __host__ __device__ void setY(uint8_t y) { + data = (data & 0x0F) | ((y & 0x0F) << 4); // set the upper 4 bits + } +}; + +struct half4 { + half x, y, z, w; +}; +struct int8_2 { + int8_t x, y; +}; +struct uint4_2_4 { + uint4_2 x, y, z, w; +}; + +///////////////////////////// CUDA UTILITIES ////////////////////////////// + +void print_cuda_info(); + +// Define the error checking function +#define checkCudaErrors(val) check((val), #val, __FILE__, __LINE__) + +void check( + cudaError_t result, + char const* const func, + const char* const file, + int const line); + +__global__ void generate_random_numbers(half* numbers, int Np); +__global__ void generate_random_int8_numbers(int8_t* numbers, int Np); +__global__ void generate_random_int4_numbers(uint4_2* numbers, int Np); + +#endif // UTILITY_H_ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index da8dbc475c..85108bc9c2 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -123,6 +123,7 @@ at::Tensor f8f8bf16_cublas( std::optional Binvs = std::nullopt, bool use_fast_accum = true, std::optional output = std::nullopt); +at::Tensor fp16_fast_gemv(at::Tensor X, at::Tensor W); at::Tensor f8i4bf16_rowwise( at::Tensor XQ, at::Tensor WQ, @@ -186,6 +187,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8f8bf16_cublas(Tensor A, Tensor B, Tensor? Ainvs=None, Tensor? Binvs=None, bool use_fast_accum=True, Tensor(a!)? output=None) -> Tensor"); m.def( "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); + m.def("fp16_fast_gemv(Tensor X, Tensor W) -> Tensor"); m.def( "f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]"); m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor"); @@ -269,6 +271,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); + m.impl("fp16_fast_gemv", fp16_fast_gemv); m.impl("f8f8bf16_grouped", f8f8bf16_grouped); m.impl("f8f8bf16_lite", f8f8bf16_lite); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); @@ -297,6 +300,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("i8i8bf16", i8i8bf16); m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); + m.impl("fp16_fast_gemv", fp16_fast_gemv); m.impl("f8f8bf16_grouped", f8f8bf16_grouped); m.impl("f8f8bf16_lite", f8f8bf16_lite); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); @@ -408,6 +412,13 @@ at::Tensor f8f8bf16_meta( return Y; } +at::Tensor fp16_fast_gemv_meta(at::Tensor X, at::Tensor W) { + const at::SymInt M = X.sym_size(0); + const at::SymInt N = W.sym_size(0); + auto Y = at::empty_symint({M, N}, X.options().dtype(at::kHalf)); + return Y; +} + at::Tensor f8f8bf16_tensorwise_meta( at::Tensor X, at::Tensor W, @@ -539,6 +550,7 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("i8i8bf16", i8i8bf16_meta); m.impl("f8f8bf16", f8f8bf16_meta); m.impl("f8f8bf16_cublas", f8f8bf16_cublas_meta); + m.impl("fp16_fast_gemv", fp16_fast_gemv_meta); m.impl("f8f8bf16_rowwise_batched", f8f8bf16_rowwise_batched_meta); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index a31ef126d5..6fd08ee3a6 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -1095,6 +1095,25 @@ def test_quantize_compile(self) -> None: block_scale[0], block_scale[0], ) + # test f16_fast_gemv is torch compileable + X_f16 = torch.randn(K, M, device="cuda", dtype=torch.float16) + W_f16 = torch.randn(N, K, device="cuda", dtype=torch.float16) + torch.compile(torch.ops.fbgemm.fp16_fast_gemv)(X_f16, W_f16) + + @unittest.skipIf( + not torch.version.cuda, "Skip on AMD: fast gemv op is not yet supported." + ) + def test_fp16_gemv(self) -> None: + M = 1 + N = 128 + K = 256 + x = torch.randn(size=(M, K), dtype=torch.half, device="cuda") + w = torch.randn(size=(N, K), dtype=torch.half, device="cuda") + + z = torch.ops.fbgemm.fp16_fast_gemv(x.T, w) + z_ref = (x @ w.T).to(torch.half).to("cuda") + + torch.testing.assert_close(z.T, z_ref, atol=1.0e-3, rtol=1.0e-3) @unittest.skipIf( torch.version.hip, "Skip on AMD: cuda quantize op is yet suported."