Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reorder_slice_add_mul matcher #3478

Open
wants to merge 91 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
709daf0
find_scalar_mul_conv added
aarushjain29 Sep 25, 2024
94c624b
Formatting
aarushjain29 Sep 25, 2024
8304922
Support for scalar in simplify_reshape
aarushjain29 Sep 26, 2024
2e75e28
check for scalar in find_mul_add
aarushjain29 Sep 26, 2024
cd6d0bc
Changes in simplify reshape
aarushjain29 Sep 26, 2024
bc51516
check if output of mul_add is fed in conv
aarushjain29 Sep 26, 2024
f8a4505
Fix bugs in conv
aarushjain29 Sep 26, 2024
b0edd34
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
kahmed10 Oct 2, 2024
69457bd
Logic improved
aarushijai Oct 3, 2024
ba4ed3b
Revert "Logic improved"
aarushjain29 Oct 23, 2024
6a99db9
slicing added
aarushjain29 Oct 30, 2024
e17bb63
removed conv changes
aarushjain29 Oct 30, 2024
3a67dab
removed old code
aarushjain29 Oct 30, 2024
dd94dbe
Merge remote-tracking branch 'origin/develop' into 3432-improve-simpl…
aarushjain29 Oct 30, 2024
4dd052d
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Oct 30, 2024
7d7d5d4
Merge branch '3432-improve-simplify_algebra-to-find-more-horizontal-f…
aarushjain29 Oct 30, 2024
4e74cf4
Blocked any input operation for a
aarushjain29 Oct 30, 2024
5f90a6f
merge slice add
aarushjain29 Nov 8, 2024
95b31f7
slice mul add producing mlir_slice_mul_reshape
aarushjain29 Nov 8, 2024
4cd618a
Formatting
aarushjain29 Nov 8, 2024
ab5282b
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
f49a289
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
48b1f67
Working fusable_split for slice_add_mul
aarushjain29 Nov 13, 2024
cf62a4e
fusable split for find_mul_add matcher
aarushjain29 Nov 18, 2024
1101f35
formatting
aarushjain29 Nov 18, 2024
be20e61
corrected test case
aarushjain29 Nov 19, 2024
63b9737
corrected test case
aarushjain29 Nov 19, 2024
35c9441
formatting
aarushjain29 Nov 19, 2024
095192f
formatting
aarushjain29 Nov 19, 2024
aff69c3
Formatting errors
aarushjain29 Nov 19, 2024
d1fe6a1
Formatting errors
aarushjain29 Nov 19, 2024
4a787b3
Format
aarushjain29 Nov 19, 2024
4aaa2b9
format
aarushjain29 Nov 19, 2024
a979c11
skip_mul_add_for_horizontal_add_fusion
aarushjain29 Nov 20, 2024
4e11892
change fusable split
aarushjain29 Nov 21, 2024
7143f53
fusable_slice_add_mul_split
aarushjain29 Nov 22, 2024
e9b8c98
working slice_add_mul
aarushjain29 Nov 22, 2024
e4ec31d
refactor
aarushjain29 Nov 22, 2024
f8b4c05
refactor
aarushjain29 Nov 22, 2024
ffec7d8
Refactor
aarushjain29 Nov 22, 2024
0d5a6d5
skip the slice
aarushjain29 Nov 22, 2024
8f4b2b3
formatting
aarushjain29 Nov 22, 2024
4270a27
Formatting
aarushjain29 Nov 22, 2024
8788ba2
description for slice_add_mul
aarushjain29 Nov 22, 2024
7e28351
Removed 1 of the test which is not required
aarushjain29 Nov 25, 2024
9a31825
Changed some code by mistake
aarushjain29 Nov 25, 2024
2a10591
2 more test case added
aarushjain29 Nov 26, 2024
696b842
Formatting
aarushjain29 Nov 26, 2024
cec662f
test cases for clip model
aarushjain29 Nov 26, 2024
fa12fcc
add_dot_add_mul
aarushjain29 Nov 26, 2024
b05e0ba
formatting
aarushjain29 Nov 26, 2024
38b51e1
changed the name to slice_add
aarushjain29 Nov 26, 2024
7c30359
Formatting
aarushjain29 Nov 26, 2024
271bd2b
Formatting
aarushjain29 Nov 26, 2024
293dd51
change in mul_add struct - none_of added
aarushjain29 Nov 26, 2024
cef6880
changes in fusable_split
aarushjain29 Nov 27, 2024
f7f0996
Changed slice_mul_add to merge slice_mul
aarushijai Nov 27, 2024
892a26d
Revert "Changed slice_mul_add to merge slice_mul"
aarushijai Nov 27, 2024
e8e46d1
Changed slice_add_mul test case to fuse mul
aarushjain29 Nov 27, 2024
f478e04
test case for add_dot_add for each branch added
aarushjain29 Nov 27, 2024
e85f850
loop for literals
aarushjain29 Nov 27, 2024
ec4cb3e
removed older test case dot_add_mul
aarushjain29 Nov 27, 2024
318e36a
Added test for dot_add_mul with loop
aarushjain29 Nov 27, 2024
6be46f4
loop changes in add_dot_add_mul_2
aarushjain29 Nov 27, 2024
ad1bb34
seed values changed
aarushjain29 Nov 27, 2024
9804425
loop added for add_dot_add_mul_1
aarushjain29 Nov 27, 2024
7ebf297
shape s1 and s2
aarushjain29 Nov 27, 2024
b7f3037
CI Formatting
aarushjain29 Dec 2, 2024
256456a
CI Formatting
aarushjain29 Dec 2, 2024
9eceb9f
removing tidy errors
aarushjain29 Dec 2, 2024
6a50c51
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
9b00fd8
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
35a20e0
semicolon missing added
aarushjain29 Dec 2, 2024
58a4a0f
Removing tidy errors CI
aarushjain29 Dec 2, 2024
ce65be9
semicolon missing added
aarushjain29 Dec 2, 2024
3d070cb
Formatting
aarushjain29 Dec 2, 2024
75da7c3
Formatting
aarushjain29 Dec 2, 2024
28eb64d
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
c59d2fa
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
0eeb663
WIP change in add_dot_add_mul_2
aarushjain29 Dec 3, 2024
820bb12
loops added
aarushjain29 Dec 3, 2024
3950cd3
Formatting
aarushjain29 Dec 3, 2024
0d99d23
Formatting
aarushjain29 Dec 4, 2024
8bd694e
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Dec 30, 2024
c8290eb
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Jan 15, 2025
54ba97a
Formatting
aarushjain29 Jan 15, 2025
bde1680
Formatting
aarushjain29 Jan 15, 2025
1c676ba
licensing
aarushjain29 Jan 15, 2025
e8b91a9
format
aarushjain29 Jan 15, 2025
330c392
cpp check
aarushjain29 Jan 16, 2025
d13d92f
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 48 additions & 3 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,17 @@
}
};

auto fusable_split(const std::string& name)
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved
{
return match::make_basic_pred_matcher([&](instruction_ref ins) {
return all_of(ins->outputs(), [&](instruction_ref slice) {
if(slice->name() != "slice")
return true;
return any_of(slice->outputs(), [&](instruction_ref x) { return x->name() == name; });
});
});
}

// ******************************
// a * (x + b) => a * x + a * b
// ******************************
Expand All @@ -439,10 +450,11 @@
{
auto matcher() const
{
auto slice_1 = match::none_of(match::name("slice")(match::arg(0)(fusable_split("add"))));

Check warning on line 453 in src/simplify_algebra.cpp

View workflow job for this annotation

GitHub Actions / cppcheck

style: Too many nested parentheses can affect readability; consider using variables instead. [migraphx-MatcherNestedParentheses]
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
slice_1.bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
Expand All @@ -455,6 +467,7 @@
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];

assert(x_ins != b_ins);

auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
Expand All @@ -463,6 +476,38 @@
}
};

// When a slice is followed by a*(x+b), a⋅(x+b), this matcher performs
// the add first, followed by the mul.
// If multiple slices originate from the same instruction and are followed by add,
// all the adds can be const folded and performed before the slicing.
struct find_slice_add_mul
{
auto matcher() const
{
auto slice_1 = match::name("slice")(match::arg(0)(fusable_split("add")));
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
slice_1.bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];

assert(x_ins != b_ins);

auto ax_ins = m.insert_instruction(ins, make_op("add"), x_ins, b_ins);
m.replace_instruction(ins, make_op("mul"), ax_ins, a_ins);
}
};

struct find_dot_add
{
auto matcher() const
Expand Down Expand Up @@ -1533,8 +1578,7 @@

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;

auto ins = r.result;
auto pred = [](auto i, auto j) {
if(i->get_operator() != j->get_operator())
return false;
Expand Down Expand Up @@ -2012,6 +2056,7 @@
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_slice_add_mul{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
Expand Down
240 changes: 240 additions & 0 deletions test/simplify_algebra_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3885,6 +3885,132 @@ TEST_CASE(dot_fusion_reshape)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(add_dot_add_mul_2)
{
migraphx::shape as{migraphx::shape::int8_type, {1, 77, 768}};

migraphx::shape s1{migraphx::shape::int8_type, {1, 768, 768}};
migraphx::shape s2{migraphx::shape::int8_type, {1, 77, 768}};

migraphx::module m1;
{
auto a = m1.add_parameter("a", as);

std::vector<migraphx::instruction_ref> final_add_results;
for(int i = 0; i < 3; ++i)
{
auto lit1 = m1.add_literal(migraphx::generate_literal(s2, i));
auto lit2 = m1.add_literal(migraphx::generate_literal(s1, i));
auto lit3 = m1.add_literal(migraphx::generate_literal(s2, i + 3));
auto add_i = m1.add_instruction(migraphx::make_op("add"), a, lit1);
auto dot_i = m1.add_instruction(migraphx::make_op("dot"), add_i, lit2);
auto add_final = m1.add_instruction(migraphx::make_op("add"), dot_i, lit3);
final_add_results.push_back(add_final);
}
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved

auto lit9 = m1.add_literal(migraphx::generate_literal(s2, 9));
auto mul = m1.add_instruction(migraphx::make_op("mul"), final_add_results[2], lit9);
m1.add_return({final_add_results[0], final_add_results[1], final_add_results[2], mul});
};
run_pass(m1);

migraphx::module m2;
{
auto a = m2.add_parameter("a", as);

std::vector<migraphx::instruction_ref> literals;
for(int i = 0; i < 3; ++i)
{
auto lit1 = m2.add_literal(migraphx::generate_literal(s2, i));
auto lit2 = m2.add_literal(migraphx::generate_literal(s1, i));
auto lit3 = m2.add_literal(migraphx::generate_literal(s2, i + 3));
literals.push_back(lit1);
literals.push_back(lit2);
literals.push_back(lit3);
}

auto lit9 = m2.add_literal(migraphx::generate_literal(s2, 9));
auto concat_1 = m2.add_instruction(
migraphx::make_op("concat", {{"axis", 2}}), literals[2], literals[6], literals[0]);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), literals[5], literals[4]);
auto dot2 = m2.add_instruction(migraphx::make_op("dot"), literals[8], literals[7]);
auto dot3 = m2.add_instruction(migraphx::make_op("dot"), lit9, literals[1]);
auto concat_2 =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), dot3, dot2, dot1);
auto add1 = m2.add_instruction(migraphx::make_op("add"), concat_2, concat_1);
auto concat_3 = m2.add_instruction(
migraphx::make_op("concat", {{"axis", 2}}), literals[1], literals[7], literals[4]);
auto dot1_1 = m2.add_instruction(migraphx::make_op("dot"), a, concat_3);
auto add1_1 = m2.add_instruction(migraphx::make_op("add"), dot1_1, add1);

auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}),
add1_1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), slice_c, literals[3]);

auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), add1_1);
auto slice_b = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}),
add1_1);
m2.add_return({slice_a, slice_b, slice_c, mul});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(add_dot_add_mul_1)
{
migraphx::shape as{migraphx::shape::int8_type, {1, 77, 768}};
migraphx::shape s1{migraphx::shape::int8_type, {1, 768, 768}};
migraphx::shape s2{migraphx::shape::int8_type, {1, 77, 768}};
migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
std::vector<migraphx::instruction_ref> add_results;
for(int i = 0; i < 3; ++i)
{
auto lit1 = m1.add_literal(migraphx::generate_literal(s1, i));
auto lit2 = m1.add_literal(migraphx::generate_literal(s2, i));
auto dot_i = m1.add_instruction(migraphx::make_op("dot"), a, lit1);
auto add_i = m1.add_instruction(migraphx::make_op("add"), dot_i, lit2);
add_results.push_back(add_i);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The loop should insert the dot/add instruction as well.

Copy link
Collaborator

@pfultz2 pfultz2 Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Literals should be generated in the loop:

        for(int i = 0; i < 3; ++i)
        {
            auto lit1 = m1.add_literal(migraphx::generate_literal(s1, i));
            auto lit2 = m1.add_literal(migraphx::generate_literal(s2, i));
            auto dot_i = m1.add_instruction(migraphx::make_op("dot"), a, lit1);
            auto add_i = m1.add_instruction(migraphx::make_op("add"), dot_i, lit2);
            add_results.push_back(add_i);
        }


auto lit3 = m1.add_literal(migraphx::generate_literal(s2, 6));
auto mul = m1.add_instruction(migraphx::make_op("mul"), add_results[2], lit3);

m1.add_return({mul});
};
run_pass(m1);

migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
std::vector<migraphx::instruction_ref> literals;
for(int i = 0; i < 6; i += 2)
{
auto lit1 = m2.add_literal(migraphx::generate_literal(s1, i));
auto lit2 = m2.add_literal(migraphx::generate_literal(s2, i + 1));
literals.push_back(lit1);
literals.push_back(lit2);
}
auto lit6 = m2.add_literal(migraphx::generate_literal(s2, 6));
auto concat_1 = m2.add_instruction(
migraphx::make_op("concat", {{"axis", 2}}), literals[0], literals[2], literals[4]);
auto concat_2 = m2.add_instruction(
migraphx::make_op("concat", {{"axis", 2}}), literals[1], literals[3], literals[5]);
auto dot1 = m2.add_instruction(migraphx::make_op("dot"), a, concat_1);
auto add1 = m2.add_instruction(migraphx::make_op("add"), dot1, concat_2);
auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}),
add1);
auto mul = m2.add_instruction(migraphx::make_op("mul"), slice_c, lit6);

m2.add_return({mul});
};
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(mul_dot_a)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down Expand Up @@ -4256,6 +4382,120 @@ TEST_CASE(dot_slice_a)
EXPECT(m1.sort() == m2.sort());
}

TEST_CASE(slice_add)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};

migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto slice_a = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a);
auto one = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto add1 = m1.add_instruction(migraphx::make_op("add"), slice_a, one);
auto add2 = m1.add_instruction(migraphx::make_op("add"), slice_b, two);
auto add3 = m1.add_instruction(migraphx::make_op("add"), slice_c, three);

m1.add_return({add1, add2, add3});
};
run_pass(m1);

migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto one = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto concat =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), one, two, three);

auto add = m2.add_instruction(migraphx::make_op("add"), a, concat);

auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), add);
auto slice_b = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), add);
auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), add);
m2.add_return({slice_a, slice_b, slice_c});
};
EXPECT(m1.sort() == m2.sort());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a problem. The add_mul should be rewritten.

Copy link
Contributor Author

@aarushjain29 aarushjain29 Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The slice_add_mul testcase is checking the case when there is a slice->add->mul, but the other slices are not followed by an add.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the add_dot_add_mul test case which considers three dots followed by 3 adds and one of them is followed by mul as per the issue found in model.

}
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved

TEST_CASE(slice_add_mul)
{
migraphx::shape as{migraphx::shape::float_type, {1, 77, 2304}};

migraphx::module m1;
{
auto a = m1.add_parameter("a", as);
auto slice_a = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), a);
auto slice_b = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), a);
auto slice_c = m1.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}), a);
auto one = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));
auto four = m1.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto add1 = m1.add_instruction(migraphx::make_op("add"), slice_a, one);
auto mul1 = m1.add_instruction(migraphx::make_op("mul"), add1, four);

auto mul2 = m1.add_instruction(migraphx::make_op("mul"), slice_b, two);
auto mul3 = m1.add_instruction(migraphx::make_op("mul"), slice_c, three);
m1.add_return({mul1, mul2, mul3});
};
run_pass(m1);
migraphx::module m2;
{
auto a = m2.add_parameter("a", as);
auto one = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 3.0f));
auto two = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 2.0f));
auto three = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));
auto four = m2.add_literal(
migraphx::generate_literal({migraphx::shape::float_type, {1, 77, 768}}, 4.0f));

auto concat =
m2.add_instruction(migraphx::make_op("concat", {{"axis", 2}}), four, two, three);
auto mul2 = m2.add_instruction(migraphx::make_op("mul"), a, concat);
auto slice_a = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {0}}, {"ends", {768}}}), mul2);
auto slice_a_mul = m2.add_instruction(migraphx::make_op("mul"), four, one);
auto slice_a_add = m2.add_instruction(migraphx::make_op("add"), slice_a, slice_a_mul);
auto slice_b = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {768}}, {"ends", {1536}}}), mul2);
auto slice_c = m2.add_instruction(
migraphx::make_op("slice", {{"axes", {2}}, {"starts", {1536}}, {"ends", {2304}}}),
mul2);

m2.add_return({slice_a_add, slice_b, slice_c});
};
EXPECT(m1.sort() == m2.sort());
}
aarushjain29 marked this conversation as resolved.
Show resolved Hide resolved

TEST_CASE(dot_slice_ab)
{
migraphx::shape as{migraphx::shape::float_type, {2, 256, 32}};
Expand Down
Loading