Skip to content

Commit

Permalink
linear_int4_kernel for XPU (#1130)
Browse files Browse the repository at this point in the history
Pure SYCL path for. int4 gemm

Benchmark results on PVC-1100. The remaining gaps are lack of usage of
2D load.

| M | K | N | SrcT   | WeiT   | DstT   | Bandwidth usage (BW usage) |

|------------|-------------|-------------|----------|----------|----------|----------------|
| 1 | 4096 | 4096 | float16 | float16 | float16 | 53.7% |
| 1 | 4096 | 11008 | float16 | float16 | float16 | 57.4% |
| 1 | 4096 | 16384 | float16 | float16 | float16 | 59.7% |
| 1 | 12288 | 4096 | float16 | float16 | float16 | 77.3% |



Besides PVC, the kernel can achieve 
92.7% bandwidth usage on MTL
84.7% bandwidth usage on A750

---------

Co-authored-by: Yutao Xu <[email protected]>
Co-authored-by: mengfei25 <[email protected]>
Co-authored-by: LuFengqing <[email protected]>
Co-authored-by: Ratnam Parikh <[email protected]>
Co-authored-by: Feng Yuan <[email protected]>
Co-authored-by: Yu, Guangye <[email protected]>
Co-authored-by: ZhiweiYan-96 <[email protected]>
Co-authored-by: Meng, Hengyu <[email protected]>
  • Loading branch information
9 people authored Jan 6, 2025
1 parent ad8f244 commit d4432d0
Show file tree
Hide file tree
Showing 5 changed files with 438 additions and 0 deletions.
63 changes: 63 additions & 0 deletions src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

#include <ATen/core/op_registration/adaption.h>
#include <ATen/div_rtn.h>
#include <ATen/native/TensorIterator.h>
#include <torch/library.h>

#include <ATen/native/xpu/sycl/LinearInt4.h>
#include <comm/xpu_aten.h>

namespace at::native {
Tensor _weight_int4pack_mm_xpu(
const Tensor& A,
const Tensor& B,
int64_t qGroupSize,
const Tensor& qScaleAndZeros) {
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);
TORCH_CHECK(
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
__func__,
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");

TORCH_CHECK(
B.dtype() == kInt || B.dtype() == kUInt32 || B.dtype() == kByte,
__func__,
" : expect B to be int32 or uint32 or uint8 tensor.");
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor.");

TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256,
__func__,
": expect qGroupSize to be 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(0) == N &&
qScaleAndZeros.size(2) == 2,
__func__,
": expect qScaleAndZeros to be 3d tensor with sizes [",
N,
", :, 2]");

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device, A, "xpu::_weight_int4pack_mm", "A");
c10::impl::check_and_update_common_device(
common_device, B, "xpu::_weight_int4pack_mm", "B");
c10::impl::check_and_update_common_device(
common_device,
qScaleAndZeros,
"xpu::_weight_int4pack_mm",
"qScaleAndZeros");
Tensor C = at::empty({M, N}, A.options());

at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C);
return C;
}
} // namespace at::native
247 changes: 247 additions & 0 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
#include <ATen/native/xpu/sycl/LinearInt4.h>
#include <comm/SYCLContext.h>

namespace at::native::xpu {
static inline int padto_le(int src, int padding) {
return src / padding * padding;
}

static inline int64_t padto_le(int64_t src, int64_t padding) {
return src / padding * padding;
}

static inline size_t padto_le(size_t src, int padding) {
return src / size_t(padding) * size_t(padding);
}

template <typename scalar_t = sycl::half, int block_size = 32>
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
LinearInt4KernelFunctor(
const scalar_t* A,
const uint8_t* B,
scalar_t* C,
const scalar_t* ScaleAndZeros,
int m,
int n,
int k,
int lda,
int ldb,
int ldc)
: A(A),
B(B),
C(C),
ScaleAndZeros(ScaleAndZeros),
m(m),
n(n),
k(k),
lda(lda),
ldb(ldb),
ldc(ldc) {}
void sycl_ker_config_convention(sycl::handler& cgh) {}

void operator()(sycl::nd_item<1> it) const {
int constexpr Unroll = 2;
int constexpr SgSize = 16;
int constexpr blocksize = block_size;
using scalarx2_t = sycl::vec<scalar_t, 2>;

if (k % (SgSize * 32 * Unroll) == 0) {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;

int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;

sycl::float2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK / blocksize) * 2;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::float2 sum = {0.f, 0.f};
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>());
if (sg_id == 0) {
*cptr = static_cast<scalar_t>(sum[0] + sum[1]);
}
} else { // k % (SgSize * 32 * Unroll) != 0
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
int k_body = padto_le(k, GroupK * Unroll);
int constexpr TileK2 = 8;
int constexpr GroupK2 = SgSize * TileK2;
int k_body2 = padto_le(k, GroupK2 * Unroll);
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
sycl::float2 tmpAcc = {0.f, 0.f};
int i = 0;
for (; i < k_body; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);

int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK / blocksize) * 2;
aptr += GroupK;
bptr += GroupK / 2;
}
}
if (i + GroupK2 * Unroll < k_body2) {
for (; i < k_body2; i += GroupK2 * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK2 / 2];
*(sycl::vec<uint8_t, TileK2 / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK2 / 2>*)(bptr + sg_id * TileK2 / 2);

int scale_offset = sg_id * (TileK2 / blocksize) * 2;
int zp_offset = sg_id * (TileK2 / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK2; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK2 + ikk);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK2 / blocksize) * 2;
aptr += GroupK2;
bptr += GroupK2 / 2;
}
}
}
if (i + SgSize * 2 <= k) {
for (; i < k; i += SgSize * 2) {
uint8_t tmps8 = *(bptr + sg_id);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8 & 0x0f) - 8),
static_cast<int8_t>((tmps8 >> 4) - 8)};

int scale_offset = sg_id * (2 / blocksize) * 2;
int zp_offset = sg_id * (2 / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2);
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
sptr += (SgSize * 2 / blocksize) * 2;
aptr += SgSize * 2;
bptr += SgSize * 2 / 2;
}
}
sycl::float2 sum = {0.f, 0.f};
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>());
if (sg_id == 0) {
*cptr = static_cast<scalar_t>(sum[0] + sum[1]);
}
}
}

private:
const scalar_t* A;
const uint8_t* B;
scalar_t* C;
const scalar_t* ScaleAndZeros;
int m;
int n;
int k;
int lda;
int ldb;
int ldc;
};

void linear_int4_kernel(
const Tensor& A,
const Tensor& B,
int qGroupSize,
const Tensor& qScaleAndZeros,
Tensor& C) {
auto& sycl_queue = at::xpu::getCurrentSYCLQueue();
int64_t m = A.size(0);
int64_t n = C.size(1);
int64_t k = A.size(1);
int constexpr SgSize = 16;
sycl::range<1> local_range{SgSize};
sycl::range<1> global_range{static_cast<size_t>(n) * SgSize};
AT_DISPATCH_REDUCED_FLOATING_TYPES(
A.scalar_type(), "linear_int4_kernel", [&]() {
using scalar_sycl_t = std::conditional_t<
std::is_same_v<scalar_t, at::Half>,
sycl::half,
sycl::ext::oneapi::bfloat16>;

const scalar_sycl_t* input_data =
reinterpret_cast<scalar_sycl_t*>(A.data_ptr<scalar_t>());
uint8_t* weight_data =
reinterpret_cast<uint8_t*>(B.data_ptr()); // int4x2 or int4x8

scalar_sycl_t* output_data =
reinterpret_cast<scalar_sycl_t*>(C.data_ptr<scalar_t>());
scalar_sycl_t* scale_zeros_data = reinterpret_cast<scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>());
LinearInt4KernelFunctor<scalar_sycl_t, 32> kfn =
LinearInt4KernelFunctor<scalar_sycl_t, 32>(
input_data,
weight_data,
output_data,
scale_zeros_data,
m,
n,
k,
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
});
}

} // namespace at::native::xpu
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/LinearInt4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <ATen/native/TensorIterator.h>
#include <comm/xpu_aten.h>

namespace at::native::xpu {

TORCH_XPU_API void linear_int4_kernel(
const Tensor& input,
const Tensor& weight,
int qGroupSize,
const Tensor& weight_scale_zero_point,
Tensor& output);

} // namespace at::native::xpu
Loading

0 comments on commit d4432d0

Please sign in to comment.