From 81683380633f3c79cbf0debc0d491612b152f202 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 19 Mar 2024 08:44:33 +0000 Subject: [PATCH] Add CPU implementation for `torch._int_mm` (s8*s8->s32) (#121792) Fixes #121647 **Description** Currently, the op `torch._int_mm` only supports CUDA device. This PR adds CPU implementation for it. Besides the request from the issue, this op may also be useful for planned CPU implementations of [LLM.int8()](https://arxiv.org/abs/2208.07339) in [Bitsandbytes](https://github.com/TimDettmers/bitsandbytes). The implementation prefers mkldnn (oneDNN) kernels. If mkldnn is not available, a reference implementation with nested for loops is used. **Test plan** `python test/test_linalg.py -k test__int_mm_cpu` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121792 Approved by: https://github.com/jgong5, https://github.com/lezcano --- aten/src/ATen/native/LinearAlgebra.cpp | 60 +++++++++++ aten/src/ATen/native/mkldnn/Matmul.cpp | 110 +++++++++++++++++++++ aten/src/ATen/native/mkldnn/Matmul.h | 6 ++ aten/src/ATen/native/native_functions.yaml | 2 + test/test_linalg.py | 43 ++++++++ 5 files changed, 221 insertions(+) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 81c7b8d941653..7ea9d74ad1b8d 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -3506,5 +3507,64 @@ Tensor _weight_int8pack_mm_cpu( return C; } +Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) { + static constexpr c10::string_view func_name = "int_mm_out_cpu"; + TORCH_CHECK(self.dim() == 2, func_name, ": Expected self to be of dimension 2 but got ", self.dim()); + TORCH_CHECK(mat2.dim() == 2, func_name, ": Expected mat2 to be of dimension 2 but got ", mat2.dim()); + TORCH_CHECK(self.size(1) == mat2.size(0), func_name, ": self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0)); + TORCH_CHECK(self.dtype() == at::kChar, func_name, ": Expected self dtype to be of type int8 but got ", self.dtype()); + TORCH_CHECK(mat2.dtype() == at::kChar, func_name, ": Expected mat2 dtype to be of type int8 but got ", mat2.dtype()); + TORCH_CHECK(result.dtype() == at::kInt, func_name, ": Expected result dtype to be of type kInt but got ", result.dtype()); + TORCH_CHECK(result.size(0) == self.size(0), func_name, ": Expected result.size(0) to be ", self.size(0), " but got ", result.size(0)); + TORCH_CHECK(result.size(1) == mat2.size(1), func_name, ": Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1)); + TORCH_CHECK(result.dim() == 2, func_name, ": Expected result to be of dimension 2 but got ", result.dim()); + TORCH_CHECK(result.is_contiguous(), func_name, ": Expected result to be contiguous."); + + if (result.numel() == 0 || self.size(1) == 0) { + return result.zero_(); + } + + bool dispatched = false; + if (at::globalContext().userEnabledMkldnn()) { + try { + mkldnn_matmul_i8i8i32(self, mat2, result); + dispatched = true; + } catch (const std::exception& e) { + TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what()); + } + } + if (!dispatched) { + auto a = reinterpret_cast(self.data_ptr()); + auto b = reinterpret_cast(mat2.data_ptr()); + auto c = reinterpret_cast(result.data_ptr()); + const int64_t m = result.size(0); + const int64_t n = result.size(1); + const int64_t k = self.size(1); + const int64_t lda_0 = self.strides()[0]; + const int64_t lda_1 = self.strides()[1]; + const int64_t ldb_0 = mat2.strides()[0]; + const int64_t ldb_1 = mat2.strides()[1]; + const int64_t ldc = result.strides()[0]; + parallel_for(0, m * n, 1, [&](int64_t start, int64_t end) { + for (const auto i : c10::irange(start, end)) { + auto row = i / n; + auto col = i % n; + c[row * ldc + col] = 0; + for (const auto k : c10::irange(k)) { + c[row * ldc + col] = c[row * ldc + col] + + static_cast(a[row * lda_0 + k * lda_1]) * + static_cast(b[k * ldb_0 + col * ldb_1]); + } + } + }); + } + return result; +} + +Tensor _int_mm_cpu(const Tensor& self, const Tensor& mat2) { + Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt)); + return _int_mm_out_cpu(self, mat2, result); +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/Matmul.cpp b/aten/src/ATen/native/mkldnn/Matmul.cpp index 227ab4dd70ba6..db02e5f3857a6 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.cpp +++ b/aten/src/ATen/native/mkldnn/Matmul.cpp @@ -78,6 +78,13 @@ bool use_mkldnn_matmul( return false; } +void mkldnn_matmul_i8i8i32( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result) { + TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support"); +} + } // namespace native } // namespace at @@ -402,6 +409,109 @@ bool use_mkldnn_matmul( return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result)); } +static void _mkldnn_matmul_i8i8i32_with_primitive( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result) { + // Create ideep tensors for oneDNN computation + auto src = ideep::tensor( + {mat1.sizes().vec(), + ideep::tensor::data_type::s8, + mat1.strides().vec()}, + mat1.data_ptr()); + auto wei = ideep::tensor( + {mat2.sizes().vec(), + ideep::tensor::data_type::s8, + mat2.strides().vec()}, + mat2.data_ptr()); + auto dst = ideep::tensor( + {result.sizes().vec(), + ideep::tensor::data_type::s32, + result.strides().vec()}, + result.data_ptr()); + // Create primitive desc + auto engine = ideep::engine::cpu_engine(); + ideep::attr_t op_attr; + op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto src_desc = src.get_desc(); + auto wei_desc = wei.get_desc(); + auto dst_desc = dst.get_desc(); + auto prim_desc = dnnl::matmul::primitive_desc( + engine, src_desc, wei_desc, dst_desc, op_attr); + // Reorder mat2 if needed + auto expected_weight = wei.reorder_if_differ_in(prim_desc.weights_desc()); + // Prepare args for primitive + ideep::tensor scratchpad(prim_desc.scratchpad_desc()); + ideep::exec_args args; + args.insert({DNNL_ARG_SRC, src}); + args.insert({DNNL_ARG_WEIGHTS, expected_weight}); + args.insert({DNNL_ARG_DST, dst}); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); + // Create primitve and execute + auto primitive = dnnl::matmul(prim_desc); + primitive.execute(ideep::stream::default_stream(), args); +} + +static void _mkldnn_gemm_i8i8i32_with_blas( + const Tensor& self, + const Tensor& mat2, + const Tensor& result) { + const int m = result.size(0); + const int n = result.size(1); + const int k = self.size(1); + + const char transa = self.strides()[1] == 1 ? 'N' : 'T'; + const char transb = mat2.strides()[1] == 1 ? 'N' : 'T'; + const char offsetc = 'F'; + + const int lda = transa == 'T' ? self.stride(1) : self.stride(0); + const int ldb = transb == 'T' ? mat2.stride(1) : mat2.stride(0); + const int ldc = n; + + const float alpha = 1; + const float beta = 0; + + int8_t ao = 0; + int8_t bo = 0; + int32_t co = 0; + + dnnl::gemm_s8s8s32( + transa, + transb, + offsetc, + m, + n, + k, + alpha, + (int8_t*)self.data_ptr(), + lda, + ao, + (int8_t*)mat2.data_ptr(), + ldb, + bo, + beta, + (int32_t*)result.data_ptr(), + ldc, + &co); + } + +void mkldnn_matmul_i8i8i32( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result) { + // x:s8 * w:s8 -> y:s32 + // both inputs should be 2d + // In most cases, using DNNL blas API is faster but it requires a/b contiguous along one dimentsion + bool a_is_contigous = (mat1.stride(0) == 1 || mat1.stride(1) == 1); + bool b_is_contigous = (mat2.stride(0) == 1 || mat2.stride(1) == 1); + + if (a_is_contigous && b_is_contigous) { + _mkldnn_gemm_i8i8i32_with_blas(mat1, mat2, result); + } else { + _mkldnn_matmul_i8i8i32_with_primitive(mat1, mat2, result); + } +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/Matmul.h b/aten/src/ATen/native/mkldnn/Matmul.h index babcb0edb6553..d82bb310efeba 100644 --- a/aten/src/ATen/native/mkldnn/Matmul.h +++ b/aten/src/ATen/native/mkldnn/Matmul.h @@ -67,6 +67,12 @@ bool use_mkldnn_matmul( const Tensor& mat2, const Tensor& result); +// x:s8 * w:s8 -> y:s32 +TORCH_API void mkldnn_matmul_i8i8i32( + const Tensor &mat1, + const Tensor &mat2, + const Tensor &result); + } } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d3b751b8e8c18..e93de749892db 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4084,10 +4084,12 @@ - func: _int_mm(Tensor self, Tensor mat2) -> Tensor dispatch: + CPU: _int_mm_cpu CUDA: _int_mm_cuda - func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) dispatch: + CPU: _int_mm_out_cpu CUDA: _int_mm_out_cuda - func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor diff --git a/test/test_linalg.py b/test/test_linalg.py index 02343a552b6dd..fdae6630fbda0 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -5866,6 +5866,49 @@ def _gen_pair(m, k, n): r"Expected result.size\(0\) to be 17 but got 16", lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int())) + @onlyCPU + @parametrize("m", [0, 8, 17]) + @parametrize("k", [0, 16, 32]) + @parametrize("n", [16, 32]) + @parametrize("use_transpose_a", [True, False]) + @parametrize("use_transpose_b", [True, False]) + @parametrize("non_contig_type", [0, 1, 2]) + def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type): + # non_contig_type: + # 0: the whole data buffer is contiguous (can be transposed) + # 1: stride of one dimension is 1, but the whole buffer is not contiguous + # 2: Neither stride is 1 + + def genf_int_float(x, y, use_transpose, non_contig_type): + if use_transpose: + x, y = y, x + if non_contig_type != 0: + y = y * 2 + x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device) + x_float = x_int8.to(torch.float32) + if non_contig_type == 1: + x_int8 = x_int8[:, : y // 2] + x_float = x_float[:, : y // 2] + elif non_contig_type == 2: + x_int8 = x_int8[:, ::2] + x_float = x_float[:, ::2] + if use_transpose: + return x_int8.t(), x_float.t() + return x_int8, x_float + + if non_contig_type != 0 and (m == 0 or k == 0): + return + a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type) + b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type) + c_int32 = torch._int_mm(a_int8, b_int8) + self.assertTrue(c_int32.dtype is torch.int32) + self.assertEqual(c_int32.device, torch.device(device)) + self.assertEqual(c_int32.float(), torch.mm(a_float, b_float)) + c_int32_result = c_int32.new_empty(c_int32.size()) + # Checking out variant + torch._int_mm(a_int8, b_int8, out=c_int32_result) + self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float)) + def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16): assert w.dim() == 2 w = w.transpose(0, 1).contiguous()