Skip to content

Commit

Permalink
Cherry-pick CK PR #1636 for fp8 GEMM rowwise for 70B Prefill (#3517)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3517

X-link: facebookresearch/FBGEMM#598

- Cherry-pick from CK PR ROCm/composable_kernel#1636
- Improve fp8 GEMM rowwise for 70B Prefill with seqlen = 1k/2k

Reviewed By: xw285cornell, jwfromm

Differential Revision: D67418190

fbshipit-source-id: b6d38715b26d91d6047d03941610fa7e20e54cb7
  • Loading branch information
zjing14 authored and facebook-github-bot committed Jan 4, 2025
1 parent 7c26339 commit dd70f61
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,11 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
{{128, 7168, 8192},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 7168, 8192},
fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5},
fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{2048, 7168, 8192},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{4096, 7168, 8192},
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{8192, 7168, 8192},
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
// Support for decode across batch sizes for [8192, 3584]
Expand All @@ -84,7 +84,7 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
{{128, 8192, 3584},
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{1024, 8192, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
{{2048, 8192, 3584},
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
{{4096, 8192, 3584},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool pad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (pad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
128,
128,
32,
32,
4,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
128,
128,
32,
32,
4,
2,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 32, 1, 8>,
S<8, 8, 1>,
1,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "fp8_rowwise_common.h"

at::Tensor
fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y) {
// Check if this input needs to be padded.
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
int N = WQ.size(0);
int K = WQ.size(1);
bool kpad = (K % 128 != 0);

// Dispatch based on whether padding is needed or not.
if (kpad) {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
96,
128,
32,
32,
2,
3,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::KPadding>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
} else {
using DeviceGemmInstance = DeviceGemmHelper<
256,
256,
96,
128,
32,
32,
2,
3,
S<8, 32, 1>,
S<8, 32, 1>,
S<1, 64, 1, 4>,
S<8, 8, 1>,
2,
1,
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion::v3,
ck::tensor_operation::device::GemmSpecialization::Default>;
// Run kernel instance.
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
XQ, WQ, x_scale, w_scale, Y);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -289,3 +289,19 @@ fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

at::Tensor
fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3(
at::Tensor XQ,
at::Tensor WQ,
at::Tensor x_scale,
at::Tensor w_scale,
at::Tensor Y);

0 comments on commit dd70f61

Please sign in to comment.