From 84ec399c028547c1d38e725cb597f6839f476e6a Mon Sep 17 00:00:00 2001 From: Chris Austen Date: Wed, 3 Apr 2024 14:24:09 -0400 Subject: [PATCH] Fix parse matmulinteger 611 (#2934) --- src/eliminate_data_type.cpp | 3 +- src/onnx/parse_matmul.cpp | 140 ++++++++++++------ src/targets/gpu/target.cpp | 2 - test/onnx/gen_onnx.py | 83 ++++++++++- ...matmulinteger_int8_uint8_dual_zp_test.onnx | 18 +++ ...linteger_int8_uint8_one_zp_error_test.onnx | 17 +++ .../matmulinteger_int8_uint8_one_zp_test.onnx | 17 +++ test/onnx/matmulinteger_int8_uint8_test.onnx | 16 ++ .../matmulinteger_invalid_type_error.onnx | Bin 0 -> 172 bytes test/onnx/matmulinteger_uns_zp_test.onnx | Bin 178 -> 178 bytes .../onnx/parse/matmulinteger_dual_zp_test.cpp | 66 +++++++++ .../matmulinteger_invalid_type_error.cpp | 30 ++++ test/onnx/parse/matmulinteger_one_zp_test.cpp | 62 ++++++++ .../matmulinteger_zp_mismatch_error_test.cpp | 31 ++++ .../matmulinteger_int8_uint8_dual_zp_test.cpp | 48 ++++++ .../matmulinteger_int8_uint8_one_zp_test.cpp | 48 ++++++ .../verify/matmulinteger_int8_uint8_test.cpp | 48 ++++++ .../onnx/verify/matmulinteger_uns_zp_test.cpp | 2 +- 18 files changed, 583 insertions(+), 48 deletions(-) create mode 100644 test/onnx/matmulinteger_int8_uint8_dual_zp_test.onnx create mode 100644 test/onnx/matmulinteger_int8_uint8_one_zp_error_test.onnx create mode 100644 test/onnx/matmulinteger_int8_uint8_one_zp_test.onnx create mode 100644 test/onnx/matmulinteger_int8_uint8_test.onnx create mode 100644 test/onnx/matmulinteger_invalid_type_error.onnx create mode 100644 test/onnx/parse/matmulinteger_dual_zp_test.cpp create mode 100644 test/onnx/parse/matmulinteger_invalid_type_error.cpp create mode 100644 test/onnx/parse/matmulinteger_one_zp_test.cpp create mode 100644 test/onnx/parse/matmulinteger_zp_mismatch_error_test.cpp create mode 100644 test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp create mode 100644 test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp create mode 100644 test/onnx/verify/matmulinteger_int8_uint8_test.cpp diff --git a/src/eliminate_data_type.cpp b/src/eliminate_data_type.cpp index fc8144489c4..28d0a831e2f 100644 --- a/src/eliminate_data_type.cpp +++ b/src/eliminate_data_type.cpp @@ -108,7 +108,8 @@ void eliminate_data_type::apply(module& m) const "scatternd_add", "scatternd_mul", "scatternd_none", - "select_module"}; + "select_module", + "quantizelinear"}; if(unsupported_types.empty()) return; diff --git a/src/onnx/parse_matmul.cpp b/src/onnx/parse_matmul.cpp index d1fdbe8e2c4..fa6dd8e0bf0 100644 --- a/src/onnx/parse_matmul.cpp +++ b/src/onnx/parse_matmul.cpp @@ -38,6 +38,77 @@ struct parse_matmul : op_parser return {{"MatMul", "dot"}, {"MatMulInteger", "quant_dot"}}; } + static void broadcast_dimensions(const onnx_parser::node_info& info, + const std::vector& s0_lens, + const std::vector& s1_lens, + const instruction_ref& a0, + const instruction_ref& a1, + instruction_ref& ba0, + instruction_ref& ba1) + { + // try broadcasting if dimensions other than last two do not match + if(not std::equal( + s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend())) + { + auto l0_it = s0_lens.begin() + s0_lens.size() - 2; + std::vector l0_broadcasted_lens(s0_lens.begin(), l0_it); + auto l1_it = s1_lens.begin() + s1_lens.size() - 2; + std::vector l1_broadcasted_lens(s1_lens.begin(), l1_it); + auto output_lens = compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); + l0_broadcasted_lens = output_lens; + l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end()); + l1_broadcasted_lens = output_lens; + l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end()); + if(s0_lens != l0_broadcasted_lens) + { + ba0 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0); + } + if(s1_lens != l1_broadcasted_lens) + { + ba1 = info.add_instruction( + make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1); + } + } + } + + // Convert to int16 prior to a shift to ensure we preserve accuracy here then + // convert back to int8 + static instruction_ref add_int8_shift(const onnx_parser::node_info& info, + instruction_ref& unshifted_input) + { + auto int8_shift = info.add_literal( + migraphx::literal{migraphx::shape{migraphx::shape::int16_type}, {-128}}); + + auto unshifted_input_int16 = info.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int16_type}}), + unshifted_input); + + auto input_shifted_int16 = info.add_common_op("add", unshifted_input_int16, int8_shift); + + return info.add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), + input_shifted_int16); + } + + static instruction_ref set_bias_arg(const onnx_parser::node_info& info, + const std::vector& args, + const int index, + const instruction_ref& input) + { + if(args.size() > index) + { + instruction_ref bias_arg = args[index]; + if(bias_arg->get_shape().type() != input->get_shape().type()) + { + MIGRAPHX_THROW("PARSE_QUANT_DOT: zero point must be the same type as data"); + } + + return info.add_common_op("sub", input, bias_arg); + } + return input; + } + instruction_ref parse(const op_desc& opd, const onnx_parser& /*parser*/, const onnx_parser::node_info& info, @@ -85,55 +156,40 @@ struct parse_matmul : op_parser { auto s0_lens = a0->get_shape().lens(); auto s1_lens = a1->get_shape().lens(); - instruction_ref ba0 = a0; - instruction_ref ba1 = a1; - // try broadcasting if dimensions other than last two do not match - if(not std::equal( - s0_lens.rbegin() + 2, s0_lens.rend(), s1_lens.rbegin() + 2, s1_lens.rend())) + + if(not is_quant_dot and args.size() > 2) { - auto l0_it = s0_lens.begin() + s0_lens.size() - 2; - std::vector l0_broadcasted_lens(s0_lens.begin(), l0_it); - auto l1_it = s1_lens.begin() + s1_lens.size() - 2; - std::vector l1_broadcasted_lens(s1_lens.begin(), l1_it); - auto output_lens = - compute_broadcasted_lens(l0_broadcasted_lens, l1_broadcasted_lens); - l0_broadcasted_lens = output_lens; - l0_broadcasted_lens.insert(l0_broadcasted_lens.end(), l0_it, s0_lens.end()); - l1_broadcasted_lens = output_lens; - l1_broadcasted_lens.insert(l1_broadcasted_lens.end(), l1_it, s1_lens.end()); - if(s0_lens != l0_broadcasted_lens) - { - ba0 = info.add_instruction( - make_op("multibroadcast", {{"out_lens", l0_broadcasted_lens}}), a0); - } - if(s1_lens != l1_broadcasted_lens) - { - ba1 = info.add_instruction( - make_op("multibroadcast", {{"out_lens", l1_broadcasted_lens}}), a1); - } + MIGRAPHX_THROW("PARSE_MATMUL: Bias Args not supported for MatMul"); } - // parse a_zero_point and b_zero_point values - if(args.size() > 2) + instruction_ref ba0 = set_bias_arg(info, args, 2, a0); + instruction_ref ba1 = set_bias_arg(info, args, 3, a1); + + // Only INT8 or UINT8 type currently supported + std::set supported_types = {migraphx::shape::uint8_type, + migraphx::shape::int8_type}; + const auto ba0_type = ba0->get_shape().type(); + const auto ba1_type = ba1->get_shape().type(); + + if(is_quant_dot and + (not contains(supported_types, ba0_type) or not contains(supported_types, ba1_type))) { - ba0 = info.add_instruction( - make_op("convert", {{"target_type", migraphx::shape::float_type}}), ba0); - - ba0 = info.add_common_op("sub", ba0, args[2]); - if(args.size() > 3) - { - ba1 = info.add_instruction( - make_op("convert", {{"target_type", migraphx::shape::float_type}}), ba1); - ba1 = info.add_common_op("sub", ba1, args[3]); - } - dot_res = info.add_instruction(make_op("dot"), ba0, ba1); - dot_res = info.add_instruction( - make_op("convert", {{"target_type", migraphx::shape::int32_type}}), dot_res); + MIGRAPHX_THROW("PARSE_MATMULINTEGER: Unsupported type"); } - else + + auto is_same_type = (ba0_type == ba1_type); + + if(is_quant_dot and not is_same_type) { - dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1); + if(ba0_type == migraphx::shape::uint8_type) + ba0 = add_int8_shift(info, ba0); + + if(ba1_type == migraphx::shape::uint8_type) + ba1 = add_int8_shift(info, ba1); } + + broadcast_dimensions(info, s0_lens, s1_lens, a0, a1, ba0, ba1); + dot_res = info.add_instruction(make_op(opd.op_name), ba0, ba1); } // squeeze the appended or prepended dimensions diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 2cf139aceee..b8f5ba7f38b 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -91,7 +91,6 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti unsupported_types.erase(shape::type_t::half_type); unsupported_types.erase(shape::type_t::bool_type); unsupported_types.erase(shape::type_t::int8_type); - unsupported_types.erase(shape::type_t::uint8_type); unsupported_types.erase(shape::type_t::int32_type); unsupported_types.erase(shape::type_t::tuple_type); // whiltelist supported Ops for the FP8 @@ -131,7 +130,6 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti enable_pass(not mlir_enabled(), rewrite_quantization{}), dead_code_elimination{}, // workaround for rocBLAS unsupported error when using uint8 in quant_dot - eliminate_data_type{{migraphx::shape::uint8_type}, shape::float_type, {"quant_dot"}}, eliminate_data_type{unsupported_types, shape::type_t::float_type}, simplify_reshapes{}, eliminate_identity{}, diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index db5ea150ab5..9b2903d0a32 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -5130,6 +5130,21 @@ def matmulinteger_dyn_error(): return ([node], [m1, m2], [y]) +@onnx_test() +def matmulinteger_invalid_type_error(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [None, 6, 16]) + m2 = helper.make_tensor_value_info('2', TensorProto.INT16, [None, 16, 8]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [None, 6, 8]) + + node = onnx.helper.make_node( + 'MatMulInteger', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + @onnx_test() def matmulinteger_uns_test(): m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3]) @@ -5145,12 +5160,76 @@ def matmulinteger_uns_test(): return ([node], [m1, m2], [y]) +@onnx_test() +def matmulinteger_int8_uint8_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulInteger', + inputs=['1', '2'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y]) + + @onnx_test() def matmulinteger_uns_zp_test(): m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3]) m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) - zp1 = helper.make_tensor('3', TensorProto.UINT8, [], [12]) - zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [0]) + zp1 = helper.make_tensor('3', TensorProto.UINT8, [], [0]) + zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [1]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulInteger', + inputs=['1', '2', '3', '4'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y], [zp1, zp2]) + + +@onnx_test() +def matmulinteger_int8_uint8_one_zp_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + zp1 = helper.make_tensor('3', TensorProto.INT8, [], [5]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulInteger', + inputs=['1', '2', '3'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y], [zp1]) + + +@onnx_test() +def matmulinteger_int8_uint8_one_zp_error_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.UINT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + zp1 = helper.make_tensor('3', TensorProto.INT8, [], [5]) + y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) + + node = onnx.helper.make_node( + 'MatMulInteger', + inputs=['1', '2', '3'], + outputs=['y'], + ) + + return ([node], [m1, m2], [y], [zp1]) + + +@onnx_test() +def matmulinteger_int8_uint8_dual_zp_test(): + m1 = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3]) + m2 = helper.make_tensor_value_info('2', TensorProto.UINT8, [3, 2]) + zp1 = helper.make_tensor('3', TensorProto.INT8, [], [1]) + zp2 = helper.make_tensor('4', TensorProto.UINT8, [], [1]) y = helper.make_tensor_value_info('y', TensorProto.INT32, [4, 2]) node = onnx.helper.make_node( diff --git a/test/onnx/matmulinteger_int8_uint8_dual_zp_test.onnx b/test/onnx/matmulinteger_int8_uint8_dual_zp_test.onnx new file mode 100644 index 00000000000..bd551d70b6a --- /dev/null +++ b/test/onnx/matmulinteger_int8_uint8_dual_zp_test.onnx @@ -0,0 +1,18 @@ + %matmulinteger_int8_uint8_dual_zp_test:š + +1 +2 +3 +4y" MatMulInteger%matmulinteger_int8_uint8_dual_zp_test**B3**B4Z +1 +  + +Z +2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulinteger_int8_uint8_one_zp_error_test.onnx b/test/onnx/matmulinteger_int8_uint8_one_zp_error_test.onnx new file mode 100644 index 00000000000..0df62a14e4b --- /dev/null +++ b/test/onnx/matmulinteger_int8_uint8_one_zp_error_test.onnx @@ -0,0 +1,17 @@ + *matmulinteger_int8_uint8_one_zp_error_test:’ + +1 +2 +3y" MatMulInteger*matmulinteger_int8_uint8_one_zp_error_test**B3Z +1 +  + +Z +2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulinteger_int8_uint8_one_zp_test.onnx b/test/onnx/matmulinteger_int8_uint8_one_zp_test.onnx new file mode 100644 index 00000000000..851736a0c4d --- /dev/null +++ b/test/onnx/matmulinteger_int8_uint8_one_zp_test.onnx @@ -0,0 +1,17 @@ + $matmulinteger_int8_uint8_one_zp_test:Œ + +1 +2 +3y" MatMulInteger$matmulinteger_int8_uint8_one_zp_test**B3Z +1 +  + +Z +2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulinteger_int8_uint8_test.onnx b/test/onnx/matmulinteger_int8_uint8_test.onnx new file mode 100644 index 00000000000..4c8e3f074fe --- /dev/null +++ b/test/onnx/matmulinteger_int8_uint8_test.onnx @@ -0,0 +1,16 @@ + matmulinteger_int8_uint8_test:x + +1 +2y" MatMulIntegermatmulinteger_int8_uint8_testZ +1 +  + +Z +2 +  + +b +y +  + +B \ No newline at end of file diff --git a/test/onnx/matmulinteger_invalid_type_error.onnx b/test/onnx/matmulinteger_invalid_type_error.onnx new file mode 100644 index 0000000000000000000000000000000000000000..3715de39176a99d4c9faa5a90ac8935a0c4edd6b GIT binary patch literal 172 zcmdNBO;0U~&&(@J%*jlNFR3g@jZZBq$}h5NWaN_IVl?DpG!kO0 zRO0nbEb%SP@r0-+tUXE;s8>jUi;sg@h>MGXi;05`hy_5xMqpu9sIUM~m?H@&TnQFt MLsHG*#3UdL0Q)p2H2?qr literal 0 HcmV?d00001 diff --git a/test/onnx/matmulinteger_uns_zp_test.onnx b/test/onnx/matmulinteger_uns_zp_test.onnx index bdc96166b2a5919111b5fb2002aaf71ec171e2de..ead7f229d6d00fb841376481c2d8651a335341ef 100644 GIT binary patch delta 21 ccmdnQxQTH>A~%B*qp=o;0FxFYA~%l{qp=o;0FxFY!^9jn065zOmjD0& diff --git a/test/onnx/parse/matmulinteger_dual_zp_test.cpp b/test/onnx/parse/matmulinteger_dual_zp_test.cpp new file mode 100644 index 00000000000..733dcde92a2 --- /dev/null +++ b/test/onnx/parse/matmulinteger_dual_zp_test.cpp @@ -0,0 +1,66 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulinteger_dual_zp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto l2 = mm->add_literal( + migraphx::literal(migraphx::shape{migraphx::shape::int8_type, {1}, {0}}, {1})); + auto l3 = mm->add_literal( + migraphx::literal(migraphx::shape{migraphx::shape::uint8_type, {1}, {0}}, {1})); + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint8_type, {3, 2}}); + + auto mbr1 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 3}}}), l2); + auto sub1 = mm->add_instruction(migraphx::make_op("sub"), l0, mbr1); + auto mbr2 = + mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), l3); + auto sub2 = mm->add_instruction(migraphx::make_op("sub"), l1, mbr2); + + auto int8_shift = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int16_type}, {-128}}); + + auto unshifted_input_int16 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int16_type}}), sub2); + + auto mbr3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), + int8_shift); + auto input_shifted_int16 = + mm->add_instruction(migraphx::make_op("add"), unshifted_input_int16, mbr3); + + l1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), + input_shifted_int16); + + mm->add_instruction(migraphx::make_op("quant_dot"), sub1, l1); + + auto prog = optimize_onnx("matmulinteger_int8_uint8_dual_zp_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulinteger_invalid_type_error.cpp b/test/onnx/parse/matmulinteger_invalid_type_error.cpp new file mode 100644 index 00000000000..bcf106d026e --- /dev/null +++ b/test/onnx/parse/matmulinteger_invalid_type_error.cpp @@ -0,0 +1,30 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulinteger_type_error_test) +{ + EXPECT(test::throws([&] { migraphx::parse_onnx("matmulinteger_invalid_type_error.onnx"); })); +} diff --git a/test/onnx/parse/matmulinteger_one_zp_test.cpp b/test/onnx/parse/matmulinteger_one_zp_test.cpp new file mode 100644 index 00000000000..6933308971e --- /dev/null +++ b/test/onnx/parse/matmulinteger_one_zp_test.cpp @@ -0,0 +1,62 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulinteger_one_zp_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + + auto l2 = mm->add_literal( + migraphx::literal(migraphx::shape{migraphx::shape::int8_type, {1}, {0}}, {5})); + + auto l0 = mm->add_parameter("1", migraphx::shape{migraphx::shape::int8_type, {4, 3}}); + auto l1 = mm->add_parameter("2", migraphx::shape{migraphx::shape::uint8_type, {3, 2}}); + + auto mb1 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {4, 3}}}), l2); + auto sub = mm->add_instruction(migraphx::make_op("sub"), l0, mb1); + + auto int8_shift = + mm->add_literal(migraphx::literal{migraphx::shape{migraphx::shape::int16_type}, {-128}}); + + auto unshifted_input_int16 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int16_type}}), l1); + + auto mbr3 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {3, 2}}}), + int8_shift); + + auto input_shifted_int16 = + mm->add_instruction(migraphx::make_op("add"), unshifted_input_int16, mbr3); + + l1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::int8_type}}), + input_shifted_int16); + + mm->add_instruction(migraphx::make_op("quant_dot"), sub, l1); + + auto prog = optimize_onnx("matmulinteger_int8_uint8_one_zp_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/matmulinteger_zp_mismatch_error_test.cpp b/test/onnx/parse/matmulinteger_zp_mismatch_error_test.cpp new file mode 100644 index 00000000000..688cf68385b --- /dev/null +++ b/test/onnx/parse/matmulinteger_zp_mismatch_error_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(matmulinteger_zp_mismatch_error_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("matmulinteger_invalid_int8_uint8_one_error_test.onnx"); })); +} diff --git a/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp b/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp new file mode 100644 index 00000000000..e20138c73d2 --- /dev/null +++ b/test/onnx/verify/matmulinteger_int8_uint8_dual_zp_test.cpp @@ -0,0 +1,48 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(matmulinteger_int8_uint8_dual_zp_test) +{ + migraphx::program p = migraphx::parse_onnx("matmulinteger_int8_uint8_dual_zp_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; + std::vector data0 = {-1, 5, -9, -2, 6, 10, -3, 7, -11, -4, 8, 0}; + migraphx::shape s1{migraphx::shape::uint8_type, {3, 2}}; + std::vector data1 = {128, 129, 126, 131, 124, 133}; + + migraphx::parameter_map pp; + pp["1"] = migraphx::argument(s0, data0.data()); + pp["2"] = migraphx::argument(s1, data1.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {40, -32, -57, 46, 46, -36, -11, 10}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp b/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp new file mode 100644 index 00000000000..6777df675c4 --- /dev/null +++ b/test/onnx/verify/matmulinteger_int8_uint8_one_zp_test.cpp @@ -0,0 +1,48 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(matmulinteger_int8_uint8_one_zp_test) +{ + migraphx::program p = migraphx::parse_onnx("matmulinteger_int8_uint8_one_zp_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; + std::vector data0 = {-1, 5, -9, -2, 6, 10, -3, 7, -11, -4, 8, 0}; + migraphx::shape s1{migraphx::shape::uint8_type, {3, 2}}; + std::vector data1 = {128, 129, 126, 131, 124, 133}; + + migraphx::parameter_map pp; + pp["1"] = migraphx::argument(s0, data0.data()); + pp["2"] = migraphx::argument(s1, data1.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {56, -76, -22, 21, 60, -82, 14, -25}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/matmulinteger_int8_uint8_test.cpp b/test/onnx/verify/matmulinteger_int8_uint8_test.cpp new file mode 100644 index 00000000000..5448942d988 --- /dev/null +++ b/test/onnx/verify/matmulinteger_int8_uint8_test.cpp @@ -0,0 +1,48 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include + +TEST_CASE(matmulinteger_int8_uint8_test) +{ + migraphx::program p = migraphx::parse_onnx("matmulinteger_int8_uint8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s0{migraphx::shape::int8_type, {4, 3}}; + std::vector data0 = {-1, 5, -9, -2, 6, 10, -3, 7, -11, -4, 8, 0}; + migraphx::shape s1{migraphx::shape::uint8_type, {3, 2}}; + std::vector data1 = {128, 129, 126, 131, 124, 133}; + + migraphx::parameter_map pp; + pp["1"] = migraphx::argument(s0, data0.data()); + pp["2"] = migraphx::argument(s1, data1.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold = {26, -31, -52, 66, 30, -37, -16, 20}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/matmulinteger_uns_zp_test.cpp b/test/onnx/verify/matmulinteger_uns_zp_test.cpp index 86e06340fa9..0eae90d71cd 100644 --- a/test/onnx/verify/matmulinteger_uns_zp_test.cpp +++ b/test/onnx/verify/matmulinteger_uns_zp_test.cpp @@ -43,6 +43,6 @@ TEST_CASE(matmulinteger_uns_zp_test) auto result = p.eval(pp).back(); std::vector result_vector; result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); - std::vector gold = {-38, -83, -44, -98, -50, -113, -56, -128}; + std::vector gold = {13, 76, 10, 64, 7, 52, 4, 40}; EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); }