Skip to content

Commit

Permalink
modify by review
Browse files Browse the repository at this point in the history
  • Loading branch information
sunjiweiswift committed Dec 23, 2024
1 parent 8385f7e commit f44ed70
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
Expand All @@ -65,8 +66,10 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
scalar_t scale = *(sptr + sg_id * (TileK / blocksize) * 2);
scalar_t zero_point = *(sptr + sg_id * (TileK / blocksize) * 2 + 1);
int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk);
Expand All @@ -84,7 +87,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
sycl::float2 sum = {0.f, 0.f};
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>());
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
*cptr = static_cast<scalar_t>(sum[0] + sum[1]);
}
} else { // k % (SgSize * 32 * Unroll) != 0
int constexpr TileK = 32;
Expand All @@ -98,6 +101,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
Expand All @@ -109,8 +113,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
scalar_t scale = *(sptr + sg_id * (TileK / blocksize) * 2);
scalar_t zero_point = *(sptr + sg_id * (TileK / blocksize) * 2 + 1);

int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk);
Expand All @@ -132,9 +139,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
uint8_t tmps8[TileK2 / 2];
*(sycl::vec<uint8_t, TileK2 / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK2 / 2>*)(bptr + sg_id * TileK2 / 2);
scalar_t scale = *(sptr + sg_id * (TileK2 / blocksize) * 2);
scalar_t zero_point =
*(sptr + sg_id * (TileK2 / blocksize) * 2 + 1);

int scale_offset = sg_id * (TileK2 / blocksize) * 2;
int zp_offset = sg_id * (TileK2 / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK2; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK2 + ikk);
Expand All @@ -156,8 +165,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8 & 0x0f) - 8),
static_cast<int8_t>((tmps8 >> 4) - 8)};
scalar_t scale = *(sptr + (sg_id * 2 / blocksize) * 2);
scalar_t zero_point = *(sptr + (sg_id * 2 / blocksize) * 2 + 1);

int scale_offset = sg_id * (2 / blocksize) * 2;
int zp_offset = sg_id * (2 / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2);
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
Expand All @@ -169,7 +181,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
sycl::float2 sum = {0.f, 0.f};
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>());
if (sg_id == 0) {
*cptr = sum[0] + sum[1];
*cptr = static_cast<scalar_t>(sum[0] + sum[1]);
}
}
}
Expand Down

0 comments on commit f44ed70

Please sign in to comment.