Skip to content

Commit

Permalink
Horizontal fusions of gemms and convolutions (#472)
Browse files Browse the repository at this point in the history
* Add decompose pass

* Add decompose test

* Formatting

* Add remap

* Formatting

* Add compute method for dot

* Formatting

* Add finder for horizontal fusion

* Formatting

* Formatting

* Reuse predicate

* Add gemm fusions

* Formatting

* Add some fixes for convolution

* Formatting

* Fix shape tests

* Formatting

* Reuse axis equal

* Add initial split fusion

* Formatting

* Update offset

* Workaround outputs that cant accept nonstandard shapes

* Formatting

* Add check for split concat

* Formatting

* Add missing headers

* Formatting

* Add tests

* Formatting

* Add more testing

* Formatting

* Fix when there is duplicate splits in inputs

* Formatting

* Fix mismatch iterators

* Add tests for dot fusions

* Formatting

* Add test for convolution

* Formatting

* Fix tidy issues

* Add more tests

* Formatting

* Ignore build directory for codecov

* Add test for groups

* Formatting

* Add more tests for groups

* Formatting

* Add test for missing end slice

* Add newline

* Remove unused function

* Add support for when beta is not 1

* Formatting

* Add test for scalar

* Add one more scalar test

Co-authored-by: mvermeulen <[email protected]>
  • Loading branch information
pfultz2 and mvermeulen authored May 8, 2020
1 parent 45bb91e commit 1a4ff50
Show file tree
Hide file tree
Showing 21 changed files with 1,370 additions and 6 deletions.
2 changes: 1 addition & 1 deletion codecov.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
ignore:
- "test/"
- "src/driver"

- "build/"
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include(ROCMPackageConfigHelpers)
add_library(migraphx
auto_contiguous.cpp
eliminate_common_subexpression.cpp
decompose.cpp
propagate_constant.cpp
dead_code_elimination.cpp
eliminate_allocation.cpp
Expand All @@ -20,6 +21,7 @@ add_library(migraphx
instruction.cpp
program.cpp
quantization.cpp
remap.cpp
shape.cpp
schedule.cpp
pass_manager.cpp
Expand Down
48 changes: 48 additions & 0 deletions src/decompose.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include <migraphx/decompose.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/multibroadcast.hpp>
#include <migraphx/op/mul.hpp>
#include <migraphx/op/add.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const { return match::name("dot")(match::nargs(3)); }

void apply(program& p, const match::matcher_result& r) const
{
auto ins = r.result;
auto dot = any_cast<op::dot>(ins->get_operator());
if(not float_equal(dot.beta, 1) and
not contains({shape::float_type, shape::half_type, shape::double_type},
ins->get_shape().type()))
return;
auto dot_ins =
p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]);
auto c_ins = ins->inputs()[2];
if(not float_equal(dot.beta, 1))
{
auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}});
auto beta_broadcast =
p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta);
c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast);
}
p.replace_instruction(ins, op::add{}, dot_ins, c_ins);
}
};

} // namespace

void decompose::apply(program& p) const { match::find_matches(p, find_dot_add{}); }

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
24 changes: 24 additions & 0 deletions src/include/migraphx/algorithm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP
#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP

#include <algorithm>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class Iterator, class Output, class Predicate>
void group_by(Iterator start, Iterator last, Output out, Predicate pred)
{
while(start != last)
{
auto it = std::partition(start, last, [&](auto x) { return pred(x, *start); });
out(start, it);
start = it;
}
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
25 changes: 25 additions & 0 deletions src/include/migraphx/decompose.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP
#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP

#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct program;

/**
* Decompose operators.
*/
struct decompose
{
std::string name() const { return "decompose"; }
void apply(program& p) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
39 changes: 39 additions & 0 deletions src/include/migraphx/gemm.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP
#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP

#include <migraphx/config.hpp>
#include <migraphx/dfor.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/tensor_view.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

template <class T, class F>
void gemm(tensor_view<T> cmat, tensor_view<T> amat, tensor_view<T> bmat, F alpha, F beta)
{
std::size_t n_dims = cmat.get_shape().lens().size();
std::size_t dim_0 = n_dims - 2;
std::size_t dim_1 = n_dims - 1;
auto k = amat.get_shape().lens()[dim_1];

assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]);
assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]);

shape_for_each(cmat.get_shape(), [&](const auto& c_idx) {
auto a_idx = c_idx;
auto b_idx = c_idx;
double s = 0.0;
dfor(k)([&](auto kk) {
a_idx[dim_1] = b_idx[dim_0] = kk;
s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end());
});
cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta;
});
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
13 changes: 13 additions & 0 deletions src/include/migraphx/op/dot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <migraphx/literal.hpp>
#include <migraphx/shape_for_each.hpp>
#include <migraphx/config.hpp>
#include <migraphx/gemm.hpp>
#include <cmath>
#include <utility>

Expand Down Expand Up @@ -67,6 +68,18 @@ struct dot

return {t, out_lens};
}

argument compute(context&, shape output_shape, std::vector<argument> args) const
{
argument result;
if(args.size() == 3)
result = args[2];
else
result = argument{output_shape};
visit_all(result, args[0], args[1])(
[&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); });
return result;
}
};

} // namespace op
Expand Down
3 changes: 3 additions & 0 deletions src/include/migraphx/program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct program
instruction_ref remove_instructions(instruction_ref first, instruction_ref last);

instruction_ref move_instruction(instruction_ref src, instruction_ref dst);
instruction_ref move_instructions(instruction_ref src, instruction_ref dst);

template <class... Ts>
instruction_ref add_literal(Ts&&... xs)
Expand Down Expand Up @@ -125,6 +126,8 @@ struct program

void annotate(std::ostream& os, std::function<void(instruction_ref)> a) const;

program& sort();

friend std::ostream& operator<<(std::ostream& os, const program& p);
friend bool operator==(const program& x, const program& y);
friend bool operator!=(const program& x, const program& y) { return !(x == y); }
Expand Down
25 changes: 25 additions & 0 deletions src/include/migraphx/remap.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP
#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP

#include <string>
#include <migraphx/instruction_ref.hpp>
#include <migraphx/config.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

struct program;

/**
* Decompose operators.
*/
struct remap
{
std::string name() const { return "remap"; }
void apply(program& p) const;
};

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
2 changes: 2 additions & 0 deletions src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct shape
/// Returns true if all strides are equal to 0 (scalar tensor)
bool scalar() const;

shape normalize_standard() const;

friend bool operator==(const shape& x, const shape& y);
friend bool operator!=(const shape& x, const shape& y);
friend std::ostream& operator<<(std::ostream& os, const shape& x);
Expand Down
20 changes: 20 additions & 0 deletions src/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <iostream>
#include <sstream>
#include <algorithm>
#include <set>
#include <utility>

namespace migraphx {
Expand Down Expand Up @@ -260,6 +261,14 @@ instruction_ref program::move_instruction(instruction_ref src, instruction_ref d
return src;
}

instruction_ref program::move_instructions(instruction_ref src, instruction_ref dst)
{
this->move_instruction(src, dst);
for(auto ins : src->inputs())
this->move_instruction(ins, src);
return src;
}

instruction_ref program::add_literal(literal l)
{
impl->instructions.emplace_front(std::move(l));
Expand Down Expand Up @@ -796,6 +805,17 @@ void program::annotate(std::ostream& os, std::function<void(instruction_ref)> a)
});
}

program& program::sort()
{
fix([&](auto self, auto ins) {
this->move_instruction(ins, this->begin());
for(auto child : ins->inputs())
self(child);
})(std::prev(this->end()));
assert(this->validate() == this->end());
return *this;
}

bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); }

std::ostream& operator<<(std::ostream& os, const program& p)
Expand Down
42 changes: 42 additions & 0 deletions src/remap.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
#include <migraphx/remap.hpp>
#include <migraphx/program.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/iterator_for.hpp>
#include <migraphx/functional.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/float_equal.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/op/dot.hpp>
#include <migraphx/op/add.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace {
struct find_dot_add
{
auto matcher() const
{
return match::name("add")(match::any_of(
match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")),
match::args(match::used_once().bind("a"),
match::name("dot")(match::nargs(2)).bind("dot"))));
}

void apply(program& p, match::matcher_result r) const
{
auto ins = r.result;
auto dot_ins = r.instructions["dot"];
auto a_ins = r.instructions["a"];

auto dot = any_cast<op::dot>(dot_ins->get_operator());

dot.beta = 1;
p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins);
}
};
} // namespace

void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); }

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
8 changes: 8 additions & 0 deletions src/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ bool shape::scalar() const

bool shape::standard() const { return impl->m_standard; }

shape shape::normalize_standard() const
{
if(this->standard())
return {this->type(), this->lens()};
else
return *this;
}

std::size_t shape::element_space() const { return impl->element_space(); }

std::string shape::type_string() const
Expand Down
Loading

0 comments on commit 1a4ff50

Please sign in to comment.