From 1a4ff50411dce662f265a8fbf7a720bd28762bb3 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Fri, 8 May 2020 16:42:10 -0500 Subject: [PATCH] Horizontal fusions of gemms and convolutions (#472) * Add decompose pass * Add decompose test * Formatting * Add remap * Formatting * Add compute method for dot * Formatting * Add finder for horizontal fusion * Formatting * Formatting * Reuse predicate * Add gemm fusions * Formatting * Add some fixes for convolution * Formatting * Fix shape tests * Formatting * Reuse axis equal * Add initial split fusion * Formatting * Update offset * Workaround outputs that cant accept nonstandard shapes * Formatting * Add check for split concat * Formatting * Add missing headers * Formatting * Add tests * Formatting * Add more testing * Formatting * Fix when there is duplicate splits in inputs * Formatting * Fix mismatch iterators * Add tests for dot fusions * Formatting * Add test for convolution * Formatting * Fix tidy issues * Add more tests * Formatting * Ignore build directory for codecov * Add test for groups * Formatting * Add more tests for groups * Formatting * Add test for missing end slice * Add newline * Remove unused function * Add support for when beta is not 1 * Formatting * Add test for scalar * Add one more scalar test Co-authored-by: mvermeulen <5479696+mvermeulen@users.noreply.github.com> --- codecov.yml | 2 +- src/CMakeLists.txt | 2 + src/decompose.cpp | 48 ++ src/include/migraphx/algorithm.hpp | 24 + src/include/migraphx/decompose.hpp | 25 + src/include/migraphx/gemm.hpp | 39 ++ src/include/migraphx/op/dot.hpp | 13 + src/include/migraphx/program.hpp | 3 + src/include/migraphx/remap.hpp | 25 + src/include/migraphx/shape.hpp | 2 + src/program.cpp | 20 + src/remap.cpp | 42 ++ src/shape.cpp | 8 + src/simplify_algebra.cpp | 293 +++++++- src/targets/gpu/fuse_ops.cpp | 1 + src/targets/gpu/include/migraphx/gpu/gemm.hpp | 4 +- .../gpu/include/migraphx/gpu/miopen.hpp | 3 +- src/targets/gpu/target.cpp | 5 + test/decompose_test.cpp | 133 ++++ test/shape_test.cpp | 56 ++ test/simplify_algebra_test.cpp | 628 ++++++++++++++++++ 21 files changed, 1370 insertions(+), 6 deletions(-) create mode 100644 src/decompose.cpp create mode 100644 src/include/migraphx/algorithm.hpp create mode 100644 src/include/migraphx/decompose.hpp create mode 100644 src/include/migraphx/gemm.hpp create mode 100644 src/include/migraphx/remap.hpp create mode 100644 src/remap.cpp create mode 100644 test/decompose_test.cpp diff --git a/codecov.yml b/codecov.yml index 3e109b7ca9a..03abe2daeb2 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,4 +1,4 @@ ignore: - "test/" - "src/driver" - + - "build/" diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3302788d75f..ac7639014ce 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,6 +5,7 @@ include(ROCMPackageConfigHelpers) add_library(migraphx auto_contiguous.cpp eliminate_common_subexpression.cpp + decompose.cpp propagate_constant.cpp dead_code_elimination.cpp eliminate_allocation.cpp @@ -20,6 +21,7 @@ add_library(migraphx instruction.cpp program.cpp quantization.cpp + remap.cpp shape.cpp schedule.cpp pass_manager.cpp diff --git a/src/decompose.cpp b/src/decompose.cpp new file mode 100644 index 00000000000..ef995714329 --- /dev/null +++ b/src/decompose.cpp @@ -0,0 +1,48 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace { +struct find_dot_add +{ + auto matcher() const { return match::name("dot")(match::nargs(3)); } + + void apply(program& p, const match::matcher_result& r) const + { + auto ins = r.result; + auto dot = any_cast(ins->get_operator()); + if(not float_equal(dot.beta, 1) and + not contains({shape::float_type, shape::half_type, shape::double_type}, + ins->get_shape().type())) + return; + auto dot_ins = + p.insert_instruction(ins, op::dot{dot.alpha, 0}, ins->inputs()[0], ins->inputs()[1]); + auto c_ins = ins->inputs()[2]; + if(not float_equal(dot.beta, 1)) + { + auto beta = p.add_literal(literal{shape{ins->get_shape().type()}, {dot.beta}}); + auto beta_broadcast = + p.insert_instruction(ins, op::multibroadcast{ins->get_shape().lens()}, beta); + c_ins = p.insert_instruction(ins, op::mul{}, c_ins, beta_broadcast); + } + p.replace_instruction(ins, op::add{}, dot_ins, c_ins); + } +}; + +} // namespace + +void decompose::apply(program& p) const { match::find_matches(p, find_dot_add{}); } + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/include/migraphx/algorithm.hpp b/src/include/migraphx/algorithm.hpp new file mode 100644 index 00000000000..3ece4e413db --- /dev/null +++ b/src/include/migraphx/algorithm.hpp @@ -0,0 +1,24 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP +#define MIGRAPHX_GUARD_RTGLIB_ALGORITHM_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +void group_by(Iterator start, Iterator last, Output out, Predicate pred) +{ + while(start != last) + { + auto it = std::partition(start, last, [&](auto x) { return pred(x, *start); }); + out(start, it); + start = it; + } +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/decompose.hpp b/src/include/migraphx/decompose.hpp new file mode 100644 index 00000000000..60650a0cfe3 --- /dev/null +++ b/src/include/migraphx/decompose.hpp @@ -0,0 +1,25 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP +#define MIGRAPHX_GUARD_RTGLIB_DECOMPOSE_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct program; + +/** + * Decompose operators. + */ +struct decompose +{ + std::string name() const { return "decompose"; } + void apply(program& p) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/gemm.hpp b/src/include/migraphx/gemm.hpp new file mode 100644 index 00000000000..94d1eac2f21 --- /dev/null +++ b/src/include/migraphx/gemm.hpp @@ -0,0 +1,39 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP +#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP + +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +template +void gemm(tensor_view cmat, tensor_view amat, tensor_view bmat, F alpha, F beta) +{ + std::size_t n_dims = cmat.get_shape().lens().size(); + std::size_t dim_0 = n_dims - 2; + std::size_t dim_1 = n_dims - 1; + auto k = amat.get_shape().lens()[dim_1]; + + assert(amat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_0]); + assert(cmat.get_shape().lens()[dim_0] == amat.get_shape().lens()[dim_0]); + assert(cmat.get_shape().lens()[dim_1] == bmat.get_shape().lens()[dim_1]); + + shape_for_each(cmat.get_shape(), [&](const auto& c_idx) { + auto a_idx = c_idx; + auto b_idx = c_idx; + double s = 0.0; + dfor(k)([&](auto kk) { + a_idx[dim_1] = b_idx[dim_0] = kk; + s += amat(a_idx.begin(), a_idx.end()) * bmat(b_idx.begin(), b_idx.end()); + }); + cmat(c_idx.begin(), c_idx.end()) = alpha * s + cmat(c_idx.begin(), c_idx.end()) * beta; + }); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/op/dot.hpp b/src/include/migraphx/op/dot.hpp index c5a70412ae7..d9daa26a64c 100644 --- a/src/include/migraphx/op/dot.hpp +++ b/src/include/migraphx/op/dot.hpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include @@ -67,6 +68,18 @@ struct dot return {t, out_lens}; } + + argument compute(context&, shape output_shape, std::vector args) const + { + argument result; + if(args.size() == 3) + result = args[2]; + else + result = argument{output_shape}; + visit_all(result, args[0], args[1])( + [&](auto cmat, auto amat, auto bmat) { gemm(cmat, amat, bmat, alpha, beta); }); + return result; + } }; } // namespace op diff --git a/src/include/migraphx/program.hpp b/src/include/migraphx/program.hpp index 8ef09bc4767..e8cee3b40f3 100644 --- a/src/include/migraphx/program.hpp +++ b/src/include/migraphx/program.hpp @@ -74,6 +74,7 @@ struct program instruction_ref remove_instructions(instruction_ref first, instruction_ref last); instruction_ref move_instruction(instruction_ref src, instruction_ref dst); + instruction_ref move_instructions(instruction_ref src, instruction_ref dst); template instruction_ref add_literal(Ts&&... xs) @@ -125,6 +126,8 @@ struct program void annotate(std::ostream& os, std::function a) const; + program& sort(); + friend std::ostream& operator<<(std::ostream& os, const program& p); friend bool operator==(const program& x, const program& y); friend bool operator!=(const program& x, const program& y) { return !(x == y); } diff --git a/src/include/migraphx/remap.hpp b/src/include/migraphx/remap.hpp new file mode 100644 index 00000000000..2fa79236966 --- /dev/null +++ b/src/include/migraphx/remap.hpp @@ -0,0 +1,25 @@ +#ifndef MIGRAPHX_GUARD_RTGLIB_REMAP_HPP +#define MIGRAPHX_GUARD_RTGLIB_REMAP_HPP + +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct program; + +/** + * Decompose operators. + */ +struct remap +{ + std::string name() const { return "remap"; } + void apply(program& p) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx + +#endif diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index b535cfea677..a54f0a37d94 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -114,6 +114,8 @@ struct shape /// Returns true if all strides are equal to 0 (scalar tensor) bool scalar() const; + shape normalize_standard() const; + friend bool operator==(const shape& x, const shape& y); friend bool operator!=(const shape& x, const shape& y); friend std::ostream& operator<<(std::ostream& os, const shape& x); diff --git a/src/program.cpp b/src/program.cpp index 2790580d7f9..95a8246311e 100644 --- a/src/program.cpp +++ b/src/program.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace migraphx { @@ -260,6 +261,14 @@ instruction_ref program::move_instruction(instruction_ref src, instruction_ref d return src; } +instruction_ref program::move_instructions(instruction_ref src, instruction_ref dst) +{ + this->move_instruction(src, dst); + for(auto ins : src->inputs()) + this->move_instruction(ins, src); + return src; +} + instruction_ref program::add_literal(literal l) { impl->instructions.emplace_front(std::move(l)); @@ -796,6 +805,17 @@ void program::annotate(std::ostream& os, std::function a) }); } +program& program::sort() +{ + fix([&](auto self, auto ins) { + this->move_instruction(ins, this->begin()); + for(auto child : ins->inputs()) + self(child); + })(std::prev(this->end())); + assert(this->validate() == this->end()); + return *this; +} + bool operator==(const program& x, const program& y) { return to_string(x) == to_string(y); } std::ostream& operator<<(std::ostream& os, const program& p) diff --git a/src/remap.cpp b/src/remap.cpp new file mode 100644 index 00000000000..ea9908aa08d --- /dev/null +++ b/src/remap.cpp @@ -0,0 +1,42 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace { +struct find_dot_add +{ + auto matcher() const + { + return match::name("add")(match::any_of( + match::args(match::name("dot")(match::nargs(2)).bind("dot"), match::any().bind("a")), + match::args(match::used_once().bind("a"), + match::name("dot")(match::nargs(2)).bind("dot")))); + } + + void apply(program& p, match::matcher_result r) const + { + auto ins = r.result; + auto dot_ins = r.instructions["dot"]; + auto a_ins = r.instructions["a"]; + + auto dot = any_cast(dot_ins->get_operator()); + + dot.beta = 1; + p.replace_instruction(ins, dot, dot_ins->inputs()[0], dot_ins->inputs()[1], a_ins); + } +}; +} // namespace + +void remap::apply(program& p) const { match::find_matches(p, find_dot_add{}); } + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/shape.cpp b/src/shape.cpp index cfcd1f325e3..81f06ca63e6 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -195,6 +195,14 @@ bool shape::scalar() const bool shape::standard() const { return impl->m_standard; } +shape shape::normalize_standard() const +{ + if(this->standard()) + return {this->type(), this->lens()}; + else + return *this; +} + std::size_t shape::element_space() const { return impl->element_space(); } std::string shape::type_string() const diff --git a/src/simplify_algebra.cpp b/src/simplify_algebra.cpp index aa80a1242da..2eb034fd47a 100644 --- a/src/simplify_algebra.cpp +++ b/src/simplify_algebra.cpp @@ -4,7 +4,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -12,6 +14,7 @@ #include #include #include +#include namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { @@ -246,6 +249,212 @@ struct find_concat_binary } }; +std::vector get_splits(instruction_ref ins) +{ + std::vector result; + std::copy_if(ins->outputs().begin(), + ins->outputs().end(), + std::back_inserter(result), + [&](auto i) { return i->name() == "slice"; }); + if(result.size() < 2) + return {}; + auto get_slice = [](auto& i) -> auto& { return any_cast(i->get_operator()); }; + auto&& axes = get_slice(result.front()).axes; + if(std::any_of(result.begin(), result.end(), [&](auto i) { return get_slice(i).axes != axes; })) + return {}; + auto get_start = [&](auto& i) -> auto& { return get_slice(i).starts; }; + auto get_end = [&](auto& i) -> auto& { return get_slice(i).ends; }; + std::sort( + result.begin(), result.end(), [&](auto x, auto y) { return get_start(x) < get_start(y); }); + if(std::any_of(get_start(result.front()).begin(), get_start(result.front()).end(), [&](auto i) { + return i != 0; + })) + return {}; + auto it = std::adjacent_find( + result.begin(), result.end(), [&](auto x, auto y) { return get_end(x) != get_start(y); }); + if(it != result.end()) + return {}; + for(std::size_t i = 0; i < axes.size(); i++) + { + auto axis = axes[i]; + if(ins->get_shape().lens()[axis] != get_slice(result.back()).ends[i]) + return {}; + } + return result; +} + +struct find_splits +{ + auto matcher() const + { + return match::any(match::any_of[match::outputs()](match::name("slice")( + match::any_of[match::outputs()](match::name("add", "mul", "relu"))))); + } + + static std::vector> + get_split_groups(const std::vector& splits) + { + std::vector> groups; + for(auto out : splits.front()->outputs()) + { + if(out->name() == "slice") + continue; + std::vector group; + for(auto split : splits) + { + auto it = + std::find_if(split->outputs().begin(), split->outputs().end(), [&](auto i) { + return i->get_operator() == out->get_operator(); + }); + if(it == split->outputs().end()) + break; + assert((*it)->name() != "slice"); + // If there is a duplicate bail + if(contains(group, *it)) + return {}; + group.push_back(*it); + } + if(group.size() != splits.size()) + continue; + groups.push_back(group); + } + return groups; + } + + void apply(program& p, const match::matcher_result& r) const + { + auto ins = r.result; + + auto splits = get_splits(ins); + if(splits.empty()) + return; + for(const auto& group : get_split_groups(splits)) + { + auto start = group.front(); + auto op = start->get_operator(); + if(op.name() == "slice") + continue; + + // Make sure there is no duplicates + assert(std::none_of( + std::next(group.begin()), group.end(), [&](auto i) { return i == start; })); + + auto split_idx = 0; + instruction_ref c = p.end(); + if(start->inputs().size() == 1) + { + c = p.insert_instruction(std::next(ins), op, ins); + } + else if(start->inputs().size() == 2) + { + assert(not std::none_of(start->inputs().begin(), start->inputs().end(), [](auto i) { + return i->name() == "slice"; + }) && "one argument must be a split"); + auto data_idx = 1; + if(start->inputs().back()->name() == "slice") + { + split_idx = 1; + data_idx = 0; + } + + std::vector data_args; + std::transform(group.begin(), + group.end(), + std::back_inserter(data_args), + [&](auto i) { return i->inputs()[data_idx]; }); + + // Data arguments must be a constant + if(std::any_of(data_args.begin(), data_args.end(), [](auto i) { + return not i->can_eval(); + })) + return; + + for(auto data : data_args) + p.move_instructions(data, ins); + + auto slice_op = any_cast(splits.front()->get_operator()); + assert(not slice_op.axes.empty()); + if(slice_op.axes.size() > 1) + return; + auto concat_axis = slice_op.axes.front(); + // TODO: Check if axises match + auto concat = p.insert_instruction(ins, op::concat{concat_axis}, data_args); + + std::vector args; + args.resize(2); + args[split_idx] = ins; + args[data_idx] = concat; + c = p.insert_instruction(std::next(ins), op, args); + } + if(c != p.end()) + { + for(auto i : group) + { + auto split = i->inputs()[split_idx]; + assert(split->name() == "slice"); + // Insert contiguous for reshapes + for(auto output : i->outputs()) + { + if(not contains({"reshape", "squeeze", "unsqueeze"}, output->name())) + continue; + auto x = p.insert_instruction(output, op::contiguous{}, output->inputs()); + p.replace_instruction(output, output->get_operator(), x); + } + + p.replace_instruction(i, split->get_operator(), c); + } + } + } + } +}; + +struct find_split_concat +{ + auto matcher() const + { + return match::any(match::any_of[match::outputs()]( + match::name("slice")(match::all_of[match::outputs()](match::name("concat"))))); + } + + void apply(program& p, const match::matcher_result& r) const + { + auto ins = r.result; + + auto splits = get_splits(ins); + if(splits.empty()) + return; + if(std::any_of( + splits.begin(), splits.end(), [](auto i) { return i->outputs().size() != 1; })) + return; + // Check for concat operator + auto concat = splits.front()->outputs().front(); + if(std::any_of(splits.begin(), splits.end(), [&](auto i) { + return i->outputs().front() != concat; + })) + return; + // Check axis match + auto concat_op = any_cast(concat->get_operator()); + auto split_op = any_cast(splits.front()->get_operator()); + if(split_op.axes.size() != 1) + return; + if(split_op.axes.front() != concat_op.axis) + return; + // Replace args + auto args = concat->inputs(); + auto it = + std::find_if(args.begin(), args.end(), [&](auto i) { return i == splits.front(); }); + if(std::distance(it, args.end()) < splits.size()) + return; + *it = splits.front()->inputs().front(); + args.erase(std::next(it), it + splits.size()); + + if(args.size() == 1) + p.replace_instruction(concat, args.front()); + else + p.replace_instruction(concat, concat->get_operator(), args); + } +}; + bool axis_equal(const std::vector& x, const std::vector& y, std::size_t axis) @@ -352,6 +561,83 @@ struct find_add_convs } }; +MIGRAPHX_PRED_MATCHER(horiz_conv_dot, instruction_ref ins) +{ + auto pred = [&](auto name) { + return [=](auto i) { + return i->name() == name and i->inputs().front() == ins and + i->inputs().at(1)->can_eval(); + }; + }; + auto dots = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("dot")); + auto convs = std::count_if(ins->outputs().begin(), ins->outputs().end(), pred("convolution")); + return !(dots < 2 and convs < 2); +} + +struct find_conv_dot_horiz_fusion +{ + auto matcher() const { return horiz_conv_dot(); } + + void apply(program& p, const match::matcher_result& r) const + { + auto ins = r.result; + + auto pred = [](auto i, auto j) { + if(i->get_operator() != j->get_operator()) + return false; + if(not contains({"dot", "convolution"}, i->name())) + return true; + auto x = i->inputs()[1]->get_shape().lens(); + auto y = j->inputs()[1]->get_shape().lens(); + if(x.size() != y.size()) + return false; + // Check that non-axises match + int axis = 1; + if(i->name() == "dot") + { + axis = x.size() - 1; + } + return axis_equal(x, y, axis); + }; + + auto each = [&](auto start, auto last) { + if(std::distance(start, last) < 2) + return; + auto&& name = (*start)->name(); + if(not contains({"dot", "convolution"}, name)) + return; + auto input = (*start)->inputs().front(); + std::vector args; + std::transform( + start, last, std::back_inserter(args), [&](auto x) { return x->inputs().at(1); }); + int axis = 1; + int concat_axis = 0; + if(name == "dot") + { + axis = int(args.front()->get_shape().lens().size() - 1); + concat_axis = axis; + } + + for(auto arg : args) + p.move_instructions(arg, input); + // TODO: Check if axises match + auto concat = p.insert_instruction(input, op::concat{concat_axis}, args); + auto fused = + p.insert_instruction(std::next(input), (*start)->get_operator(), input, concat); + int64_t offset = 0; + for(auto arg : range(start, last)) + { + int64_t len = arg->get_shape().lens()[axis]; + p.replace_instruction(arg, op::slice{{axis}, {offset}, {offset + len}}, fused); + offset += len; + } + }; + + auto outputs = ins->outputs(); + group_by(outputs.begin(), outputs.end(), each, pred); + } +}; + struct find_div_const { auto matcher() const @@ -412,20 +698,23 @@ struct find_rsqrt void simplify_algebra::apply(program& p) const { // Run simplifications multiple times - for(int i = 0; i < 4; i++) + for(int i = 0; i < 8; i++) { match::find_matches(p, find_inner_broadcast{}, find_double_add_lit_broadcast{}, find_add_lit_broadcast{}, find_add_convs{}, + find_conv_dot_horiz_fusion{}, find_mul_conv{}, find_mul_add{}, find_div_const{}, find_sub_const{}, find_rsqrt{}, find_concat_unary{}, - find_concat_binary{}); + find_concat_binary{}, + find_split_concat{}, + find_splits{}); dead_code_elimination{}.apply(p); } } diff --git a/src/targets/gpu/fuse_ops.cpp b/src/targets/gpu/fuse_ops.cpp index d6ec0d438bf..6b571806c85 100644 --- a/src/targets/gpu/fuse_ops.cpp +++ b/src/targets/gpu/fuse_ops.cpp @@ -40,6 +40,7 @@ struct fusion fusion(const shape& input) // : fp(make_fusion_plan(input)) { + assert(input.standard()); auto t = make_tensor(input); fp = make_fusion_plan(t); keep_alive(std::move(t)); diff --git a/src/targets/gpu/include/migraphx/gpu/gemm.hpp b/src/targets/gpu/include/migraphx/gpu/gemm.hpp index 2d0a6d69fbc..286721b1fb3 100644 --- a/src/targets/gpu/include/migraphx/gpu/gemm.hpp +++ b/src/targets/gpu/include/migraphx/gpu/gemm.hpp @@ -1,5 +1,5 @@ -#ifndef MIGRAPHX_GUARD_RTGLIB_GEMM_HPP -#define MIGRAPHX_GUARD_RTGLIB_GEMM_HPP +#ifndef MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP +#define MIGRAPHX_GUARD_RTGLIB_GPU_GEMM_HPP #include #include diff --git a/src/targets/gpu/include/migraphx/gpu/miopen.hpp b/src/targets/gpu/include/migraphx/gpu/miopen.hpp index 6e714fd0fc0..a6140f52380 100644 --- a/src/targets/gpu/include/migraphx/gpu/miopen.hpp +++ b/src/targets/gpu/include/migraphx/gpu/miopen.hpp @@ -38,8 +38,9 @@ Result make_obj(F f, Ts... xs) return r; } -inline tensor_descriptor make_tensor(const migraphx::shape& s, bool pack = false) +inline tensor_descriptor make_tensor(const migraphx::shape& os, bool pack = false) { + auto s = os.normalize_standard(); auto t = make_obj(&miopenCreateTensorDescriptor); // Convert to ints std::vector lens(s.lens().begin(), s.lens().end()); diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index 3a47316571a..eae05552dde 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -25,6 +25,8 @@ #include #include #include +#include +#include #include namespace migraphx { @@ -39,6 +41,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti // clang-format off return { + decompose{}, dead_code_elimination{}, simplify_reshapes{}, dead_code_elimination{}, @@ -59,6 +62,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, propagate_constant{}, dead_code_elimination{}, + remap{}, + dead_code_elimination{}, lowering{&ctx, options.offload_copy}, eliminate_contiguous{}, dead_code_elimination{}, diff --git a/test/decompose_test.cpp b/test/decompose_test.cpp new file mode 100644 index 00000000000..a839205d557 --- /dev/null +++ b/test/decompose_test.cpp @@ -0,0 +1,133 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +void run_pass(migraphx::program& p) { migraphx::run_passes(p, {migraphx::decompose{}}); } + +TEST_CASE(dot_add) +{ + migraphx::program p1; + { + auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto dot = p1.add_instruction(migraphx::op::dot{}, x, y, z); + p1.add_instruction(migraphx::op::identity{}, dot); + } + run_pass(p1); + migraphx::program p2; + { + auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); + auto add = p2.add_instruction(migraphx::op::add{}, dot, z); + p2.add_instruction(migraphx::op::identity{}, add); + } + EXPECT(p1 == p2); +} + +TEST_CASE(dot_add_beta_float) +{ + migraphx::program p1; + { + auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); + p1.add_instruction(migraphx::op::identity{}, dot); + } + run_pass(p1); + migraphx::program p2; + { + auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::float_type, {2, 2}}); + auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); + auto beta = + p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::float_type}, {0.5}}); + auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); + auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast); + auto add = p2.add_instruction(migraphx::op::add{}, dot, mul); + p2.add_instruction(migraphx::op::identity{}, add); + } + EXPECT(p1 == p2); +} + +TEST_CASE(dot_add_beta_half) +{ + migraphx::program p1; + { + auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); + p1.add_instruction(migraphx::op::identity{}, dot); + } + run_pass(p1); + migraphx::program p2; + { + auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::half_type, {2, 2}}); + auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); + auto beta = + p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::half_type}, {0.5}}); + auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); + auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast); + auto add = p2.add_instruction(migraphx::op::add{}, dot, mul); + p2.add_instruction(migraphx::op::identity{}, add); + } + EXPECT(p1 == p2); +} + +TEST_CASE(dot_add_beta_double) +{ + migraphx::program p1; + { + auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); + p1.add_instruction(migraphx::op::identity{}, dot); + } + run_pass(p1); + migraphx::program p2; + { + auto x = p2.add_parameter("x", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto y = p2.add_parameter("y", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto z = p2.add_parameter("z", migraphx::shape{migraphx::shape::double_type, {2, 2}}); + auto dot = p2.add_instruction(migraphx::op::dot{1, 0}, x, y); + auto beta = + p2.add_literal(migraphx::literal{migraphx::shape{migraphx::shape::double_type}, {0.5}}); + auto beta_broadcast = p2.add_instruction(migraphx::op::multibroadcast{{2, 2}}, beta); + auto mul = p2.add_instruction(migraphx::op::mul{}, z, beta_broadcast); + auto add = p2.add_instruction(migraphx::op::add{}, dot, mul); + p2.add_instruction(migraphx::op::identity{}, add); + } + EXPECT(p1 == p2); +} + +TEST_CASE(dot_add_beta_int) +{ + migraphx::program p1; + { + auto x = p1.add_parameter("x", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); + auto y = p1.add_parameter("y", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); + auto z = p1.add_parameter("z", migraphx::shape{migraphx::shape::int32_type, {2, 2}}); + auto dot = p1.add_instruction(migraphx::op::dot{1.0, 0.5}, x, y, z); + p1.add_instruction(migraphx::op::identity{}, dot); + } + migraphx::program p2 = p1; + run_pass(p1); + EXPECT(p1 == p2); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } diff --git a/test/shape_test.cpp b/test/shape_test.cpp index cd1b347fc89..a60673424ff 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -47,6 +47,15 @@ TEST_CASE(test_shape_packed) EXPECT(not s.broadcasted()); } +TEST_CASE(test_shape_non_packed_single_dim) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}}; + EXPECT(s.standard()); + EXPECT(s.packed()); + EXPECT(not s.transposed()); + EXPECT(not s.broadcasted()); +} + TEST_CASE(test_shape_transposed1) { migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}}; @@ -172,6 +181,53 @@ TEST_CASE(test_shape_default_copy) EXPECT(!(s1 != s2)); } +TEST_CASE(test_shape_normalize_standard1) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2, 3}, {6, 3, 1}}; + EXPECT(s.standard()); + auto n = s.normalize_standard(); + EXPECT(n == s); +} + +TEST_CASE(test_shape_normalize_standard2) +{ + migraphx::shape s{migraphx::shape::float_type, {1, 64, 35, 35}, {156800, 1225, 35, 1}}; + EXPECT(s.standard()); + auto n = s.normalize_standard(); + EXPECT(n.standard()); + EXPECT(n != s); + EXPECT(n.lens() == s.lens()); + EXPECT(n.type() == s.type()); +} + +TEST_CASE(test_shape_normalize_standard3) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2}, {1, 2}}; + EXPECT(not s.standard()); + auto n = s.normalize_standard(); + EXPECT(n == s); +} + +TEST_CASE(test_shape_normalize_scalar1) +{ + migraphx::shape s{migraphx::shape::float_type}; + EXPECT(s.standard()); + EXPECT(s.scalar()); + auto n = s.normalize_standard(); + EXPECT(n != s); + EXPECT(n.standard()); + EXPECT(not n.scalar()); +} + +TEST_CASE(test_shape_normalize_scalar2) +{ + migraphx::shape s{migraphx::shape::float_type, {2, 2}, {0, 0}}; + EXPECT(not s.standard()); + EXPECT(s.scalar()); + auto n = s.normalize_standard(); + EXPECT(n == s); +} + TEST_CASE(test_shape4) { migraphx::shape s{migraphx::shape::float_type, {100, 32, 8, 8}}; diff --git a/test/simplify_algebra_test.cpp b/test/simplify_algebra_test.cpp index 6c018424103..c0596be427c 100644 --- a/test/simplify_algebra_test.cpp +++ b/test/simplify_algebra_test.cpp @@ -563,6 +563,7 @@ TEST_CASE(simplify_rsqrt) migraphx::program p2; { + auto x = p2.add_parameter("x", {migraphx::shape::int32_type, {1}}); p2.add_instruction(migraphx::op::rsqrt{}, x); } @@ -585,4 +586,631 @@ TEST_CASE(simplify_rsqrt_multi_use) EXPECT(p1 == p2); } +TEST_CASE(simplify_split_add_relu) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + p1.add_instruction(pass_op{}, add); + } + run_pass(p1); + + migraphx::program p2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = p2.add_parameter("input", s); + auto one = p2.add_literal(1); + auto two = p2.add_literal(2); + auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two); + auto concatb = p2.add_instruction(b, concat); + auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb); + auto relu = p2.add_instruction(migraphx::op::relu{}, sum); + auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu); + auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu); + auto add = p2.add_instruction(migraphx::op::add{}, x, y); + p2.add_instruction(pass_op{}, add); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_split_add_relu_reshape) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto r = migraphx::op::reshape{{3, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto reshape1 = p1.add_instruction(r, relu1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto reshape2 = p1.add_instruction(r, relu2); + auto add = p1.add_instruction(migraphx::op::add{}, reshape1, reshape2); + p1.add_instruction(pass_op{}, add); + } + run_pass(p1); + + migraphx::program p2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto r = migraphx::op::reshape{{3, 4}}; + auto input = p2.add_parameter("input", s); + auto one = p2.add_literal(1); + auto two = p2.add_literal(2); + auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two); + auto concatb = p2.add_instruction(b, concat); + auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb); + auto relu = p2.add_instruction(migraphx::op::relu{}, sum); + auto slice1 = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu); + auto cont1 = p2.add_instruction(migraphx::op::contiguous{}, slice1); + auto reshape1 = p2.add_instruction(r, cont1); + auto slice2 = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu); + auto cont2 = p2.add_instruction(migraphx::op::contiguous{}, slice2); + auto reshape2 = p2.add_instruction(r, cont2); + auto add = p2.add_instruction(migraphx::op::add{}, reshape1, reshape2); + p2.add_instruction(pass_op{}, add); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_slice_different_axis) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 2}}; + migraphx::program p1; + { + auto r = migraphx::op::reshape{{3, 2, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{3}, {0}, {1}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(migraphx::op::broadcast{1, {3, 1, 4, 2}}, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(migraphx::op::broadcast{3, {3, 2, 4, 1}}, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto reshape1 = p1.add_instruction(r, relu1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto reshape2 = p1.add_instruction(r, relu2); + auto add = p1.add_instruction(migraphx::op::add{}, reshape1, reshape2); + p1.add_instruction(pass_op{}, add); + } + migraphx::program p2 = p1; + run_pass(p1); + + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_slice_missing_begining_slice) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + p1.add_instruction(pass_op{}, add); + } + migraphx::program p2 = p1; + run_pass(p1); + + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_slice_missing_middle_slice) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {2}, {3}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + p1.add_instruction(pass_op{}, add); + } + migraphx::program p2 = p1; + run_pass(p1); + + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_slice_missing_end_slice) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 3, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + p1.add_instruction(pass_op{}, add); + } + migraphx::program p2 = p1; + run_pass(p1); + + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_split_add_relu_concat_same_axis) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto concat = p1.add_instruction(migraphx::op::concat{1}, relu1, relu2); + p1.add_instruction(pass_op{}, concat); + } + run_pass(p1); + + migraphx::program p2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = p2.add_parameter("input", s); + auto one = p2.add_literal(1); + auto two = p2.add_literal(2); + auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two); + auto concatb = p2.add_instruction(b, concat); + auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb); + auto relu = p2.add_instruction(migraphx::op::relu{}, sum); + p2.add_instruction(pass_op{}, relu); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_split_add_relu_multi_axes) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4, 6}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4, 3}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1, 3}, {0, 0}, {1, 3}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1, 3}, {1, 3}, {2, 6}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + p1.add_instruction(pass_op{}, add); + } + migraphx::program p2 = p1; + run_pass(p1); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_split_add_relu_used_multiple_split1) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add1 = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + auto add2 = p1.add_instruction(migraphx::op::add{}, x, add1); + p1.add_instruction(pass_op{}, add2); + } + run_pass(p1); + + migraphx::program p2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = p2.add_parameter("input", s); + auto slice = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto one = p2.add_literal(1); + auto two = p2.add_literal(2); + auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two); + auto concatb = p2.add_instruction(b, concat); + auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb); + auto relu = p2.add_instruction(migraphx::op::relu{}, sum); + auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu); + auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu); + auto add1 = p2.add_instruction(migraphx::op::add{}, x, y); + auto add2 = p2.add_instruction(migraphx::op::add{}, slice, add1); + p2.add_instruction(pass_op{}, add2); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_split_add_relu_used_multiple_split2) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::program p1; + { + auto b = migraphx::op::broadcast{1, {3, 1, 4}}; + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto z = p1.add_instruction(migraphx::op::relu{}, x); + auto one = p1.add_literal(1); + auto oneb = p1.add_instruction(b, one); + auto two = p1.add_literal(2); + auto twob = p1.add_instruction(b, two); + auto sum1 = p1.add_instruction(migraphx::op::add{}, x, oneb); + auto relu1 = p1.add_instruction(migraphx::op::relu{}, sum1); + auto sum2 = p1.add_instruction(migraphx::op::add{}, y, twob); + auto relu2 = p1.add_instruction(migraphx::op::relu{}, sum2); + auto add1 = p1.add_instruction(migraphx::op::add{}, relu1, relu2); + auto add2 = p1.add_instruction(migraphx::op::add{}, z, add1); + p1.add_instruction(pass_op{}, add2); + } + run_pass(p1); + + migraphx::program p2; + { + auto b = migraphx::op::broadcast{1, {3, 2, 4}}; + auto input = p2.add_parameter("input", s); + auto slice = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto z = p2.add_instruction(migraphx::op::relu{}, slice); + auto one = p2.add_literal(1); + auto two = p2.add_literal(2); + auto concat = p2.add_instruction(migraphx::op::concat{0}, one, two); + auto concatb = p2.add_instruction(b, concat); + auto sum = p2.add_instruction(migraphx::op::add{}, input, concatb); + auto relu = p2.add_instruction(migraphx::op::relu{}, sum); + auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, relu); + auto y = p2.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, relu); + auto add1 = p2.add_instruction(migraphx::op::add{}, x, y); + auto add2 = p2.add_instruction(migraphx::op::add{}, z, add1); + p2.add_instruction(pass_op{}, add2); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_split_between_add) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 4}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto x = p1.add_instruction(migraphx::op::slice{{1}, {0}, {1}}, input); + auto y = p1.add_instruction(migraphx::op::slice{{1}, {1}, {2}}, input); + auto sum = p1.add_instruction(migraphx::op::add{}, x, y); + p1.add_instruction(pass_op{}, sum); + } + migraphx::program p2 = p1; + run_pass(p1); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_dot_horiz) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(s, 0)); + auto b = p1.add_literal(migraphx::generate_literal(s, 1)); + auto x = p1.add_instruction(migraphx::op::dot{}, input, a); + auto y = p1.add_instruction(migraphx::op::dot{}, input, b); + auto sum = p1.add_instruction(migraphx::op::add{}, x, y); + p1.add_instruction(pass_op{}, sum); + } + run_pass(p1); + + migraphx::program p2; + { + auto input = p2.add_parameter("input", s); + auto a = p2.add_literal(migraphx::generate_literal(s, 0)); + auto b = p2.add_literal(migraphx::generate_literal(s, 1)); + auto concat = p2.add_instruction(migraphx::op::concat{2}, a, b); + auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat); + auto x = p2.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot); + auto y = p2.add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot); + auto sum = p2.add_instruction(migraphx::op::add{}, x, y); + p2.add_instruction(pass_op{}, sum); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_dot_horiz_same_constant) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(s, 0)); + auto x = p1.add_instruction(migraphx::op::dot{}, input, a); + auto y = p1.add_instruction(migraphx::op::dot{}, input, a); + auto sum = p1.add_instruction(migraphx::op::add{}, x, y); + p1.add_instruction(pass_op{}, sum); + } + run_pass(p1); + + migraphx::program p2; + { + auto input = p2.add_parameter("input", s); + auto a = p2.add_literal(migraphx::generate_literal(s, 0)); + auto concat = p2.add_instruction(migraphx::op::concat{2}, a, a); + auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat); + auto x = p2.add_instruction(migraphx::op::slice{{2}, {0}, {2}}, dot); + auto y = p2.add_instruction(migraphx::op::slice{{2}, {2}, {4}}, dot); + auto sum = p2.add_instruction(migraphx::op::add{}, x, y); + p2.add_instruction(pass_op{}, sum); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_dot_horiz_flipped) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {3, 2, 2}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(s, 0)); + auto b = p1.add_literal(migraphx::generate_literal(s, 1)); + auto x = p1.add_instruction(migraphx::op::dot{}, input, a); + auto y = p1.add_instruction(migraphx::op::dot{}, b, input); + auto sum = p1.add_instruction(migraphx::op::add{}, x, y); + p1.add_instruction(pass_op{}, sum); + } + + migraphx::program p2 = p1; + run_pass(p1); + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_conv_horiz) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 3, 64, 64}}; + auto ws = migraphx::shape{migraphx::shape::int32_type, {12, 3, 3, 3}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(ws, 0)); + auto b = p1.add_literal(migraphx::generate_literal(ws, 1)); + auto x = p1.add_instruction(migraphx::op::convolution{}, input, a); + auto y = p1.add_instruction(migraphx::op::convolution{}, input, b); + auto sum = p1.add_instruction(migraphx::op::add{}, x, y); + p1.add_instruction(pass_op{}, sum); + } + run_pass(p1); + + migraphx::program p2; + { + auto input = p2.add_parameter("input", s); + auto a = p2.add_literal(migraphx::generate_literal(ws, 0)); + auto b = p2.add_literal(migraphx::generate_literal(ws, 1)); + auto concat = p2.add_instruction(migraphx::op::concat{0}, a, b); + auto conv = p2.add_instruction(migraphx::op::convolution{}, input, concat); + auto x = p2.add_instruction(migraphx::op::slice{{1}, {0}, {12}}, conv); + auto y = p2.add_instruction(migraphx::op::slice{{1}, {12}, {24}}, conv); + auto sum = p2.add_instruction(migraphx::op::add{}, x, y); + p2.add_instruction(pass_op{}, sum); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_conv_horiz_groups) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; + auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = p1.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = p1.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = p1.add_literal(migraphx::generate_literal(ws2, 3)); + auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a); + auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b); + auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c); + auto doty = p1.add_instruction(migraphx::op::dot{}, input, d); + auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy); + auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty); + auto sum3 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); + + p1.add_instruction(pass_op{}, sum3); + } + run_pass(p1); + + migraphx::program p2; + { + auto input = p2.add_parameter("input", s); + auto a = p2.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = p2.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = p2.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = p2.add_literal(migraphx::generate_literal(ws2, 3)); + auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b); + auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d); + auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1); + auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv); + auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv); + auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy); + auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2); + auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot); + auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot); + auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty); + auto sum3 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); + p2.add_instruction(pass_op{}, sum3); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_conv_horiz_groups_extra1) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; + auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = p1.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = p1.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = p1.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = p1.add_literal(migraphx::generate_literal(s, 4)); + auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a); + auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b); + auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c); + auto doty = p1.add_instruction(migraphx::op::dot{}, input, d); + auto sqdiffx = p1.add_instruction(migraphx::op::sqdiff{}, input, e); + auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy); + auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty); + auto sum3 = sqdiffx; + auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); + auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3); + p1.add_instruction(pass_op{}, sum5); + } + run_pass(p1); + + migraphx::program p2; + { + auto input = p2.add_parameter("input", s); + auto a = p2.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = p2.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = p2.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = p2.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = p2.add_literal(migraphx::generate_literal(s, 4)); + auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b); + auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d); + auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1); + auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv); + auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv); + auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy); + auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2); + auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot); + auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot); + auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty); + auto sqdiffx = p2.add_instruction(migraphx::op::sqdiff{}, input, e); + auto sum3 = sqdiffx; + auto sum4 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); + auto sum5 = p2.add_instruction(migraphx::op::add{}, sum4, sum3); + p2.add_instruction(pass_op{}, sum5); + } + EXPECT(p1.sort() == p2.sort()); +} + +TEST_CASE(simplify_conv_horiz_groups_extra2) +{ + auto s = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + auto ws1 = migraphx::shape{migraphx::shape::int32_type, {6, 6, 3, 3}}; + auto ws2 = migraphx::shape{migraphx::shape::int32_type, {8, 6, 64, 64}}; + migraphx::program p1; + { + auto input = p1.add_parameter("input", s); + auto a = p1.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = p1.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = p1.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = p1.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = p1.add_literal(migraphx::generate_literal(s, 4)); + auto f = p1.add_literal(migraphx::generate_literal(s, 5)); + auto convx = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, a); + auto convy = p1.add_instruction(migraphx::op::convolution{{1, 1}}, input, b); + auto dotx = p1.add_instruction(migraphx::op::dot{}, input, c); + auto doty = p1.add_instruction(migraphx::op::dot{}, input, d); + auto sqdiffx = p1.add_instruction(migraphx::op::sqdiff{}, input, e); + auto sqdiffy = p1.add_instruction(migraphx::op::sqdiff{}, input, f); + auto sum1 = p1.add_instruction(migraphx::op::add{}, convx, convy); + auto sum2 = p1.add_instruction(migraphx::op::add{}, dotx, doty); + auto sum3 = p1.add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy); + auto sum4 = p1.add_instruction(migraphx::op::add{}, sum1, sum2); + auto sum5 = p1.add_instruction(migraphx::op::add{}, sum4, sum3); + p1.add_instruction(pass_op{}, sum5); + } + run_pass(p1); + + migraphx::program p2; + { + auto input = p2.add_parameter("input", s); + auto a = p2.add_literal(migraphx::generate_literal(ws1, 0)); + auto b = p2.add_literal(migraphx::generate_literal(ws1, 1)); + auto c = p2.add_literal(migraphx::generate_literal(ws2, 2)); + auto d = p2.add_literal(migraphx::generate_literal(ws2, 3)); + auto e = p2.add_literal(migraphx::generate_literal(s, 4)); + auto f = p2.add_literal(migraphx::generate_literal(s, 5)); + auto concat1 = p2.add_instruction(migraphx::op::concat{0}, a, b); + auto concat2 = p2.add_instruction(migraphx::op::concat{3}, c, d); + auto conv = p2.add_instruction(migraphx::op::convolution{{1, 1}}, input, concat1); + auto convx = p2.add_instruction(migraphx::op::slice{{1}, {0}, {6}}, conv); + auto convy = p2.add_instruction(migraphx::op::slice{{1}, {6}, {12}}, conv); + auto sum1 = p2.add_instruction(migraphx::op::add{}, convx, convy); + auto dot = p2.add_instruction(migraphx::op::dot{}, input, concat2); + auto dotx = p2.add_instruction(migraphx::op::slice{{3}, {0}, {64}}, dot); + auto doty = p2.add_instruction(migraphx::op::slice{{3}, {64}, {128}}, dot); + auto sum2 = p2.add_instruction(migraphx::op::add{}, dotx, doty); + auto sqdiffx = p2.add_instruction(migraphx::op::sqdiff{}, input, e); + auto sqdiffy = p2.add_instruction(migraphx::op::sqdiff{}, input, f); + auto sum3 = p2.add_instruction(migraphx::op::add{}, sqdiffx, sqdiffy); + auto sum4 = p2.add_instruction(migraphx::op::add{}, sum1, sum2); + auto sum5 = p2.add_instruction(migraphx::op::add{}, sum4, sum3); + p2.add_instruction(pass_op{}, sum5); + } + EXPECT(p1.sort() == p2.sort()); +} + int main(int argc, const char* argv[]) { test::run(argc, argv); }