diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip index 05c90f139..23381213c 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip @@ -67,11 +67,11 @@ static const std::unordered_map, RowwiseKernel, IntTup {{128, 7168, 8192}, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{1024, 7168, 8192}, - fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5}, + fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, {{2048, 7168, 8192}, - fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{4096, 7168, 8192}, - fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, + fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{8192, 7168, 8192}, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}, // Support for decode across batch sizes for [8192, 3584] @@ -84,7 +84,7 @@ static const std::unordered_map, RowwiseKernel, IntTup {{128, 8192, 3584}, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{1024, 8192, 3584}, - fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, + fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, {{2048, 8192, 3584}, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3}, {{4096, 8192, 3584}, diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip new file mode 100644 index 000000000..9becb97bd --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 128, + 32, + 32, + 4, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 128, + 128, + 32, + 32, + 4, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip new file mode 100644 index 000000000..87fe51852 --- /dev/null +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.hip @@ -0,0 +1,70 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include "fp8_rowwise_common.h" + +at::Tensor +fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool kpad = (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (kpad) { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 96, + 128, + 32, + 32, + 2, + 3, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + 256, + 256, + 96, + 128, + 32, + 32, + 2, + 3, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return f8f8bf16_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} diff --git a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_kernel_manifest.h b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_kernel_manifest.h index ebfdccd00..7fb17dd17 100644 --- a/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_kernel_manifest.h +++ b/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_kernel_manifest.h @@ -289,3 +289,19 @@ fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave at::Tensor x_scale, at::Tensor w_scale, at::Tensor Y); + +at::Tensor +fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y); + +at::Tensor +fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + at::Tensor XQ, + at::Tensor WQ, + at::Tensor x_scale, + at::Tensor w_scale, + at::Tensor Y);