-
Notifications
You must be signed in to change notification settings - Fork 94
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Horizontal fusions of gemms and convolutions (#472)
* Add decompose pass * Add decompose test * Formatting * Add remap * Formatting * Add compute method for dot * Formatting * Add finder for horizontal fusion * Formatting * Formatting * Reuse predicate * Add gemm fusions * Formatting * Add some fixes for convolution * Formatting * Fix shape tests * Formatting * Reuse axis equal * Add initial split fusion * Formatting * Update offset * Workaround outputs that cant accept nonstandard shapes * Formatting * Add check for split concat * Formatting * Add missing headers * Formatting * Add tests * Formatting * Add more testing * Formatting * Fix when there is duplicate splits in inputs * Formatting * Fix mismatch iterators * Add tests for dot fusions * Formatting * Add test for convolution * Formatting * Fix tidy issues * Add more tests * Formatting * Ignore build directory for codecov * Add test for groups * Formatting * Add more tests for groups * Formatting * Add test for missing end slice * Add newline * Remove unused function * Add support for when beta is not 1 * Formatting * Add test for scalar * Add one more scalar test Co-authored-by: mvermeulen <[email protected]>
- Loading branch information
1 parent
45bb91e
commit 1a4ff50
Showing
21 changed files
with
1,370 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
ignore: | ||
- "test/" | ||
- "src/driver" | ||
|
||
- "build/" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
#include <migraphx/decompose.hpp> | ||
#include <migraphx/program.hpp> | ||
#include <migraphx/instruction.hpp> | ||
#include <migraphx/iterator_for.hpp> | ||
#include <migraphx/functional.hpp> | ||
#include <migraphx/ranges.hpp> | ||
#include <migraphx/float_equal.hpp> | ||
#include <migraphx/matcher.hpp> | ||
#include <migraphx/op/dot.hpp> | ||
#include <migraphx/op/multibroadcast.hpp> | ||
#include <migraphx/op/mul.hpp> | ||
#include <migraphx/op/add.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
namespace { | ||
struct find_dot_add | ||
{ | ||
auto matcher() const { return match::name("dot")(match::nargs(3)); } | ||
|
||
void apply(program& p, const match::matcher_result& r) const | ||
{ | ||
auto ins = r.result; | ||
auto dot = any_cast<op::dot>(ins->get_operator()); | ||
if(not float_equal(dot.beta, 1) and | ||
not contains({shape::float_type, shape::half_type, shape::double_type}, | ||
ins->get_shape().type())) | ||
return; | ||
auto dot_ins = | ||
p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]); | ||
auto c_ins = ins->inputs()[2]; | ||
if(not float_equal(dot.beta, 1)) | ||
{ | ||
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}}); | ||
auto beta_broadcast = | ||
p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta); | ||
c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast); | ||
} | ||
p.replace_instruction(ins, op::add{}, dot_ins, c_ins); | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
void decompose::apply(program& p) const { match::find_matches(p, find_dot_add{}); } | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#ifndef MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP | ||
#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP | ||
|
||
#include <algorithm> | ||
#include <migraphx/config.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
template <class Iterator, class Output, class Predicate> | ||
void group_by(Iterator start, Iterator last, Output out, Predicate pred) | ||
{ | ||
while(start != last) | ||
{ | ||
auto it = std::partition(start, last, [&](auto x) { return pred(x, *start); }); | ||
out(start, it); | ||
start = it; | ||
} | ||
} | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP | ||
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP | ||
|
||
#include <string> | ||
#include <migraphx/instruction_ref.hpp> | ||
#include <migraphx/config.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
struct program; | ||
|
||
/** | ||
* Decompose operators. | ||
*/ | ||
struct decompose | ||
{ | ||
std::string name() const { return "decompose"; } | ||
void apply(program& p) const; | ||
}; | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP | ||
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP | ||
|
||
#include <migraphx/config.hpp> | ||
#include <migraphx/dfor.hpp> | ||
#include <migraphx/shape_for_each.hpp> | ||
#include <migraphx/tensor_view.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
template <class T, class F> | ||
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta) | ||
{ | ||
std::size_t n_dims = cmat.get_shape().lens().size(); | ||
std::size_t dim_0 = n_dims - 2; | ||
std::size_t dim_1 = n_dims - 1; | ||
auto k = amat.get_shape().lens()[dim_1]; | ||
|
||
assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); | ||
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); | ||
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); | ||
|
||
shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { | ||
auto a_idx = c_idx; | ||
auto b_idx = c_idx; | ||
double s = 0.0; | ||
dfor(k)([&](auto kk) { | ||
a_idx[dim_1] = b_idx[dim_0] = kk; | ||
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); | ||
}); | ||
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; | ||
}); | ||
} | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP | ||
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP | ||
|
||
#include <string> | ||
#include <migraphx/instruction_ref.hpp> | ||
#include <migraphx/config.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
|
||
struct program; | ||
|
||
/** | ||
* Decompose operators. | ||
*/ | ||
struct remap | ||
{ | ||
std::string name() const { return "remap"; } | ||
void apply(program& p) const; | ||
}; | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#include <migraphx/remap.hpp> | ||
#include <migraphx/program.hpp> | ||
#include <migraphx/instruction.hpp> | ||
#include <migraphx/iterator_for.hpp> | ||
#include <migraphx/functional.hpp> | ||
#include <migraphx/ranges.hpp> | ||
#include <migraphx/float_equal.hpp> | ||
#include <migraphx/matcher.hpp> | ||
#include <migraphx/op/dot.hpp> | ||
#include <migraphx/op/add.hpp> | ||
|
||
namespace migraphx { | ||
inline namespace MIGRAPHX_INLINE_NS { | ||
namespace { | ||
struct find_dot_add | ||
{ | ||
auto matcher() const | ||
{ | ||
return match::name("add")(match::any_of( | ||
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")), | ||
match::args(match::used_once().bind("a"), | ||
match::name("dot")(match::nargs(2)).bind("dot")))); | ||
} | ||
|
||
void apply(program& p, match::matcher_result r) const | ||
{ | ||
auto ins = r.result; | ||
auto dot_ins = r.instructions["dot"]; | ||
auto a_ins = r.instructions["a"]; | ||
|
||
auto dot = any_cast<op::dot>(dot_ins->get_operator()); | ||
|
||
dot.beta = 1; | ||
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); } | ||
|
||
} // namespace MIGRAPHX_INLINE_NS | ||
} // namespace migraphx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.