Skip to content

Commit

Permalink
Disable dot/mul optimizations when there is int4 weights (#3645)
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 authored Nov 22, 2024
1 parent 0ad8073 commit d688810
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,22 @@ auto conv_const_weights()
match::args(match::none_of(match::is_constant()), match::is_constant().bind("w")));
}

auto from_int4()
{
return match::make_predicate_matcher([](instruction_ref start) {
return fix<bool>([&](auto self, instruction_ref ins) {
auto alias = instruction::get_output_alias(ins);
if(contains({"reshape", "dequantizelinear"}, alias->name()))
return self(alias->inputs().front());
if(alias->name() == "concat")
return all_of(alias->inputs(), self);
return alias->name() == "unpack_int4";
})(start);
});
}

auto not_from_int4() { return match::none_of(from_int4()); }

auto reduction() { return match::name_contains("reduce"); }

// conv(x, w) * a => conv(x, a * w)
Expand Down Expand Up @@ -208,8 +224,8 @@ struct find_mul_dot
{
auto matcher() const
{
auto is_dot_const_inputs =
match::name("dot")(match::any_of[match::inputs()](match::is_constant()));
auto constant = match::is_constant(not_from_int4());
auto is_dot_const_inputs = match::name("dot")(match::any_of[match::inputs()](constant));
return match::name("mul")(match::either_arg(0, 1)(
is_dot_const_inputs.bind("dot"), match::name("broadcast", "multibroadcast").bind("c")));
}
Expand Down Expand Up @@ -358,7 +374,8 @@ struct find_dot_mul
match::used_once(),
match::either_arg(0, 1)(const_broadcast.bind("d"),
match::none_of(match::is_constant()).bind("z")));
return match::name("dot")(match::either_arg(0, 1)(mul, match::is_constant().bind("c")));
return match::name("dot")(
match::either_arg(0, 1)(mul, match::is_constant(not_from_int4()).bind("c")));
}

void apply(module& m, const match::matcher_result& r) const
Expand Down
66 changes: 66 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4024,6 +4024,72 @@ TEST_CASE(mul_dot_b_not_k_broadcast)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(mul_dot_a_int4_dq)
{
migraphx::shape as{migraphx::shape::float_type, {1, 32, 4096}};
migraphx::shape bs{migraphx::shape::int8_type, {22016, 2048}};
migraphx::shape cs{migraphx::shape::float_type, {22016, 4096}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);

auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4096}}));
auto litb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);

auto b = m1.add_literal(migraphx::generate_literal(bs));
auto unpack = m1.add_instruction(migraphx::make_op("unpack_int4"), b);
auto scales = m1.add_literal(migraphx::generate_literal(cs));
auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scales);
auto unsqueeze = m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), dq);
auto transpose = m1.add_instruction(
migraphx::make_op("transpose", {{"permutation", {0, 2, 1}}}), unsqueeze);
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, transpose);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(mul_dot_a_int4_dq_concat)
{
migraphx::shape as{migraphx::shape::float_type, {1, 32, 4096}};
migraphx::shape bs{migraphx::shape::int8_type, {4096, 5504}};
migraphx::shape cs{migraphx::shape::float_type, {4096, 11008}};
migraphx::module m1;
{
auto a = m1.add_parameter("input", as);

auto lit =
m1.add_literal(migraphx::generate_literal({migraphx::shape::float_type, {4096}}));
auto litb = m1.add_instruction(
migraphx::make_op("broadcast", {{"axis", 2}, {"out_lens", as.lens()}}), lit);
auto mul = m1.add_instruction(migraphx::make_op("mul"), a, litb);

std::vector<migraphx::instruction_ref> concats;
for(int i = 0; i < 2; i++)
{
auto b = m1.add_literal(migraphx::generate_literal(bs));
auto unpack = m1.add_instruction(migraphx::make_op("unpack_int4"), b);
auto scales = m1.add_literal(migraphx::generate_literal(cs));
auto dq = m1.add_instruction(migraphx::make_op("dequantizelinear"), unpack, scales);
concats.push_back(
m1.add_instruction(migraphx::make_op("unsqueeze", {{"axes", {0}}}), dq));
}
auto concat = m1.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), concats);
auto dot = m1.add_instruction(migraphx::make_op("dot"), mul, concat);
m1.add_return({dot});
};
migraphx::module m2 = m1;
run_pass(m1);

EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(dot_mul_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down

0 comments on commit d688810

Please sign in to comment.