Skip to content

Commit

Permalink
Add CPU implementation for torch._int_mm (s8*s8->s32) (pytorch#121792)
Browse files Browse the repository at this point in the history
Fixes pytorch#121647

**Description**
Currently, the op `torch._int_mm` only supports CUDA device. This PR adds CPU implementation for it.
Besides the request from the issue, this op may also be useful for planned CPU implementations of [LLM.int8()](https://arxiv.org/abs/2208.07339) in [Bitsandbytes](https://github.com/TimDettmers/bitsandbytes).

The implementation prefers mkldnn (oneDNN) kernels. If mkldnn is not available, a reference implementation with nested for loops is used.

**Test plan**
`python test/test_linalg.py -k test__int_mm_cpu`

Pull Request resolved: pytorch#121792
Approved by: https://github.com/jgong5, https://github.com/lezcano
  • Loading branch information
Xia-Weiwen authored and pytorchmergebot committed Mar 19, 2024
1 parent 0d845f7 commit 8168338
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 0 deletions.
60 changes: 60 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <ATen/ops/_addmm_activation_native.h>
#include <ATen/ops/_compute_linear_combination_native.h>
#include <ATen/ops/_convert_weight_to_int4pack_native.h>
#include <ATen/ops/_int_mm_native.h>
#include <ATen/ops/_linalg_check_errors.h>
#include <ATen/ops/_linalg_det.h>
#include <ATen/ops/_linalg_det_native.h>
Expand Down Expand Up @@ -3506,5 +3507,64 @@ Tensor _weight_int8pack_mm_cpu(
return C;
}

Tensor& _int_mm_out_cpu(const Tensor& self, const Tensor& mat2, Tensor& result) {
static constexpr c10::string_view func_name = "int_mm_out_cpu";
TORCH_CHECK(self.dim() == 2, func_name, ": Expected self to be of dimension 2 but got ", self.dim());
TORCH_CHECK(mat2.dim() == 2, func_name, ": Expected mat2 to be of dimension 2 but got ", mat2.dim());
TORCH_CHECK(self.size(1) == mat2.size(0), func_name, ": self.size(1) needs to match mat2.size(0) but got ", self.size(1), " and ", mat2.size(0));
TORCH_CHECK(self.dtype() == at::kChar, func_name, ": Expected self dtype to be of type int8 but got ", self.dtype());
TORCH_CHECK(mat2.dtype() == at::kChar, func_name, ": Expected mat2 dtype to be of type int8 but got ", mat2.dtype());
TORCH_CHECK(result.dtype() == at::kInt, func_name, ": Expected result dtype to be of type kInt but got ", result.dtype());
TORCH_CHECK(result.size(0) == self.size(0), func_name, ": Expected result.size(0) to be ", self.size(0), " but got ", result.size(0));
TORCH_CHECK(result.size(1) == mat2.size(1), func_name, ": Expected result.size(1) to be ", mat2.size(1), " but got ", result.size(1));
TORCH_CHECK(result.dim() == 2, func_name, ": Expected result to be of dimension 2 but got ", result.dim());
TORCH_CHECK(result.is_contiguous(), func_name, ": Expected result to be contiguous.");

if (result.numel() == 0 || self.size(1) == 0) {
return result.zero_();
}

bool dispatched = false;
if (at::globalContext().userEnabledMkldnn()) {
try {
mkldnn_matmul_i8i8i32(self, mat2, result);
dispatched = true;
} catch (const std::exception& e) {
TORCH_WARN(func_name, " failed, switching to BLAS gemm: ", e.what());
}
}
if (!dispatched) {
auto a = reinterpret_cast<int8_t*>(self.data_ptr());
auto b = reinterpret_cast<int8_t*>(mat2.data_ptr());
auto c = reinterpret_cast<int32_t*>(result.data_ptr());
const int64_t m = result.size(0);
const int64_t n = result.size(1);
const int64_t k = self.size(1);
const int64_t lda_0 = self.strides()[0];
const int64_t lda_1 = self.strides()[1];
const int64_t ldb_0 = mat2.strides()[0];
const int64_t ldb_1 = mat2.strides()[1];
const int64_t ldc = result.strides()[0];
parallel_for(0, m * n, 1, [&](int64_t start, int64_t end) {
for (const auto i : c10::irange(start, end)) {
auto row = i / n;
auto col = i % n;
c[row * ldc + col] = 0;
for (const auto k : c10::irange(k)) {
c[row * ldc + col] = c[row * ldc + col] +
static_cast<int32_t>(a[row * lda_0 + k * lda_1]) *
static_cast<int32_t>(b[k * ldb_0 + col * ldb_1]);
}
}
});
}
return result;
}

Tensor _int_mm_cpu(const Tensor& self, const Tensor& mat2) {
Tensor result = at::empty({self.size(0), mat2.size(1)}, self.options().dtype(at::kInt));
return _int_mm_out_cpu(self, mat2, result);
}

} // namespace native
} // namespace at
110 changes: 110 additions & 0 deletions aten/src/ATen/native/mkldnn/Matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ bool use_mkldnn_matmul(
return false;
}

void mkldnn_matmul_i8i8i32(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result) {
TORCH_INTERNAL_ASSERT(false, __func__, ": ATen not compiled with MKLDNN support");
}

} // namespace native
} // namespace at

Expand Down Expand Up @@ -402,6 +409,109 @@ bool use_mkldnn_matmul(
return (use_mkldnn_bf16_matmul(mat1, mat2, result) || use_mkldnn_fp16_matmul(mat1, mat2, result) || use_mkldnn_bf32_matmul(mat1, mat2, result));
}

static void _mkldnn_matmul_i8i8i32_with_primitive(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result) {
// Create ideep tensors for oneDNN computation
auto src = ideep::tensor(
{mat1.sizes().vec(),
ideep::tensor::data_type::s8,
mat1.strides().vec()},
mat1.data_ptr());
auto wei = ideep::tensor(
{mat2.sizes().vec(),
ideep::tensor::data_type::s8,
mat2.strides().vec()},
mat2.data_ptr());
auto dst = ideep::tensor(
{result.sizes().vec(),
ideep::tensor::data_type::s32,
result.strides().vec()},
result.data_ptr());
// Create primitive desc
auto engine = ideep::engine::cpu_engine();
ideep::attr_t op_attr;
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
auto src_desc = src.get_desc();
auto wei_desc = wei.get_desc();
auto dst_desc = dst.get_desc();
auto prim_desc = dnnl::matmul::primitive_desc(
engine, src_desc, wei_desc, dst_desc, op_attr);
// Reorder mat2 if needed
auto expected_weight = wei.reorder_if_differ_in(prim_desc.weights_desc());
// Prepare args for primitive
ideep::tensor scratchpad(prim_desc.scratchpad_desc());
ideep::exec_args args;
args.insert({DNNL_ARG_SRC, src});
args.insert({DNNL_ARG_WEIGHTS, expected_weight});
args.insert({DNNL_ARG_DST, dst});
args.insert({DNNL_ARG_SCRATCHPAD, scratchpad});
// Create primitve and execute
auto primitive = dnnl::matmul(prim_desc);
primitive.execute(ideep::stream::default_stream(), args);
}

static void _mkldnn_gemm_i8i8i32_with_blas(
const Tensor& self,
const Tensor& mat2,
const Tensor& result) {
const int m = result.size(0);
const int n = result.size(1);
const int k = self.size(1);

const char transa = self.strides()[1] == 1 ? 'N' : 'T';
const char transb = mat2.strides()[1] == 1 ? 'N' : 'T';
const char offsetc = 'F';

const int lda = transa == 'T' ? self.stride(1) : self.stride(0);
const int ldb = transb == 'T' ? mat2.stride(1) : mat2.stride(0);
const int ldc = n;

const float alpha = 1;
const float beta = 0;

int8_t ao = 0;
int8_t bo = 0;
int32_t co = 0;

dnnl::gemm_s8s8s32(
transa,
transb,
offsetc,
m,
n,
k,
alpha,
(int8_t*)self.data_ptr(),
lda,
ao,
(int8_t*)mat2.data_ptr(),
ldb,
bo,
beta,
(int32_t*)result.data_ptr(),
ldc,
&co);
}

void mkldnn_matmul_i8i8i32(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result) {
// x:s8 * w:s8 -> y:s32
// both inputs should be 2d
// In most cases, using DNNL blas API is faster but it requires a/b contiguous along one dimentsion
bool a_is_contigous = (mat1.stride(0) == 1 || mat1.stride(1) == 1);
bool b_is_contigous = (mat2.stride(0) == 1 || mat2.stride(1) == 1);

if (a_is_contigous && b_is_contigous) {
_mkldnn_gemm_i8i8i32_with_blas(mat1, mat2, result);
} else {
_mkldnn_matmul_i8i8i32_with_primitive(mat1, mat2, result);
}
}

} // namespace native
} // namespace at

Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/mkldnn/Matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ bool use_mkldnn_matmul(
const Tensor& mat2,
const Tensor& result);

// x:s8 * w:s8 -> y:s32
TORCH_API void mkldnn_matmul_i8i8i32(
const Tensor &mat1,
const Tensor &mat2,
const Tensor &result);

}

}
2 changes: 2 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4084,10 +4084,12 @@

- func: _int_mm(Tensor self, Tensor mat2) -> Tensor
dispatch:
CPU: _int_mm_cpu
CUDA: _int_mm_cuda

- func: _int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU: _int_mm_out_cpu
CUDA: _int_mm_out_cuda

- func: _convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor
Expand Down
43 changes: 43 additions & 0 deletions test/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5866,6 +5866,49 @@ def _gen_pair(m, k, n):
r"Expected result.size\(0\) to be 17 but got 16",
lambda: torch._int_mm(genf_int(17, 8), genf_int(8, 32), out=genf_int(16, 31).int()))

@onlyCPU
@parametrize("m", [0, 8, 17])
@parametrize("k", [0, 16, 32])
@parametrize("n", [16, 32])
@parametrize("use_transpose_a", [True, False])
@parametrize("use_transpose_b", [True, False])
@parametrize("non_contig_type", [0, 1, 2])
def test__int_mm_cpu(self, device, m, k, n, use_transpose_a, use_transpose_b, non_contig_type):
# non_contig_type:
# 0: the whole data buffer is contiguous (can be transposed)
# 1: stride of one dimension is 1, but the whole buffer is not contiguous
# 2: Neither stride is 1

def genf_int_float(x, y, use_transpose, non_contig_type):
if use_transpose:
x, y = y, x
if non_contig_type != 0:
y = y * 2
x_int8 = torch.randint(-10, 10, (x, y), dtype=torch.int8, device=device)
x_float = x_int8.to(torch.float32)
if non_contig_type == 1:
x_int8 = x_int8[:, : y // 2]
x_float = x_float[:, : y // 2]
elif non_contig_type == 2:
x_int8 = x_int8[:, ::2]
x_float = x_float[:, ::2]
if use_transpose:
return x_int8.t(), x_float.t()
return x_int8, x_float

if non_contig_type != 0 and (m == 0 or k == 0):
return
a_int8, a_float = genf_int_float(m, k, use_transpose_a, non_contig_type)
b_int8, b_float = genf_int_float(k, n, use_transpose_b, non_contig_type)
c_int32 = torch._int_mm(a_int8, b_int8)
self.assertTrue(c_int32.dtype is torch.int32)
self.assertEqual(c_int32.device, torch.device(device))
self.assertEqual(c_int32.float(), torch.mm(a_float, b_float))
c_int32_result = c_int32.new_empty(c_int32.size())
# Checking out variant
torch._int_mm(a_int8, b_int8, out=c_int32_result)
self.assertEqual(c_int32_result.float(), torch.mm(a_float, b_float))

def _group_quantize_tensor(self, w, n_bit=4, q_group_size=16):
assert w.dim() == 2
w = w.transpose(0, 1).contiguous()
Expand Down

0 comments on commit 8168338

Please sign in to comment.