Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reorder_slice_add_mul matcher #3478

Open
wants to merge 90 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
90 commits
Select commit Hold shift + click to select a range
709daf0
find_scalar_mul_conv added
aarushjain29 Sep 25, 2024
94c624b
Formatting
aarushjain29 Sep 25, 2024
8304922
Support for scalar in simplify_reshape
aarushjain29 Sep 26, 2024
2e75e28
check for scalar in find_mul_add
aarushjain29 Sep 26, 2024
cd6d0bc
Changes in simplify reshape
aarushjain29 Sep 26, 2024
bc51516
check if output of mul_add is fed in conv
aarushjain29 Sep 26, 2024
f8a4505
Fix bugs in conv
aarushjain29 Sep 26, 2024
b0edd34
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
kahmed10 Oct 2, 2024
69457bd
Logic improved
aarushijai Oct 3, 2024
ba4ed3b
Revert "Logic improved"
aarushjain29 Oct 23, 2024
6a99db9
slicing added
aarushjain29 Oct 30, 2024
e17bb63
removed conv changes
aarushjain29 Oct 30, 2024
3a67dab
removed old code
aarushjain29 Oct 30, 2024
dd94dbe
Merge remote-tracking branch 'origin/develop' into 3432-improve-simpl…
aarushjain29 Oct 30, 2024
4dd052d
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Oct 30, 2024
7d7d5d4
Merge branch '3432-improve-simplify_algebra-to-find-more-horizontal-f…
aarushjain29 Oct 30, 2024
4e74cf4
Blocked any input operation for a
aarushjain29 Oct 30, 2024
5f90a6f
merge slice add
aarushjain29 Nov 8, 2024
95b31f7
slice mul add producing mlir_slice_mul_reshape
aarushjain29 Nov 8, 2024
4cd618a
Formatting
aarushjain29 Nov 8, 2024
ab5282b
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
f49a289
WIP for make_basic_pred_matcher
aarushjain29 Nov 13, 2024
48b1f67
Working fusable_split for slice_add_mul
aarushjain29 Nov 13, 2024
cf62a4e
fusable split for find_mul_add matcher
aarushjain29 Nov 18, 2024
1101f35
formatting
aarushjain29 Nov 18, 2024
be20e61
corrected test case
aarushjain29 Nov 19, 2024
63b9737
corrected test case
aarushjain29 Nov 19, 2024
35c9441
formatting
aarushjain29 Nov 19, 2024
095192f
formatting
aarushjain29 Nov 19, 2024
aff69c3
Formatting errors
aarushjain29 Nov 19, 2024
d1fe6a1
Formatting errors
aarushjain29 Nov 19, 2024
4a787b3
Format
aarushjain29 Nov 19, 2024
4aaa2b9
format
aarushjain29 Nov 19, 2024
a979c11
skip_mul_add_for_horizontal_add_fusion
aarushjain29 Nov 20, 2024
4e11892
change fusable split
aarushjain29 Nov 21, 2024
7143f53
fusable_slice_add_mul_split
aarushjain29 Nov 22, 2024
e9b8c98
working slice_add_mul
aarushjain29 Nov 22, 2024
e4ec31d
refactor
aarushjain29 Nov 22, 2024
f8b4c05
refactor
aarushjain29 Nov 22, 2024
ffec7d8
Refactor
aarushjain29 Nov 22, 2024
0d5a6d5
skip the slice
aarushjain29 Nov 22, 2024
8f4b2b3
formatting
aarushjain29 Nov 22, 2024
4270a27
Formatting
aarushjain29 Nov 22, 2024
8788ba2
description for slice_add_mul
aarushjain29 Nov 22, 2024
7e28351
Removed 1 of the test which is not required
aarushjain29 Nov 25, 2024
9a31825
Changed some code by mistake
aarushjain29 Nov 25, 2024
2a10591
2 more test case added
aarushjain29 Nov 26, 2024
696b842
Formatting
aarushjain29 Nov 26, 2024
cec662f
test cases for clip model
aarushjain29 Nov 26, 2024
fa12fcc
add_dot_add_mul
aarushjain29 Nov 26, 2024
b05e0ba
formatting
aarushjain29 Nov 26, 2024
38b51e1
changed the name to slice_add
aarushjain29 Nov 26, 2024
7c30359
Formatting
aarushjain29 Nov 26, 2024
271bd2b
Formatting
aarushjain29 Nov 26, 2024
293dd51
change in mul_add struct - none_of added
aarushjain29 Nov 26, 2024
cef6880
changes in fusable_split
aarushjain29 Nov 27, 2024
f7f0996
Changed slice_mul_add to merge slice_mul
aarushijai Nov 27, 2024
892a26d
Revert "Changed slice_mul_add to merge slice_mul"
aarushijai Nov 27, 2024
e8e46d1
Changed slice_add_mul test case to fuse mul
aarushjain29 Nov 27, 2024
f478e04
test case for add_dot_add for each branch added
aarushjain29 Nov 27, 2024
e85f850
loop for literals
aarushjain29 Nov 27, 2024
ec4cb3e
removed older test case dot_add_mul
aarushjain29 Nov 27, 2024
318e36a
Added test for dot_add_mul with loop
aarushjain29 Nov 27, 2024
6be46f4
loop changes in add_dot_add_mul_2
aarushjain29 Nov 27, 2024
ad1bb34
seed values changed
aarushjain29 Nov 27, 2024
9804425
loop added for add_dot_add_mul_1
aarushjain29 Nov 27, 2024
7ebf297
shape s1 and s2
aarushjain29 Nov 27, 2024
b7f3037
CI Formatting
aarushjain29 Dec 2, 2024
256456a
CI Formatting
aarushjain29 Dec 2, 2024
9eceb9f
removing tidy errors
aarushjain29 Dec 2, 2024
6a50c51
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
9b00fd8
use of literals reserve to remove tidy errors
aarushjain29 Dec 2, 2024
35a20e0
semicolon missing added
aarushjain29 Dec 2, 2024
58a4a0f
Removing tidy errors CI
aarushjain29 Dec 2, 2024
ce65be9
semicolon missing added
aarushjain29 Dec 2, 2024
3d070cb
Formatting
aarushjain29 Dec 2, 2024
75da7c3
Formatting
aarushjain29 Dec 2, 2024
28eb64d
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
c59d2fa
generating literals in loop for add_dot_add_mul_1
aarushjain29 Dec 3, 2024
0eeb663
WIP change in add_dot_add_mul_2
aarushjain29 Dec 3, 2024
820bb12
loops added
aarushjain29 Dec 3, 2024
3950cd3
Formatting
aarushjain29 Dec 3, 2024
0d99d23
Formatting
aarushjain29 Dec 4, 2024
8bd694e
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Dec 30, 2024
c8290eb
Merge branch 'develop' into 3432-improve-simplify_algebra-to-find-mor…
aarushjain29 Jan 15, 2025
54ba97a
Formatting
aarushjain29 Jan 15, 2025
bde1680
Formatting
aarushjain29 Jan 15, 2025
1c676ba
licensing
aarushjain29 Jan 15, 2025
e8b91a9
format
aarushjain29 Jan 15, 2025
330c392
cpp check
aarushjain29 Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 73 additions & 8 deletions src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -422,14 +422,46 @@
{
auto matcher() const
{
return match::name("mul")(match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
return match::name("mul")(
match::none_of[match::outputs()](match::name("convolution")),
match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}

void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
auto a_ins = r.instructions["a"];
auto b_ins = r.instructions["b"];
auto x_ins = r.instructions["x"];
assert(x_ins != b_ins);

auto ax_ins = m.insert_instruction(ins, make_op("mul"), a_ins, x_ins);
auto ab_ins = m.insert_instruction(ins, make_op("mul"), a_ins, b_ins);
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
/* Delete this later
struct find_mul_add
{
auto matcher() const
{
return match::name("mul")(
match::none_of[match::outputs()](match::name("convolution")),
match::either_arg(0, 1)(
match::name("add")(
match::either_arg(0, 1)(
match::any().bind("x"),
match::any_of(conv_const_weights(), match::is_constant()).bind("b")),
match::none_of(match::args(match::is_constant(), match::is_constant())),
match::used_once()),
match::is_constant().bind("a")));
}

void apply(module& m, const match::matcher_result& r) const
Expand All @@ -445,6 +477,38 @@
m.replace_instruction(ins, make_op("add"), ax_ins, ab_ins);
}
};
*/

struct find_scalar_mul_conv
{
auto matcher() const
{
return match::name("mul")(
match::either_arg(0, 1)(
match::is_constant().bind("scalar"),
match::name("convolution").bind("conv")
));
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
}

void apply(module& m, const match::matcher_result& r) const

Check warning on line 493 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L493

Added line #L493 was not covered by tests
{
auto ins = r.result;
auto scalar = r.instructions["scalar"];
auto conv_ins = r.instructions["conv"];

Check warning on line 497 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L495-L497

Added lines #L495 - L497 were not covered by tests

// Get the convol's input and weights
auto conv_input = conv_ins->inputs().front(); // input to conv
auto conv_weights = conv_ins->inputs().back(); // weights of the conv

Check warning on line 501 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L500-L501

Added lines #L500 - L501 were not covered by tests

auto scaled_weights = m.insert_instruction(ins, make_op("mul"), scalar, conv_weights);

Check warning on line 503 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L503

Added line #L503 was not covered by tests

// new conv with modified weights
auto new_conv = m.insert_instruction(ins, conv_ins->get_operator(), conv_input, scaled_weights);

Check warning on line 506 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L506

Added line #L506 was not covered by tests

m.replace_instruction(ins, new_conv);
}

Check warning on line 509 in src/simplify_algebra.cpp

View check run for this annotation

Codecov / codecov/patch

src/simplify_algebra.cpp#L508-L509

Added lines #L508 - L509 were not covered by tests
};


struct find_dot_add
{
Expand Down Expand Up @@ -1981,6 +2045,7 @@
find_dot_slice{},
find_dot_mul{},
find_mul_add{},
find_scalar_mul_conv{},
find_unit_ops{},
find_neg_unit_ops{},
eliminate_zero_point{},
Expand Down
3 changes: 2 additions & 1 deletion src/simplify_reshapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,8 @@ struct find_unary_shape_transforms
{
return ins->name() == "@literal" or
ins->get_operator().attributes().contains("pointwise") or
contains(ins->name(), "reduce");
contains(ins->name(), "reduce") or
ins->get_operator().attributes().contains("scalar"); //For scalar operation
CharlieL7 marked this conversation as resolved.
Show resolved Hide resolved
}

void apply(module& m, const match::matcher_result& mr) const
Expand Down
Loading