diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 95854d14b2c25..091e309cd95d8 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -616,6 +617,92 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_relu( std::move(input), reduce_range); } +static at::Tensor linear_dynamic_fp16_with_onednn_weight( + at::Tensor input, + at::Tensor onednn_weight, // fp16 tensor from MkldnnCPU + std::optional bias, + bool relu_fused) { + using ideep::tensor; + const int64_t dim = input.dim(); + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Float, + "onednn linear dynamic fp16: data type of input should be float."); + TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Half, + "onednn linear dynamic fp16: data type of weight should be half."); + + // If the input has more than two dimensions, we will reshape it to a 2-dimensional form + // for calculation and subsequently reshape the output back. + auto input_contig = + dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous(); + + auto src = at::native::itensor_from_tensor(input_contig); + auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight); + int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1); + + auto output_size = input.sizes().vec(); + output_size[dim - 1] = N; + + std::optional onednn_bias{std::nullopt}; + bool with_bias = bias.has_value(); + at::Tensor bias_val_float; + if (with_bias) { + bias_val_float = bias.value().to(at::kFloat); + if (bias_val_float.dim() == 1) { + auto b_reshape = bias_val_float.reshape({1, bias_val_float.size(0)}); + onednn_bias = at::native::itensor_view_from_dense(b_reshape); + } else { + onednn_bias = at::native::itensor_view_from_dense(bias_val_float); + } + } + std::vector src_dims = {M, K}; + std::vector dst_dims = {M, N}; + at::Tensor output = at::empty( + dst_dims, + device(c10::kCPU) + .dtype(c10::kFloat) + ); + if (output.numel() == 0) { + return output; + } + tensor dst = at::native::itensor_view_from_dense(output); + static tensor empty_tensor; + static tensor::desc empty_tensor_desc; + + // Create matmul primitive + auto src_dtype = ideep::data_type::f32; + auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any); + // onednn does not support f32f16f32 matmul, so we get primitive with f32 weight desc + // weight is stored in f16 and reordered to f32 below by `reorder_if_differ_in` + auto weights_desc = tensor::desc(packed_weight.get_dims(), ideep::data_type::f32, ideep::format_tag::any); + auto dst_dtype = dst.get_data_type(); + auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any); + auto bias_desc = with_bias ? + tensor::desc(onednn_bias.value().get_dims(), ideep::data_type::f32, ideep::format_tag::any) : + empty_tensor_desc; + // Get op attr for primitive + auto op_attr = relu_fused ? ideep::attr_t::fuse_relu() : ideep::attr_t(); + op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user); + auto engine = ideep::engine::cpu_engine(); + auto primitive_desc = with_bias ? + dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, bias_desc, dst_desc, op_attr) : + dnnl::matmul::primitive_desc(engine, src_desc, weights_desc, dst_desc, op_attr); + auto primitive = dnnl::matmul(primitive_desc); + + // Convert weight from f16 to f32 with layout changes + auto expected_weight = packed_weight.reorder_if_differ_in(primitive_desc.weights_desc()); + + // Prepare args and execute primitive + tensor scratchpad(primitive_desc.scratchpad_desc()); + ideep::exec_args args; + args.insert({DNNL_ARG_SRC, src}); + args.insert({DNNL_ARG_WEIGHTS, expected_weight}); + args.insert({DNNL_ARG_DST, dst}); + args.insert({DNNL_ARG_SCRATCHPAD, scratchpad}); + if (with_bias) { + args.insert({DNNL_ARG_BIAS, onednn_bias.value()}); + } + primitive.execute(ideep::stream::default_stream(), args); + return dim == 2 ? output : output.reshape(output_size); +} #endif // #if AT_MKLDNN_ENABLED() namespace at::native { @@ -786,6 +873,32 @@ at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(const at::Tensor& input, const #endif // USE_FBGEMM } +class LinearDynamicFp16Onednn final { + public: + static Tensor run( + Tensor act, // int8 CPU tensor, not QTensor + Tensor onednn_weight, // int8 tensor from MkldnnCPU + std::optional bias) { +#if AT_MKLDNN_ENABLED() + return linear_dynamic_fp16_with_onednn_weight( + act, onednn_weight, bias, /*relu_fused*/false); +#endif + TORCH_CHECK(false, "Unimplemented (linear_dynamic_fp16_with_onednn_weight)"); + } + + static Tensor run_relu( + Tensor act, // int8 CPU tensor, not QTensor + Tensor onednn_weight, // int8 tensor from MkldnnCPU + std::optional bias) { +#if AT_MKLDNN_ENABLED() + return linear_dynamic_fp16_with_onednn_weight( + act, onednn_weight, bias, /*relu_fused*/true); +#endif + TORCH_CHECK(false, "Unimplemented (linear_dynamic_fp16_with_onednn_weight)"); + } + +}; + TORCH_LIBRARY_IMPL(quantized, CPU, m) { register_linear_params(); @@ -834,5 +947,11 @@ TORCH_LIBRARY_IMPL(_quantized, Meta, m) { wrapped_fbgemm_linear_fp16_weight_meta); } +TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { + m.impl(TORCH_SELECTIVE_NAME("onednn::linear_dynamic_fp16"), + TORCH_FN(LinearDynamicFp16Onednn::run)); + m.impl(TORCH_SELECTIVE_NAME("onednn::linear_relu_dynamic_fp16"), + TORCH_FN(LinearDynamicFp16Onednn::run_relu)); +} } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index f4c55b2a3cfe4..d9e3d484d02d2 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -309,6 +309,23 @@ inline at::Tensor pack_weight_to_onednn_tensor( return packed_weight; } +inline at::Tensor pack_weight_to_fp16_onednn_tensor( + at::Tensor& weight, + std::optional>& input_shape) { + weight = at::_saturate_weight_to_fp16(weight); + std::vector w_dims = weight.sizes().vec(); + auto weight_fp16 = weight.to(at::kHalf); + ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::f16}, weight_fp16.data_ptr()); + auto expected_weight = wei.transpose(0, 1); // oneDNN requires transposed weight + // Onednn does not support f32f16f32 matmul, so we need to convert weight to f32 before compute + // Therefore, we just return weight in plain format + auto packed_weight = at::native::new_with_itensor_mkldnn( + std::move(expected_weight), + c10::kHalf, + weight.options().device_opt()); + return packed_weight; +} + #endif // #if AT_MKLDNN_ENABLED() namespace at::native { @@ -672,6 +689,21 @@ class QLinearPackWeightInt8Onednn final { } }; +class QLinearPackWeightFp16Onednn final { + public: + static at::Tensor run( + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] at::Tensor weight, // Not QTensor + // NOLINTNEXTLINE(performance-unnecessary-value-param) + [[maybe_unused]] std::optional> input_shape) { +#if AT_MKLDNN_ENABLED() + return pack_weight_to_fp16_onednn_tensor(weight, input_shape); +#else + TORCH_CHECK(false, "Unimplemented as onednn is not available."); +#endif + } +}; + TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { register_linear_params(); m.impl(TORCH_SELECTIVE_NAME("quantized::linear_prepack"), TORCH_FN(QLinearPackWeightInt8::run)); @@ -716,5 +748,9 @@ TORCH_LIBRARY_IMPL(onednn, CPU, m) { m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_prepack"), TORCH_FN(QLinearPackWeightInt8Onednn::run)); } +TORCH_LIBRARY_IMPL(onednn, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("onednn::linear_prepack_fp16"), TORCH_FN(QLinearPackWeightFp16Onednn::run)); +} + } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 05341366a9dfa..72dcda2b74de4 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -272,7 +272,11 @@ TORCH_LIBRARY(onednn, m) { // Linear prepack m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_prepack(Tensor weight, int[]? x_shape) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_prepack_fp16(Tensor weight, int[]? x_shape) -> Tensor")); + // Linear + m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_dynamic_fp16(Tensor x, Tensor w, Tensor? bias) -> Tensor")); + m.def(TORCH_SELECTIVE_SCHEMA("onednn::linear_relu_dynamic_fp16(Tensor x, Tensor w, Tensor? bias) -> Tensor")); // Linear with unary postop m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise(Tensor qx, float x_scale, int x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA("onednn::qlinear_pointwise.tensor(Tensor qx, Tensor x_scale, Tensor x_zero_point, Tensor qw, Tensor w_scale, Tensor w_zero_point, Tensor? bias, float output_scale, int output_zero_point, ScalarType? output_dtype, str post_op_name, Scalar?[] post_op_args, str post_op_algorithm) -> Tensor")); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index f7c7330a8c991..0e419989d3560 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -3750,6 +3750,39 @@ def test_dynamic_convtranspose3d(self): return # TODO: fix MakeDeConvOutputShape overflowing for convT3d with qnnpack self._test_qconv_op_impl(q_mod, dq_op, dim, dtype) + @skipIfNoONEDNN + def test_linear_dynamic_fp16_onednn(self): + + options = itertools.product( + (2, 4), # batch_size + (4, 5, 12), # input_channels + (4, 7, 8), # output_channels + (True, False), # use_bias + (True, False), # use_relu + ) + for batch_size, input_channels, output_channels, use_bias, use_relu in options: + qlinear_prepack = torch.ops.onednn.linear_prepack_fp16 + if use_relu: + qlinear_dynamic = torch.ops.onednn.linear_relu_dynamic_fp16 + else: + qlinear_dynamic = torch.ops.onednn.linear_dynamic_fp16 + + x = torch.randn(batch_size, input_channels) + w = torch.randn(output_channels, input_channels) + bias = torch.randn(output_channels) if use_bias else None + + w_packed = qlinear_prepack(w, x.shape) + out = qlinear_dynamic(x, w_packed, bias) + + # qlinear_dynamic_fp16 uses FP32 activation tensors and FP16 weight tensors + # output is FP32 + w_fp16 = w.to(torch.float16).to(torch.float32) + ref = F.linear(x, w_fp16, bias) + if use_relu: + ref.relu_() + + self.assertEqual(out, ref) + class TestQuantizedLinear(TestCase): def _test_qlinear_impl(self, batch_size, input_channels, output_channels, use_bias,