diff --git a/src/ATen/native/xpu/LinearInt4.cpp b/src/ATen/native/xpu/LinearInt4.cpp new file mode 100644 index 000000000..55ce4bda6 --- /dev/null +++ b/src/ATen/native/xpu/LinearInt4.cpp @@ -0,0 +1,63 @@ + +#include +#include +#include +#include + +#include +#include + +namespace at::native { +Tensor _weight_int4pack_mm_xpu( + const Tensor& A, + const Tensor& B, + int64_t qGroupSize, + const Tensor& qScaleAndZeros) { + auto M = A.size(0); + auto N = B.size(0); + auto K = A.size(1); + TORCH_CHECK( + A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, + __func__, + " : expect A to be either 32-bit or 16-bit float tensor."); + TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); + TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); + + TORCH_CHECK( + B.dtype() == kInt || B.dtype() == kUInt32 || B.dtype() == kByte, + __func__, + " : expect B to be int32 or uint32 or uint8 tensor."); + TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous."); + TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor."); + + TORCH_CHECK( + qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || + qGroupSize == 256, + __func__, + ": expect qGroupSize to be 32, 64, 128 or 256, got ", + qGroupSize); + + TORCH_CHECK( + qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(0) == N && + qScaleAndZeros.size(2) == 2, + __func__, + ": expect qScaleAndZeros to be 3d tensor with sizes [", + N, + ", :, 2]"); + + std::optional common_device = std::nullopt; + c10::impl::check_and_update_common_device( + common_device, A, "xpu::_weight_int4pack_mm", "A"); + c10::impl::check_and_update_common_device( + common_device, B, "xpu::_weight_int4pack_mm", "B"); + c10::impl::check_and_update_common_device( + common_device, + qScaleAndZeros, + "xpu::_weight_int4pack_mm", + "qScaleAndZeros"); + Tensor C = at::empty({M, N}, A.options()); + + at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C); + return C; +} +} // namespace at::native diff --git a/src/ATen/native/xpu/sycl/LinearInt4.cpp b/src/ATen/native/xpu/sycl/LinearInt4.cpp new file mode 100644 index 000000000..846fd3530 --- /dev/null +++ b/src/ATen/native/xpu/sycl/LinearInt4.cpp @@ -0,0 +1,247 @@ +#include +#include + +namespace at::native::xpu { +static inline int padto_le(int src, int padding) { + return src / padding * padding; +} + +static inline int64_t padto_le(int64_t src, int64_t padding) { + return src / padding * padding; +} + +static inline size_t padto_le(size_t src, int padding) { + return src / size_t(padding) * size_t(padding); +} + +template +struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + LinearInt4KernelFunctor( + const scalar_t* A, + const uint8_t* B, + scalar_t* C, + const scalar_t* ScaleAndZeros, + int m, + int n, + int k, + int lda, + int ldb, + int ldc) + : A(A), + B(B), + C(C), + ScaleAndZeros(ScaleAndZeros), + m(m), + n(n), + k(k), + lda(lda), + ldb(ldb), + ldc(ldc) {} + void sycl_ker_config_convention(sycl::handler& cgh) {} + + void operator()(sycl::nd_item<1> it) const { + int constexpr Unroll = 2; + int constexpr SgSize = 16; + int constexpr blocksize = block_size; + using scalarx2_t = sycl::vec; + + if (k % (SgSize * 32 * Unroll) == 0) { + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = ScaleAndZeros + g_n * ldb * 2; + auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + + sycl::float2 tmpAcc = {0.f, 0.f}; + for (int i = 0; i < k; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + int scale_offset = sg_id * (TileK / blocksize) * 2; + int zp_offset = sg_id * (TileK / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk); + scalarx2_t tmpB = { + static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + auto tmpAmulB = tmpA * (tmpB * scale + zero_point); + tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; + } + sptr += (GroupK / blocksize) * 2; + aptr += GroupK; + bptr += GroupK / 2; + } + } + sycl::float2 sum = {0.f, 0.f}; + sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>()); + if (sg_id == 0) { + *cptr = static_cast(sum[0] + sum[1]); + } + } else { // k % (SgSize * 32 * Unroll) != 0 + int constexpr TileK = 32; + int constexpr GroupK = SgSize * TileK; + int k_body = padto_le(k, GroupK * Unroll); + int constexpr TileK2 = 8; + int constexpr GroupK2 = SgSize * TileK2; + int k_body2 = padto_le(k, GroupK2 * Unroll); + int g_idx = it.get_group(0); + auto sg = it.get_sub_group(); + int sg_id = sg.get_local_id()[0]; + int g_n = g_idx; + auto sptr = ScaleAndZeros + g_n * ldb * 2; + auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; + auto bptr = B + g_n * k / 2; + auto aptr = A; + auto cptr = C + g_n; + sycl::float2 tmpAcc = {0.f, 0.f}; + int i = 0; + for (; i < k_body; i += GroupK * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK / 2); + + int scale_offset = sg_id * (TileK / blocksize) * 2; + int zp_offset = sg_id * (TileK / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); +#pragma unroll + for (int ikk = 0; ikk < TileK; ikk += 2) { + scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk); + scalarx2_t tmpB = { + static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + auto tmpAmulB = tmpA * (tmpB * scale + zero_point); + tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; + } + sptr += (GroupK / blocksize) * 2; + aptr += GroupK; + bptr += GroupK / 2; + } + } + if (i + GroupK2 * Unroll < k_body2) { + for (; i < k_body2; i += GroupK2 * Unroll) { +#pragma unroll + for (int iu = 0; iu < Unroll; iu++) { + uint8_t tmps8[TileK2 / 2]; + *(sycl::vec*)tmps8 = + *(sycl::vec*)(bptr + sg_id * TileK2 / 2); + + int scale_offset = sg_id * (TileK2 / blocksize) * 2; + int zp_offset = sg_id * (TileK2 / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); +#pragma unroll + for (int ikk = 0; ikk < TileK2; ikk += 2) { + scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK2 + ikk); + scalarx2_t tmpB = { + static_cast((tmps8[ikk / 2] & 0x0f) - 8), + static_cast((tmps8[ikk / 2] >> 4) - 8)}; + auto tmpAmulB = tmpA * (tmpB * scale + zero_point); + tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; + } + sptr += (GroupK2 / blocksize) * 2; + aptr += GroupK2; + bptr += GroupK2 / 2; + } + } + } + if (i + SgSize * 2 <= k) { + for (; i < k; i += SgSize * 2) { + uint8_t tmps8 = *(bptr + sg_id); + scalarx2_t tmpB = { + static_cast((tmps8 & 0x0f) - 8), + static_cast((tmps8 >> 4) - 8)}; + + int scale_offset = sg_id * (2 / blocksize) * 2; + int zp_offset = sg_id * (2 / blocksize) * 2; + scalar_t scale = *(sptr + scale_offset); + scalar_t zero_point = *(zptr + zp_offset); + scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2); + auto tmpAmulB = tmpA * (tmpB * scale + zero_point); + tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; + sptr += (SgSize * 2 / blocksize) * 2; + aptr += SgSize * 2; + bptr += SgSize * 2 / 2; + } + } + sycl::float2 sum = {0.f, 0.f}; + sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>()); + if (sg_id == 0) { + *cptr = static_cast(sum[0] + sum[1]); + } + } + } + + private: + const scalar_t* A; + const uint8_t* B; + scalar_t* C; + const scalar_t* ScaleAndZeros; + int m; + int n; + int k; + int lda; + int ldb; + int ldc; +}; + +void linear_int4_kernel( + const Tensor& A, + const Tensor& B, + int qGroupSize, + const Tensor& qScaleAndZeros, + Tensor& C) { + auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); + int64_t m = A.size(0); + int64_t n = C.size(1); + int64_t k = A.size(1); + int constexpr SgSize = 16; + sycl::range<1> local_range{SgSize}; + sycl::range<1> global_range{static_cast(n) * SgSize}; + AT_DISPATCH_REDUCED_FLOATING_TYPES( + A.scalar_type(), "linear_int4_kernel", [&]() { + using scalar_sycl_t = std::conditional_t< + std::is_same_v, + sycl::half, + sycl::ext::oneapi::bfloat16>; + + const scalar_sycl_t* input_data = + reinterpret_cast(A.data_ptr()); + uint8_t* weight_data = + reinterpret_cast(B.data_ptr()); // int4x2 or int4x8 + + scalar_sycl_t* output_data = + reinterpret_cast(C.data_ptr()); + scalar_sycl_t* scale_zeros_data = reinterpret_cast( + qScaleAndZeros.data_ptr()); + LinearInt4KernelFunctor kfn = + LinearInt4KernelFunctor( + input_data, + weight_data, + output_data, + scale_zeros_data, + m, + n, + k, + k, + k / qGroupSize, + n); + sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); + }); +} + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/LinearInt4.h b/src/ATen/native/xpu/sycl/LinearInt4.h new file mode 100644 index 000000000..c54f3df21 --- /dev/null +++ b/src/ATen/native/xpu/sycl/LinearInt4.h @@ -0,0 +1,14 @@ +#pragma once +#include +#include + +namespace at::native::xpu { + +TORCH_XPU_API void linear_int4_kernel( + const Tensor& input, + const Tensor& weight, + int qGroupSize, + const Tensor& weight_scale_zero_point, + Tensor& output); + +} // namespace at::native::xpu diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 3c1c8aed7..83dedf81f 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -6,6 +6,7 @@ from torch.testing._internal.common_dtype import floating_and_complex_types_and from torch.testing._internal.common_cuda import tf32_on_and_off from torch.testing._internal.common_mkldnn import bf32_on_and_off +from torch.testing._internal.common_quantization import _dynamically_quantize_per_channel from torch.testing import make_tensor import unittest import itertools @@ -171,6 +172,112 @@ def _test(m, k, n, transpose_a, transpose_b, test_equal=True): if not use_transpose_a and not use_transpose_b: _test(17, k, n, use_transpose_a, use_transpose_b) +@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!") +@parametrize("m", [1]) +@parametrize("k", [32, 64, 128, 256, 512, 1024]) +@parametrize("n", [32, 64, 128, 256, 512, 1024]) +def _int4_mm(self, device, m, k, n): + def _group_quantize_tensor(w, n_bit=4, q_group_size=16): + assert w.dim() == 2 + w = w.transpose(0, 1).contiguous() + assert q_group_size > 1 + assert w.shape[-1] % q_group_size == 0 + + to_quant = w.reshape(-1, q_group_size) + assert torch.isnan(to_quant).sum() == 0 + + max_val = to_quant.amax(dim=1, keepdim=True) + min_val = to_quant.amin(dim=1, keepdim=True) + max_int = 2 ** n_bit - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-6) / max_int + assert torch.isnan(scales).sum() == 0 + zeros = min_val + scales * (2 ** (n_bit - 1)) + assert torch.isnan(zeros).sum() == 0 + + out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int) + + assert torch.isnan(out).sum() == 0 + + out = out.to(dtype=torch.uint8).reshape(w.shape) + + if out.device.type == 'xpu': + out = (out[::, 1::2] << 4 | out[::, ::2]).to(torch.uint8) + elif out.device != torch.device('cpu'): + out = (out[::, ::2] << 4 | out[::, 1::2]).to(torch.uint8) + # Scales and zeros for the same q-group should be contiguous, so we can + # load as a 32-bit word + scales = scales.view(w.shape[0], -1) + zeros = zeros.view(w.shape[0], -1) + scales_and_zeros = ( + torch.cat( + [ + scales.reshape(scales.size(0), scales.size(1), 1), + zeros.reshape(zeros.size(0), zeros.size(1), 1), + ], + 2, + ) + ) + + if out.device.type != 'xpu': + scales_and_zeros = scales_and_zeros.transpose(0, 1).contiguous() + return out, scales_and_zeros + + def convert_weight_to_int4pack(b): + b_tmp, b_scales_and_zeros = _group_quantize_tensor( + b, n_bit=4, q_group_size=q_group + ) + + if self.device_type == 'cpu': + b_int4pack = torch._convert_weight_to_int4pack_for_cpu( + b_tmp, inner_k_tiles + ) + elif self.device_type == 'xpu': + b_int4pack = b_tmp.view(torch.int32) + else: + b_int4pack = torch._convert_weight_to_int4pack( + b_tmp, inner_k_tiles + ) + + return b_int4pack, b_scales_and_zeros + + def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros): + if self.device_type == 'cpu': + self.assertTrue(b_int4pack.dtype is torch.uint8) + self.assertTrue(b_int4pack.dim() == 2) + return torch._weight_int4pack_mm_for_cpu( + a, b_int4pack, q_group, b_scales_and_zeros + ) + elif self.device_type == 'xpu': + self.assertTrue(b_int4pack.dtype is torch.int32) # or b_int4pack.dtype is torch.uint8) + self.assertTrue(b_int4pack.dim() == 2) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) + else: + self.assertTrue(b_int4pack.dtype is torch.int32) + self.assertTrue(b_int4pack.dim() == 4) + return torch._weight_int4pack_mm( + a, b_int4pack, q_group, b_scales_and_zeros + ) + + q_group = 32 + inner_k_tiles = 2 + + torch.manual_seed(1) + a_bf16 = torch.rand((m, k), dtype=torch.bfloat16, device=device) + b_bf16 = torch.rand((k, n), dtype=torch.bfloat16, device=device) + + b_int4pack, b_scales_and_zeros_bf16 = convert_weight_to_int4pack(b_bf16) + for dtype in [torch.bfloat16] + ([torch.float16, torch.float32] if device == "cpu" else [torch.float16] if "xpu" in device else []): + a = a_bf16.to(dtype=dtype) + b = b_bf16.to(dtype=dtype) + b_scales_and_zeros = b_scales_and_zeros_bf16.to(dtype=dtype) + ref = torch.mm(a, b) + res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros) + mean_err = ((res - ref).abs() / ref).mean() + self.assertTrue(mean_err < 0.05) + @dtypes(torch.float, torch.complex64) # Integer matmul just supported on CPU @setBlasBackendsToDefaultFinally def matmul_small_brute_force_1d_Nd(self, device, dtype): @@ -229,6 +336,7 @@ def ck_blas_library(self): TestLinalg.test_preferred_linalg_library=preferred_linalg_library TestLinalg.test_addbmm=addbmm TestLinalg.test__int_mm=_int_mm +TestLinalg.test__int4_mm=_int4_mm TestLinalg.test_matmul_small_brute_force_1d_Nd=matmul_small_brute_force_1d_Nd TestLinalg.test_matmul_small_brute_force_2d_Nd=matmul_small_brute_force_2d_Nd TestLinalg.test_matmul_small_brute_force_3d_Nd=matmul_small_brute_force_3d_Nd diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index f19a57c7f..e65c276b2 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -8598,3 +8598,9 @@ dispatch: SparseXPU: copy_sparse_ autogen: copy_sparse_to_sparse, copy_sparse_to_sparse.out + +- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + dispatch: + XPU: _weight_int4pack_mm_xpu + # autogen: _weight_int4pack_mm.out + # tags: core