diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp index 20dbf5d44..846fd3530 100644 --- a/src/ATen/native/xpu/sycl/LinearInt4.cpp +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -54,6 +54,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int sg_id = sg.get_local_id()[0]; int g_n = g_idx; auto sptr = ScaleAndZeros + g_n * ldb * 2; + auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; auto bptr = B + g_n * k / 2; auto aptr = A; auto cptr = C + g_n; @@ -65,8 +66,10 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - scalar_t scale = *(sptr + sg_id * (TileK / blocksize) * 2); - scalar_t zero_point = *(sptr + sg_id * (TileK / blocksize) * 2 + 1); + int scale_offset = sg_id * (TileK / blocksize) * 2; + int zp_offset = sg_id * (TileK / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); #pragma unroll for (int ikk = 0; ikk < TileK; ikk += 2) { scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk); @@ -84,7 +87,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { sycl::float2 sum = {0.f, 0.f}; sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>()); if (sg_id == 0) { - *cptr = sum[0] + sum[1]; + *cptr = static_cast(sum[0] + sum[1]); } } else { // k % (SgSize * 32 * Unroll) != 0 int constexpr TileK = 32; @@ -98,6 +101,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { int sg_id = sg.get_local_id()[0]; int g_n = g_idx; auto sptr = ScaleAndZeros + g_n * ldb * 2; + auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; auto bptr = B + g_n * k / 2; auto aptr = A; auto cptr = C + g_n; @@ -109,8 +113,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK / 2); - scalar_t scale = *(sptr + sg_id * (TileK / blocksize) * 2); - scalar_t zero_point = *(sptr + sg_id * (TileK / blocksize) * 2 + 1); + + int scale_offset = sg_id * (TileK / blocksize) * 2; + int zp_offset = sg_id * (TileK / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); #pragma unroll for (int ikk = 0; ikk < TileK; ikk += 2) { scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk); @@ -132,9 +139,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { uint8_t tmps8[TileK2 / 2]; *(sycl::vec*)tmps8 = *(sycl::vec*)(bptr + sg_id * TileK2 / 2); - scalar_t scale = *(sptr + sg_id * (TileK2 / blocksize) * 2); - scalar_t zero_point = - *(sptr + sg_id * (TileK2 / blocksize) * 2 + 1); + + int scale_offset = sg_id * (TileK2 / blocksize) * 2; + int zp_offset = sg_id * (TileK2 / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); #pragma unroll for (int ikk = 0; ikk < TileK2; ikk += 2) { scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK2 + ikk); @@ -156,8 +165,11 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { scalarx2_t tmpB = { static_cast((tmps8 & 0x0f) - 8), static_cast((tmps8 >> 4) - 8)}; - scalar_t scale = *(sptr + (sg_id * 2 / blocksize) * 2); - scalar_t zero_point = *(sptr + (sg_id * 2 / blocksize) * 2 + 1); + + int scale_offset = sg_id * (2 / blocksize) * 2; + int zp_offset = sg_id * (2 / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2); auto tmpAmulB = tmpA * (tmpB * scale + zero_point); tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; @@ -169,7 +181,7 @@ struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { sycl::float2 sum = {0.f, 0.f}; sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>()); if (sg_id == 0) { - *cptr = sum[0] + sum[1]; + *cptr = static_cast(sum[0] + sum[1]); } } }