diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py index c8e46e9f88..ab1d28e026 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py @@ -716,40 +716,6 @@ def cuda(self) -> bool: return True -@register_quantize_op -class FP8LiteGemm(QuantizeOpBase): - """ - FP8 lite matmul for memory bound. - """ - - def quantize(self, x, w): - # Quantize both input tensors. - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - return xq, wq, x_scale, w_scale - - def compute(self, xq, wq, x_scale, w_scale): - return torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale) - - def quantize_and_compute(self, x, w): - xq, wq, x_scale, w_scale = self.quantize(x, w) - return self.compute(xq, wq, x_scale * w_scale) - - @property - def name(self) -> str: - return "cuda_lite" - - @property - def hip(self) -> bool: - # Need to add support for better quantize kernel. - # Also may have an issue with cuda graphs. - return False - - @property - def cuda(self) -> bool: - return True - - @register_quantize_op class TritonFP8RowwiseGemm(QuantizeOpBase): """ diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_lite.cu b/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_lite.cu deleted file mode 100644 index 95935217a8..0000000000 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_lite.cu +++ /dev/null @@ -1,263 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -namespace fbgemm_gpu { - -#if CUDART_VERSION >= 12000 - -using SizeType32 = std::size_t; - -struct Params { - void const* act; - void const* weight; - void const* alpha; - void* output; - SizeType32 m, n, k; - - Params( - void const* _act, - void const* _weight, - void const* _alpha, - void* _output, - SizeType32 _m, - SizeType32 _n, - SizeType32 _k) - : act(_act), - weight(_weight), - alpha(_alpha), - output(_output), - m(_m), - n(_n), - k(_k) {} -}; - -template < - typename InputType, - typename OutputType, - SizeType32 TILE_M, - SizeType32 TILE_N, - SizeType32 BLOCK_SIZE> -__global__ void cudaCoreGemm( - InputType const* __restrict__ act, - InputType const* __restrict__ weight, - float const* alpha, - OutputType* __restrict__ output, - SizeType32 m, - SizeType32 n, - SizeType32 k) { - using VecType = int4; - static constexpr SizeType32 kStepK = - static_cast(128 / (8 * sizeof(InputType))); - static constexpr SizeType32 kTileK = kStepK * BLOCK_SIZE; - auto tileIdM = static_cast(blockIdx.x * TILE_M); - auto tileIdN = static_cast(blockIdx.y * TILE_N); - auto tid = static_cast(threadIdx.x); - float tile_a[kStepK], tile_w[TILE_N * kStepK]; - float acc[TILE_M * TILE_N]; - - static_assert(kStepK % 4 == 0); - using CvtInputType = cutlass::float_e4m3_t; - using Converter = cutlass::NumericArrayConverter; - using CvtSrcType = typename Converter::source_type; - using CvtResType = typename Converter::result_type; - static constexpr SizeType32 kCvtCount = - static_cast(sizeof(VecType) / sizeof(CvtSrcType)); - -#pragma unroll - for (SizeType32 i = 0; i < TILE_M * TILE_N; ++i) { - acc[i] = 0; - } - act += tileIdM * k; - weight += tileIdN * k; - output += tileIdM * n + tileIdN; - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaGridDependencySynchronize(); -#endif - - for (SizeType32 idxK = tid * kStepK; idxK < k; idxK += kTileK) { - for (SizeType32 i = 0; i < TILE_N; ++i) { - auto tile_w_quantized = - reinterpret_cast(weight + i * k + idxK)[0]; -#pragma unroll - for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) { - reinterpret_cast(tile_w)[i * kCvtCount + cvtIdx] = - Converter::convert( - reinterpret_cast(&tile_w_quantized)[cvtIdx]); - } - } -#pragma unroll - for (SizeType32 i = 0; i < TILE_M; ++i) { - auto tile_a_quantized = - reinterpret_cast(act + i * k + idxK)[0]; -#pragma unroll - for (SizeType32 cvtIdx = 0; cvtIdx < kCvtCount; ++cvtIdx) { - reinterpret_cast(tile_a)[cvtIdx] = Converter::convert( - reinterpret_cast(&tile_a_quantized)[cvtIdx]); - } -#pragma unroll - for (SizeType32 j = 0; j < TILE_N; ++j) { -#pragma unroll - for (SizeType32 l = 0; l < kStepK; ++l) { - acc[i * TILE_N + j] = - fma(tile_a[l], tile_w[j * kStepK + l], acc[i * TILE_N + j]); - } - } - } - } - - typedef cub::WarpReduce WarpReduce; - - static constexpr SizeType32 kWarpSize = 32; - static constexpr SizeType32 kWarpNum = BLOCK_SIZE / kWarpSize; - SizeType32 warpId = tid / kWarpSize, laneId = tid % kWarpSize; - __shared__ float shmem[TILE_M * TILE_N * kWarpNum]; - __shared__ typename WarpReduce::TempStorage tempStorage[kWarpNum]; -#pragma unroll - for (SizeType32 mi = 0; mi < TILE_M; ++mi) { -#pragma unroll - for (SizeType32 ni = 0; ni < TILE_N; ++ni) { - float val = WarpReduce(tempStorage[warpId]).Sum(acc[mi * TILE_N + ni]); - if (laneId == 0) { - shmem[mi * TILE_N + ni + warpId * TILE_M * TILE_N] = val; - } - } - } - __syncthreads(); - for (SizeType32 ii = tid; ii < TILE_M * TILE_N; ii += BLOCK_SIZE) { - SizeType32 mid = ii / TILE_N, nid = ii % TILE_N; - float val = 0; -#pragma unroll - for (SizeType32 jj = 0; jj < kWarpNum; ++jj) { - val += shmem[jj * TILE_M * TILE_N + ii]; - } - output[mid * n + nid] = static_cast(val * *alpha); - } - -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) - cudaTriggerProgrammaticLaunchCompletion(); -#endif -} - -template < - typename InputType, - typename OutputType, - SizeType32 TILE_M, - SizeType32 TILE_N, - SizeType32 BLOCK_SIZE> -void cudaCoreGemmKernel(Params const& params, cudaStream_t stream) { - dim3 block(BLOCK_SIZE); - dim3 grid(params.m / TILE_M, params.n / TILE_N); - - cudaCoreGemm - <<>>( - reinterpret_cast(params.act), - reinterpret_cast(params.weight), - reinterpret_cast(params.alpha), - reinterpret_cast(params.output), - params.m, - params.n, - params.k); - C10_CUDA_KERNEL_LAUNCH_CHECK(); -} - -template < - typename InputType, - typename OutputType, - int TILE_M, - int TILE_N, - int BLOCK_SIZE> -bool cudaCoreGemmTemplateCaller(Params const& params, cudaStream_t stream) { - constexpr int cudaCoreGemmTemplateMaxM = 128; - if (params.m == TILE_M) { - cudaCoreGemmKernel( - params, stream); - return true; - } - if constexpr (TILE_M < cudaCoreGemmTemplateMaxM) { - return cudaCoreGemmTemplateCaller< - InputType, - OutputType, - TILE_M + 1, - TILE_N, - BLOCK_SIZE>(params, stream); - } - return false; -} - -template -bool cudaCoreGemmLauncher(Params const& params, cudaStream_t stream) { - return cudaCoreGemmTemplateCaller( - params, stream); -} - -at::Tensor f8f8bf16_lite( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor scale) { - bool dispatched = true; - int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); - int N = WQ.size(0); - int K = WQ.size(1); - TORCH_CHECK(XQ.size(-1) == K); - - if (M > 128) { - throw std::runtime_error("f8f8bf16_lite cannot run when M > 128"); - } else if (N % 2 != 0) { - throw std::runtime_error("f8f8bf16_lite cannot run when N % 2 != 0"); - } else if (K % 16 != 0) { - throw std::runtime_error("f8f8bf16_lite cannot run when K % 16 != 0"); - } - - auto out_sizes = XQ.sizes().vec(); - out_sizes.back() = N; - at::Tensor Y = at::empty(out_sizes, XQ.options().dtype(at::kBFloat16)); - - Params params{ - XQ.data_ptr(), - WQ.data_ptr(), - scale.data_ptr(), - Y.data_ptr(), - (SizeType32)M, - (SizeType32)N, - (SizeType32)K}; - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - dispatched = cudaCoreGemmLauncher( - params, stream); - if (!dispatched) { - throw std::runtime_error("f8f8bf16_lite cannot run"); - } - return Y; -} - -#else - -at::Tensor f8f8bf16_lite( - at::Tensor XQ, // FP8 - at::Tensor WQ, // FP8 - at::Tensor scale) { - throw std::runtime_error( - "CUDA version is older than 12.0"); // requires CUDA>=12 -} - -#endif - -} // namespace fbgemm_gpu diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp index a2c7429532..34ea4ec427 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp @@ -55,7 +55,6 @@ at::Tensor f8f8bf16_tensorwise( at::Tensor WQ, double scale, bool use_fast_accum = true); -at::Tensor f8f8bf16_lite(at::Tensor XQ, at::Tensor WQ, at::Tensor scale); std::vector f8f8bf16_grouped( at::TensorList XQ, at::TensorList WQ, @@ -188,7 +187,6 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( "f8f8bf16_grouped(Tensor[] XQ, Tensor[] WQ, Tensor[] scale, Tensor? zero_start_index_M=None, bool use_fast_accum=True) -> Tensor[]"); - m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor"); m.def( "bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor"); m.def( @@ -270,7 +268,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("f8f8bf16_grouped", f8f8bf16_grouped); - m.impl("f8f8bf16_lite", f8f8bf16_lite); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); @@ -298,7 +295,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { m.impl("f8f8bf16", f8f8bf16); m.impl("f8f8bf16_cublas", f8f8bf16_cublas); m.impl("f8f8bf16_grouped", f8f8bf16_grouped); - m.impl("f8f8bf16_lite", f8f8bf16_lite); m.impl("f8i4bf16_rowwise", f8i4bf16_rowwise); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched); m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise); @@ -419,13 +415,6 @@ at::Tensor f8f8bf16_tensorwise_meta( return Y; } -at::Tensor f8f8bf16_lite_meta(at::Tensor X, at::Tensor W, at::Tensor scale) { - const at::SymInt M = X.sym_size(0); - const at::SymInt N = W.sym_size(0); - auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16)); - return Y; -} - at::Tensor f8i4bf16_rowwise_meta( at::Tensor XQ, // FP8 at::Tensor WQ, // INT4 @@ -544,7 +533,6 @@ TORCH_LIBRARY_IMPL(fbgemm, Meta, m) { m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise_meta); m.impl("bf16i4bf16_rowwise_batched", bf16i4bf16_rowwise_batched_meta); m.impl("f8f8bf16_grouped", f8f8bf16_grouped_meta); - m.impl("f8f8bf16_lite", f8f8bf16_lite_meta); #endif } diff --git a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py index f7197e1710..96d59339eb 100644 --- a/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py +++ b/fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py @@ -1110,30 +1110,6 @@ def test_quantize_zero_input(self, K) -> None: torch.testing.assert_close(w.shape, wq.shape) torch.testing.assert_close(w_scale.shape, w_scale_ref.shape) - @unittest.skipIf(torch.version.hip, "Skip on AMD: fp8 lite op is yet suported.") - @settings(deadline=None) - @given( - M=st.sampled_from([1, 5, 16]), - N=st.sampled_from([1024, 6144]), - K=st.sampled_from([512, 3584]), - CudaGraph=st.sampled_from([True, False]), - ) - def test_fp8_lite_matmul(self, M: int, N: int, K: int, CudaGraph: bool) -> None: - x = torch.randn(size=(M, K), dtype=torch.bfloat16, device="cuda") * 0.1 - w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01 - xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x) - wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w) - if CudaGraph: - zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale) - g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g): - zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale) - g.replay() - else: - zq = torch.ops.fbgemm.f8f8bf16_lite(xq, wq, x_scale * w_scale) - zq_ref = (x @ w.T).to(torch.bfloat16) - torch.testing.assert_close(zq, zq_ref, atol=9.0e-2, rtol=9.0e-2) - if __name__ == "__main__": unittest.main()