Skip to content

Commit

Permalink
add experimental matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
shivadbhavsar committed Jan 12, 2025
1 parent 6d02806 commit c3b6f22
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,46 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
}
};

struct find_mlir_reshape_pointwise
{
mlir_mode conv_mode = mlir_mode::none;
mlir_mode dot_mode = mlir_mode::none;
auto matcher() const
{
auto reshapes = reshaper_names();
// Dont want to move ops before slice
reshapes.erase("slice");
auto rsp_dot_or_conv =
match::name(reshapes)(
match::any_of[match::inputs()](is_mlir_dot(dot_mode), is_mlir_conv(conv_mode))
.bind("mlir_op"))
.bind("rsp");
return mlir_pointwise()(match::any_of[match::inputs()](rsp_dot_or_conv));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto rsp = r.instructions["rsp"];

if(not std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) {
return i->get_operator() == rsp->get_operator();
}))
return;

std::vector<instruction_ref> new_inps;
std::transform(ins->inputs().begin(),
ins->inputs().end(),
std::back_inserter(new_inps),
[&](auto i) { return i->inputs().front(); });

auto new_pw = mpm.get_module().insert_instruction(
ins, ins->get_operator(), new_inps, ins->module_inputs());

mpm.get_module().replace_instruction(ins, rsp->get_operator(), {new_pw});
}
};

struct find_pointwise_mlir
{
auto supported_pointwise() const { return mlir_input_pointwise(match::used_once()); }
Expand Down Expand Up @@ -1061,6 +1101,12 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
return std::max(m1, m2);
};

match::find_matches(
mpm,
find_mlir_reshape_pointwise{.conv_mode = get_mode("fused_convolution", mlir_mode::fast),
.dot_mode = get_mode("fused_dot", mlir_mode::fast)});
mpm.run_pass(dead_code_elimination{});

// Attention offloads; default disabled
if(mlir_attention_enabled(ctx) or enable_extra)
{
Expand Down

0 comments on commit c3b6f22

Please sign in to comment.