Skip to content

Commit

Permalink
WIP for scalar_mul_add
Browse files Browse the repository at this point in the history
  • Loading branch information
aarushjain29 committed Sep 18, 2024
1 parent 7c2fdf5 commit e174012
Showing 1 changed file with 35 additions and 1 deletion.
36 changes: 35 additions & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -440,11 +440,44 @@ struct find_mul_add
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

if(a_ins->get_shape().scalar())
{
return;
}

auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
struct find_scalar_mul_add
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),

match::is_scalar().bind("a")));

Check warning on line 465 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / tidy

no member named 'is_scalar' in namespace 'migraphx::match' [clang-diagnostic-error]
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

auto xb_ins = m.insert_instruction(ins, make_op("add"), x_ins, b_ins);
auto result_ins = m.insert_instruction(ins, make_op("mul"), a_ins, xb_ins);
m.replace_instruction(ins, result_ins);
}
};

struct find_dot_add
{
Expand Down Expand Up @@ -1595,7 +1628,8 @@ struct find_div_const
}
};

struct find_unit_ops
struct
unit_ops
{
auto matcher() const
{
Expand Down

0 comments on commit e174012

Please sign in to comment.