From e174012588d3857876311dc6f24730b147707e5b Mon Sep 17 00:00:00 2001 From: aarushjain29 Date: Wed, 18 Sep 2024 16:06:35 +0000 Subject: [PATCH] WIP for scalar_mul_add --- src/simplify_algebra.cpp | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index 7cae251c8b1..f9bd377d6a0 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -440,11 +440,44 @@ struct find_mul_add auto x_ins = r.instructions["x"]; assert(x_ins != b_ins); + if(a_ins->get_shape().scalar()) + { + return; + } + auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins); auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins); m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins); } }; +struct find_scalar_mul_add +{ + auto matcher() const + { + return match::name("mul")(match::either_arg(0, 1)( + match::name("add")( + match::either_arg(0, 1)( + match::any().bind("x"), + match::any_of(conv_const_weights(), match::is_constant()).bind("b")), + match::none_of(match::args(match::is_constant(), match::is_constant())), + match::used_once()), + + match::is_scalar().bind("a"))); + } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + auto a_ins = r.instructions["a"]; + auto b_ins = r.instructions["b"]; + auto x_ins = r.instructions["x"]; + assert(x_ins != b_ins); + + auto xb_ins = m.insert_instruction(ins, make_op("add"), x_ins, b_ins); + auto result_ins = m.insert_instruction(ins, make_op("mul"), a_ins, xb_ins); + m.replace_instruction(ins, result_ins); + } +}; struct find_dot_add { @@ -1595,7 +1628,8 @@ struct find_div_const } }; -struct find_unit_ops +struct +unit_ops { auto matcher() const {