From 0039b11a7db6713c9df7063945c2cfb703fd0f8c Mon Sep 17 00:00:00 2001 From: shivadbhavsar <105248561+shivadbhavsar@users.noreply.github.com> Date: Wed, 15 Nov 2023 11:12:16 -0800 Subject: [PATCH] Support per-axis quantization (#2390) Reworked the simplify_qdq pass to support: Per-axis quantization (ie. allow 1D scales and zero points) Allow broadcast and transpose ops between dq and quant_op --- src/include/migraphx/matcher.hpp | 19 ++ src/simplify_qdq.cpp | 138 +++++++--- test/quantization.cpp | 140 +++++----- test/simplify_qdq_test.cpp | 433 ++++++++++++++++++++++++++----- 4 files changed, 555 insertions(+), 175 deletions(-) diff --git a/src/include/migraphx/matcher.hpp b/src/include/migraphx/matcher.hpp index 5237dc9c5c6..a546d1b25e9 100644 --- a/src/include/migraphx/matcher.hpp +++ b/src/include/migraphx/matcher.hpp @@ -591,6 +591,19 @@ MIGRAPHX_PRED_MATCHER(same_input_shapes, instruction_ref ins) ins->inputs().begin(), ins->inputs().end(), [&](auto x) { return x->get_shape() == s; }); } +MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins) +{ + if(ins->name() != "@literal") + return false; + bool all_same = false; + ins->get_literal().visit([&](auto s) { + all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) { + return float_equal(scale, s.front()); + }); + }); + return all_same; +} + MIGRAPHX_BASIC_MATCHER(output, const matcher_context&, instruction_ref ins) { if(ins->outputs().size() == 1) @@ -844,6 +857,12 @@ auto skip_broadcasts_converts(Ms... ms) return skip(name("broadcast", "multibroadcast", "contiguous", "convert"))(ms...); } +template +auto skip_broadcasts_transposes_contiguous(Ms... ms) +{ + return skip(name("broadcast", "multibroadcast", "contiguous", "transpose"))(ms...); +} + template inline auto has_value(T x, float tolerance = 1e-6) { diff --git a/src/simplify_qdq.cpp b/src/simplify_qdq.cpp index 21d691f7d41..b054bc9fc62 100644 --- a/src/simplify_qdq.cpp +++ b/src/simplify_qdq.cpp @@ -45,77 +45,145 @@ std::unordered_set get_quantizable_op_names() return s; } -MIGRAPHX_PRED_MATCHER(has_same_value, instruction_ref ins) +struct match_find_quantizable_ops { - if(ins->name() != "@literal") - return false; - bool all_same = false; - ins->get_literal().visit([&](auto s) { - all_same = std::all_of(s.begin() + 1, s.end(), [&](const auto& scale) { - return float_equal(scale, s.front()); + static bool + is_valid_scale(instruction_ref scale, std::vector lens, std::size_t axis) + { + return scale->get_shape().scalar() or scale->get_shape().elements() == lens.at(axis); + } + + static bool is_valid_zero_point(instruction_ref zp) + { + if(not zp->can_eval()) + return false; + + bool all_zeros = false; + zp->eval().visit([&](auto z) { + all_zeros = + std::all_of(z.begin(), z.end(), [&](auto val) { return float_equal(val, 0); }); }); - }); - return all_same; -} + return all_zeros; + } -struct match_find_quantizable_ops -{ + static auto + scale_broadcast_op(instruction_ref scale, std::vector lens, std::size_t axis) + { + if(scale->get_shape().scalar()) + { + return migraphx::make_op("multibroadcast", {{"out_lens", lens}}); + } + else + { + return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}}); + } + } - static auto dequantizelinear_op(const std::string& name, const std::string& scale) + // Helper function to insert quantized versions of any broadcasts and transpose ops that + // occur between dequantizelinear and the quantized op + static auto + propagate_quantized_ins(module& m, const instruction_ref dqins, const instruction_ref qop) + { + auto qinp = dqins->inputs().front(); + auto next_ins = dqins; + + while(next_ins != qop) + { + if(next_ins->name() != "dequantizelinear") + { + qinp = m.insert_instruction(qop, next_ins->get_operator(), qinp); + } + next_ins = next_ins->outputs().front(); + } + return qinp; + } + + static auto dequantizelinear_op(const std::string& scale, const std::string& zp) { return match::name("dequantizelinear")( - match::arg(0)(match::skip(match::name("quantizelinear"))(match::any().bind(name))), - match::arg(1)(match::skip_broadcasts(has_same_value().bind(scale))), - match::arg(2)(match::skip_broadcasts(match::all_of(match::has_value(0))))); + match::arg(0)(match::skip(match::name("quantizelinear"))(match::any())), + match::arg(1)(match::skip_broadcasts(match::is_constant().bind(scale))), + match::arg(2)(match::skip_broadcasts(match::is_constant().bind(zp)))); } auto matcher() const { return match::name(get_quantizable_op_names())( - match::arg(0)(dequantizelinear_op("x1", "scale1")), - match::arg(1)(dequantizelinear_op("x2", "scale2"))); + match::arg(0)(match::skip_broadcasts_transposes_contiguous( + dequantizelinear_op("scale1", "zp1").bind("dq1"))), + match::arg(1)(match::skip_broadcasts_transposes_contiguous( + dequantizelinear_op("scale2", "zp2").bind("dq2")))); } void apply(module& m, const match::matcher_result& r) const { auto qop = r.result; - auto q1 = r.instructions["x1"]; - auto q2 = r.instructions["x2"]; + auto dq1 = r.instructions["dq1"]; + auto dq2 = r.instructions["dq2"]; auto scale1 = r.instructions["scale1"]; auto scale2 = r.instructions["scale2"]; + auto zp1 = r.instructions["zp1"]; + auto zp2 = r.instructions["zp2"]; // Only INT8 type currently supported - if(q1->get_shape().type() != migraphx::shape::int8_type or - q2->get_shape().type() != migraphx::shape::int8_type) + if(dq1->inputs().front()->get_shape().type() != migraphx::shape::int8_type or + dq2->inputs().front()->get_shape().type() != migraphx::shape::int8_type) return; - double scale; - visit_all(scale1->get_literal(), scale2->get_literal())( - [&](const auto s1, const auto s2) { scale = s1.front() * s2.front(); }); + // Only symmetric quantization supported (ie. non-zero zero_points not allowed) + if(not(is_valid_zero_point(zp1) and is_valid_zero_point(zp2))) + return; + // Only support scalar and 1D scales + if(scale1->get_shape().lens().size() != 1 or scale2->get_shape().lens().size() != 1) + return; + + // Propagate q1 and q2 through any broadcasts and transposes before qop auto qop_args = qop->inputs(); - qop_args.at(0) = q1; - qop_args.at(1) = q2; + qop_args.at(0) = propagate_quantized_ins(m, dq1, qop); + qop_args.at(1) = propagate_quantized_ins(m, dq2, qop); instruction_ref dq; - instruction_ref dq_scale; + instruction_ref out_scale; instruction_ref zero_point; if(qop->name() == "convolution") { auto conv_val = qop->get_operator().to_value(); dq = m.insert_instruction( qop, migraphx::make_op("quant_convolution", conv_val), qop_args); + auto out_lens = dq->get_shape().lens(); + + // Input scale should always be scalar and weight scale can be scalar or 1D of the + // same lens as the output channel dim (dim 1 in the output) + if(not(is_valid_scale(scale1, out_lens, 1) and is_valid_scale(scale2, out_lens, 1))) + return; + + auto s1_bcast = + m.insert_instruction(qop, scale_broadcast_op(scale1, out_lens, 1), scale1); + auto s2_bcast = + m.insert_instruction(qop, scale_broadcast_op(scale2, out_lens, 1), scale2); + + out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast); } else if(qop->name() == "dot") { - dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args); + dq = m.insert_instruction(qop, migraphx::make_op("quant_dot"), qop_args); + auto out_lens = dq->get_shape().lens(); + + // For (..., M, N) x (..., N, K) dot, only support cases where quantization axis is M + // for input1 and K for input 2 + if(not(is_valid_scale(scale1, out_lens, out_lens.size() - 2) and + is_valid_scale(scale2, out_lens, out_lens.size() - 1))) + return; + + auto s1_bcast = m.insert_instruction( + qop, scale_broadcast_op(scale1, out_lens, out_lens.size() - 2), scale1); + auto s2_bcast = m.insert_instruction( + qop, scale_broadcast_op(scale2, out_lens, out_lens.size() - 1), scale2); + + out_scale = m.insert_instruction(qop, migraphx::make_op("mul"), s1_bcast, s2_bcast); } - auto ins_type = qop->get_shape().type(); - dq_scale = m.add_literal(literal({ins_type}, {scale})); - auto lens = dq->get_shape().lens(); - auto scale_mb = - m.insert_instruction(qop, make_op("multibroadcast", {{"out_lens", lens}}), dq_scale); - dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, scale_mb); + dq = m.insert_instruction(qop, make_op("dequantizelinear"), dq, out_scale); m.replace_instruction(qop, dq); } }; diff --git a/test/quantization.cpp b/test/quantization.cpp index 64797d83aae..e148743a901 100644 --- a/test/quantization.cpp +++ b/test/quantization.cpp @@ -636,13 +636,12 @@ TEST_CASE(dot_float) migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale); auto zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); - auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); - auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b); - std::vector vec(sc.elements(), 100.0f); - auto dc = mm->add_literal(100.0f); - auto mdc = - mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sc.lens()}}), dc); - auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, mdc); + auto quant_b = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); + auto quant = mm->add_instruction(migraphx::make_op("quant_dot"), quant_a, quant_b); + auto scale_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), scale); + auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -717,24 +716,28 @@ TEST_CASE(dot_double_2args) auto pa = mm->add_parameter("a", sa); auto pb = mm->add_parameter("b", sb); - auto scale_a = mm->add_literal(10.0); - auto zp = mm->add_literal(static_cast(0)); - scale_a = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a); + auto scale_a_lit = mm->add_literal(10.0); + auto zp = mm->add_literal(static_cast(0)); + auto scale_a = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_a_lit); auto zp_a = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); - auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); - auto scale_b = mm->add_literal(5.0); - scale_b = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b); + auto qa = mm->add_instruction(migraphx::make_op("quantizelinear"), pa, scale_a, zp_a); + auto scale_b_lit = mm->add_literal(5.0); + auto scale_b = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), scale_b_lit); auto zp_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sb.lens()}}), zp); - auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); - auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb); - auto scale = mm->add_literal(50.0); - scale = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), scale); - auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, scale); + auto qb = mm->add_instruction(migraphx::make_op("quantizelinear"), pb, scale_b, zp_b); + auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qa, qb); + auto scale_a_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), + scale_a_lit); + auto scale_b_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), + scale_b_lit); + auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_a_mb, scale_b_mb); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale); mm->add_return({r}); return p; }; @@ -798,19 +801,16 @@ TEST_CASE(dot_half_1arg) migraphx::shape sa{migraphx::shape::half_type, {9, 9}}; auto x = mm->add_parameter("x", sa); - auto zp = mm->add_literal(static_cast(0)); - auto scale = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); - scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), - scale); + auto zp = mm->add_literal(static_cast(0)); + auto scale_lit = mm->add_literal(migraphx::literal({sa.type()}, {10.0})); + auto scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), scale_lit); zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sa.lens()}}), zp); - auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); - auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); - auto dq_scale = mm->add_literal(migraphx::literal({sa.type()}, {100.0})); - dq_scale = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), - dq_scale); - auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, dq_scale); + auto qx = mm->add_instruction(migraphx::make_op("quantizelinear"), x, scale, zp); + auto qdot = mm->add_instruction(migraphx::make_op("quant_dot"), qx, qx); + auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale, scale); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), qdot, out_scale); mm->add_return({r}); return p; }; @@ -851,10 +851,10 @@ TEST_CASE(conv_float) auto px = mm->add_parameter("x", sx); auto pw = mm->add_parameter("w", sw); - auto zp = mm->add_literal(static_cast(0)); - auto scale = mm->add_literal(10.0f); - scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), - scale); + auto zp = mm->add_literal(static_cast(0)); + auto scale_lit = mm->add_literal(10.0f); + auto scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit); zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); @@ -862,13 +862,11 @@ TEST_CASE(conv_float) auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); - migraphx::shape sc{migraphx::shape::float_type, {4, 4, 1, 1}}; - std::vector vec(sc.elements(), 100.0f); - migraphx::shape s_scale{migraphx::shape::float_type, sc.lens()}; - auto d_scale = mm->add_literal(100.0f); - d_scale = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale); - auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale); + auto scale_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), + scale_lit); + auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -930,20 +928,21 @@ TEST_CASE(conv_half) auto px = mm->add_parameter("x", sx); auto pw = mm->add_parameter("w", sw); - auto zp = mm->add_literal(static_cast(0)); - auto scale = mm->add_literal(migraphx::literal({sx.type()}, {10.0})); - scale = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), - scale); + auto zp = mm->add_literal(static_cast(0)); + auto scale_lit = mm->add_literal(migraphx::literal({sx.type()}, {10.0})); + auto scale = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), scale_lit); zp = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", sx.lens()}}), zp); auto quant_x = mm->add_instruction(migraphx::make_op("quantizelinear"), px, scale, zp); auto quant_w = mm->add_instruction(migraphx::make_op("quantizelinear"), pw, scale, zp); auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), quant_x, quant_w); - auto d_scale = mm->add_literal(migraphx::literal({sx.type()}, {100.0})); - d_scale = mm->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", {4, 4, 1, 1}}}), d_scale); - auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, d_scale); + auto scale_mb = mm->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), + scale_lit); + auto out_scale = mm->add_instruction(migraphx::make_op("mul"), scale_mb, scale_mb); + auto r = mm->add_instruction(migraphx::make_op("dequantizelinear"), quant, out_scale); mm->add_return({r}); return p; @@ -1185,12 +1184,12 @@ TEST_CASE(int8_subgraph) migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), s1); auto zpb = then_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sy.lens()}}), zp1); - auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); - auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb); - auto so = then_mod->add_literal(100.0f); - so = then_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so); - auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); + auto qb = then_mod->add_instruction(migraphx::make_op("quantizelinear"), b, sb, zpb); + auto qdot = then_mod->add_instruction(migraphx::make_op("quant_dot"), qa, qb); + auto s1_mb = then_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", qdot->get_shape().lens()}}), s1); + auto so = then_mod->add_instruction(migraphx::make_op("mul"), s1_mb, s1_mb); + auto r = then_mod->add_instruction(migraphx::make_op("dequantizelinear"), qdot, so); then_mod->add_return({r}); migraphx::shape sd{migraphx::shape::float_type, {2, 2, 4, 6}}; @@ -1199,24 +1198,25 @@ TEST_CASE(int8_subgraph) auto w = mm->add_parameter("w", sw); // else submod auto* else_mod = p.create_module("If_6_else"); - auto sax = else_mod->add_literal(2.0f); + auto sax_lit = else_mod->add_literal(2.0f); auto zp = else_mod->add_literal(static_cast(0)); - sax = else_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax); + auto sax = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), sax_lit); auto zpx = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sd.lens()}}), zp); - auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx); - auto ssw = else_mod->add_literal(1.66667f); - ssw = else_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw); + auto qx = else_mod->add_instruction(migraphx::make_op("quantizelinear"), x, sax, zpx); + auto ssw_lit = else_mod->add_literal(1.66667f); + auto ssw = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), ssw_lit); auto zpw = else_mod->add_instruction( migraphx::make_op("multibroadcast", {{"out_lens", sw.lens()}}), zp); - auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw); - auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw); - auto so1 = else_mod->add_literal(3.33333f); - so1 = else_mod->add_instruction( - migraphx::make_op("multibroadcast", {{"out_lens", sout.lens()}}), so1); - auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); + auto qw = else_mod->add_instruction(migraphx::make_op("quantizelinear"), w, ssw, zpw); + auto qconv = else_mod->add_instruction(migraphx::make_op("quant_convolution"), qx, qw); + auto ssw_mb = else_mod->add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", qconv->get_shape().lens()}}), + ssw_lit); + auto so1 = else_mod->add_instruction(migraphx::make_op("mul"), sax, ssw_mb); + auto r1 = else_mod->add_instruction(migraphx::make_op("dequantizelinear"), qconv, so1); else_mod->add_return({r1}); auto ret = mm->add_instruction(migraphx::make_op("if"), {cond}, {then_mod, else_mod}); diff --git a/test/simplify_qdq_test.cpp b/test/simplify_qdq_test.cpp index b9010e495f1..3cc4f77ff23 100644 --- a/test/simplify_qdq_test.cpp +++ b/test/simplify_qdq_test.cpp @@ -44,20 +44,34 @@ void run_pass(migraphx::module& m) sqdq.apply(m); } -migraphx::instruction_ref add_quantize_op(migraphx::module& m, - const std::string& name, - migraphx::instruction_ref x, +migraphx::instruction_ref broadcast_scale(migraphx::module& m, migraphx::instruction_ref scale, - migraphx::instruction_ref shift) + const std::vector& out_lens, + std::size_t axis) { - auto lens = x->get_shape().lens(); + if(scale->get_shape().lens() == out_lens) + return scale; + migraphx::instruction_ref scale_mb; - if(scale->get_shape().lens().front() == 1) + auto scale_lens = scale->get_shape().lens(); + if(scale_lens.front() == 1 and scale_lens.size() == 1) scale_mb = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale); + m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", out_lens}}), scale); else scale_mb = m.add_instruction( - migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale); + migraphx::make_op("broadcast", {{"axis", axis}, {"out_lens", out_lens}}), scale); + return scale_mb; +} + +migraphx::instruction_ref add_quantize_op(migraphx::module& m, + const std::string& name, + migraphx::instruction_ref x, + migraphx::instruction_ref scale, + migraphx::instruction_ref shift, + std::size_t q_axis = 1) +{ + auto lens = x->get_shape().lens(); + auto scale_mb = broadcast_scale(m, scale, lens, q_axis); auto shift_mb = m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), shift); return m.add_instruction(migraphx::make_op(name), x, scale_mb, shift_mb); @@ -66,19 +80,26 @@ migraphx::instruction_ref add_quantize_op(migraphx::module& m, migraphx::instruction_ref add_quantize_op(migraphx::module& m, const std::string& name, migraphx::instruction_ref x, - migraphx::instruction_ref scale) + migraphx::instruction_ref scale, + std::size_t q_axis = 1) { - auto lens = x->get_shape().lens(); - migraphx::instruction_ref scale_mb; - if(scale->get_shape().lens().front() == 1) - scale_mb = - m.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), scale); - else - scale_mb = m.add_instruction( - migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", lens}}), scale); + auto lens = x->get_shape().lens(); + auto scale_mb = broadcast_scale(m, scale, lens, q_axis); return m.add_instruction(migraphx::make_op(name), x, scale_mb); } +migraphx::instruction_ref add_scale_mul(migraphx::module& m, + migraphx::instruction_ref scale1, + migraphx::instruction_ref scale2, + std::size_t axis1, + std::size_t axis2, + const std::vector& out_lens) +{ + auto scale1_mb = broadcast_scale(m, scale1, out_lens, axis1); + auto scale2_mb = broadcast_scale(m, scale2, out_lens, axis2); + return m.add_instruction(migraphx::make_op("mul"), scale1_mb, scale2_mb); +} + TEST_CASE(remove_qdq) { migraphx::shape sh1{migraphx::shape::float_type, {100, 100}}; @@ -159,18 +180,62 @@ TEST_CASE(dot) m1.add_return({dot}); } + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); + + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); + auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); + m2.add_return({d3}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_multi_scale) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + migraphx::shape sh3{migraphx::shape::float_type, {1280}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0)); + auto scale2 = m1.add_literal(0.4f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 0); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 0); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2); + m1.add_return({dot}); + } + migraphx::module m2; { auto t1 = m2.add_parameter("t1", sh1); auto t2 = m2.add_parameter("t2", sh2); - auto scale = m2.add_literal(0.5f); + auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0)); + auto scale2 = m2.add_literal(0.4f); auto zero = m2.add_literal(std::int8_t{0}); - auto scale1 = m2.add_literal(0.25f); - auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); - auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); - auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 0); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 1); + + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); + auto out_scale = add_scale_mul(m2, scale1, scale2, 0, 1, dot->get_shape().lens()); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); m2.add_return({d3}); } @@ -178,6 +243,180 @@ TEST_CASE(dot) EXPECT(m1 == m2); } +TEST_CASE(dot_broadcasted) +{ + migraphx::shape sh1{migraphx::shape::float_type, {2, 1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto d2_mb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), d2); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); + auto q2_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 1000, 1024}}}), q2); + + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb); + auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); + m2.add_return({d3}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_transposed) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale, zero); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale, zero); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale, zero); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale, zero); + auto d2_t = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_t); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); + auto q2_t = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2); + + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_t); + auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); + m2.add_return({d3}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_multi_scale_transposed_broadcasted) +{ + migraphx::shape sh1{migraphx::shape::float_type, {2, 3, 1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1024, 1000}}; + migraphx::shape sh3{migraphx::shape::float_type, {1280}}; + migraphx::shape sh4{migraphx::shape::float_type, {1024}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0)); + auto scale2 = m1.add_literal(migraphx::generate_literal(sh4, 0)); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 2); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 2); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 0); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 0); + auto d2_t = + m1.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), d2); + auto d2_mb = m1.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), d2_t); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2_mb); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto scale1 = m2.add_literal(migraphx::generate_literal(sh3, 0)); + auto scale2 = m2.add_literal(migraphx::generate_literal(sh4, 0)); + auto zero = m2.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale1, zero, 2); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale2, zero, 0); + auto q2_t = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), q2); + auto q2_mb = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {2, 3, 1000, 1024}}}), q2_t); + + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2_mb); + auto out_scale = add_scale_mul(m2, scale1, scale2, 2, 3, dot->get_shape().lens()); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); + m2.add_return({d3}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(dot_multi_scale_unsupported_axis) +{ + migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; + migraphx::shape sh2{migraphx::shape::float_type, {1000, 1024}}; + migraphx::shape sh3{migraphx::shape::float_type, {1000}}; + + migraphx::module m1; + { + auto t1 = m1.add_parameter("t1", sh1); + auto t2 = m1.add_parameter("t2", sh2); + auto scale1 = m1.add_literal(migraphx::generate_literal(sh3, 0)); + auto scale2 = m1.add_literal(0.4f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m1, "quantizelinear", t1, scale1, zero, 1); + auto d1 = add_quantize_op(m1, "dequantizelinear", q1, scale1, zero, 1); + auto q2 = add_quantize_op(m1, "quantizelinear", t2, scale2, zero, 1); + auto d2 = add_quantize_op(m1, "dequantizelinear", q2, scale2, zero, 1); + auto dot = m1.add_instruction(migraphx::make_op("dot"), d1, d2); + m1.add_return({dot}); + } + + migraphx::module m2; + { + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto dot = m2.add_instruction(migraphx::make_op("dot"), t1, t2); + m2.add_return({dot}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + TEST_CASE(dot_non_zero_point) { migraphx::shape sh1{migraphx::shape::float_type, {1280, 1000}}; @@ -269,18 +508,18 @@ TEST_CASE(dot_add) migraphx::module m2; { - auto t1 = m2.add_parameter("t1", sh1); - auto t2 = m2.add_parameter("t2", sh2); - auto ab = m2.add_parameter("ab", sh3); - auto scale = m2.add_literal(0.5f); - auto zero = m2.add_literal(std::int8_t{0}); - auto scale1 = m2.add_literal(0.25f); - - auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); - auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); - auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); - auto d3 = add_quantize_op(m2, "dequantizelinear", dot, scale1); - auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); + auto t1 = m2.add_parameter("t1", sh1); + auto t2 = m2.add_parameter("t2", sh2); + auto ab = m2.add_parameter("ab", sh3); + auto scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + + auto q1 = add_quantize_op(m2, "quantizelinear", t1, scale, zero); + auto q2 = add_quantize_op(m2, "quantizelinear", t2, scale, zero); + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q1, q2); + auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, dot->get_shape().lens()); + auto d3 = add_quantize_op(m2, "dequantizelinear", dot, out_scale); + auto add = m2.add_instruction(migraphx::make_op("add"), d3, ab); m2.add_return({add}); } @@ -320,26 +559,80 @@ TEST_CASE(conv) auto weights = m2.add_parameter("weights", s4); auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); - auto scale1 = m2.add_literal(0.25f); - auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); - auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + q1, + weights); + auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); + m2.add_return({d6}); + } + + run_pass(m1); + EXPECT(m1 == m2); +} + +TEST_CASE(conv_multi_scale) +{ + migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; + migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; + migraphx::shape s8{migraphx::shape::float_type, {1280}}; + + migraphx::module m1; + { + auto input = m1.add_parameter("input", s7); + auto weights = m1.add_parameter("weights", s4); + auto w_scale = m1.add_literal(migraphx::generate_literal(s8, 0)); + auto inp_scale = m1.add_literal(0.5f); + auto zero = m1.add_literal(std::int8_t{0}); + + auto d1 = add_quantize_op(m1, "dequantizelinear", weights, w_scale, zero, 0); + auto q1 = add_quantize_op(m1, "quantizelinear", input, inp_scale, zero); + auto d5 = add_quantize_op(m1, "dequantizelinear", q1, inp_scale, zero); + auto c1 = m1.add_instruction(migraphx::make_op("convolution", {{"padding", {0, 0, 0, 0}}, {"stride", {1, 1}}, {"dilation", {1, 1}}, {"group", 1}, {"padding_mode", 0}}), - q1, + d5, + d1); + m1.add_return({c1}); + } + + migraphx::module m2; + { + auto input = m2.add_parameter("input", s7); + auto weights = m2.add_parameter("weights", s4); + auto w_scale = m2.add_literal(migraphx::generate_literal(s8, 0)); + auto inp_scale = m2.add_literal(0.5f); + auto zero = m2.add_literal(std::int8_t{0}); + + auto q_inp = add_quantize_op(m2, "quantizelinear", input, inp_scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), + q_inp, weights); - auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); - m2.add_return({d6}); + auto out_scale = add_scale_mul(m2, inp_scale, w_scale, 1, 1, c1->get_shape().lens()); + auto d1 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); + m2.add_return({d1}); } run_pass(m1); EXPECT(m1 == m2); } -TEST_CASE(conv_multi_scale) +TEST_CASE(conv_multi_scale_unsupported_axis) { migraphx::shape s4{migraphx::shape::int8_type, {1280, 320, 1, 1}}; migraphx::shape s7{migraphx::shape::float_type, {1, 320, 7, 7}}; @@ -430,20 +723,20 @@ TEST_CASE(conv_bias_add) auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); auto zero32 = m2.add_literal(std::int32_t{0}); - auto scale1 = m2.add_literal(0.25f); - auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); - auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); - auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", - {{"padding", {0, 0, 0, 0}}, - {"stride", {1, 1}}, - {"dilation", {1, 1}}, - {"group", 1}, - {"padding_mode", 0}}), + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); + auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), q1, weights); - auto d6 = add_quantize_op(m2, "dequantizelinear", c1, scale1); - auto b1 = m2.add_instruction( + auto out_scale = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); + auto d6 = add_quantize_op(m2, "dequantizelinear", c1, out_scale); + auto b1 = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); auto a1 = m2.add_instruction(migraphx::make_op("add"), d6, b1); m2.add_return({a1}); @@ -519,22 +812,21 @@ TEST_CASE(conv_pooling_dot) auto scale = m2.add_literal(0.5f); auto zero = m2.add_literal(std::int8_t{0}); auto zero32 = m2.add_literal(std::int32_t{0}); - auto scale1 = m2.add_literal(0.25f); - auto scale2 = m2.add_literal(0.25f); - auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); - auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); - auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); - auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", - {{"padding", {0, 0, 0, 0}}, - {"stride", {1, 1}}, - {"dilation", {1, 1}}, - {"group", 1}, - {"padding_mode", 0}}), + auto d2 = add_quantize_op(m2, "dequantizelinear", bias, scale, zero32); + auto d3 = add_quantize_op(m2, "dequantizelinear", ab, scale, zero); + auto q1 = add_quantize_op(m2, "quantizelinear", input, scale, zero); + auto c1 = m2.add_instruction(migraphx::make_op("quant_convolution", + {{"padding", {0, 0, 0, 0}}, + {"stride", {1, 1}}, + {"dilation", {1, 1}}, + {"group", 1}, + {"padding_mode", 0}}), q1, weights); - auto d5 = add_quantize_op(m2, "dequantizelinear", c1, scale1); - auto bc1 = m2.add_instruction( + auto out_scale1 = add_scale_mul(m2, scale, scale, 1, 1, c1->get_shape().lens()); + auto d5 = add_quantize_op(m2, "dequantizelinear", c1, out_scale1); + auto bc1 = m2.add_instruction( migraphx::make_op("broadcast", {{"axis", 1}, {"out_lens", {1, 1280, 7, 7}}}), d2); auto a1 = m2.add_instruction(migraphx::make_op("add"), d5, bc1); auto ap = @@ -545,10 +837,11 @@ TEST_CASE(conv_pooling_dot) {"lengths", {7, 7}}, {"ceil_mode", 0}}), a1); - auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); - auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); - auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); - auto d9 = add_quantize_op(m2, "dequantizelinear", dot, scale2); + auto fl = m2.add_instruction(migraphx::make_op("flatten", {{"axis", 1}}), ap); + auto q4 = add_quantize_op(m2, "quantizelinear", fl, scale, zero); + auto dot = m2.add_instruction(migraphx::make_op("quant_dot"), q4, db); + auto out_scale2 = add_scale_mul(m2, scale, scale, 1, 0, dot->get_shape().lens()); + auto d9 = add_quantize_op(m2, "dequantizelinear", dot, out_scale2); auto mb1 = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 1000}}}), d3); auto a2 = m2.add_instruction(migraphx::make_op("add"), d9, mb1);