From 51a24d87036aa8b7b54b1a890486611f3747ecf0 Mon Sep 17 00:00:00 2001 From: Paul Fultz II Date: Thu, 29 Aug 2024 16:08:08 -0500 Subject: [PATCH] Move auto_contiguous before lowering (#3343) --- src/include/migraphx/shape.hpp | 1 + src/shape.cpp | 24 ++++++++++++ src/targets/gpu/code_object_op.cpp | 2 +- src/targets/gpu/fuse_mlir.cpp | 3 +- src/targets/gpu/target.cpp | 4 +- test/shape_test.cpp | 63 ++++++++++++++++++++++++++++++ 6 files changed, 94 insertions(+), 3 deletions(-) diff --git a/src/include/migraphx/shape.hpp b/src/include/migraphx/shape.hpp index 0c1e7b269d4..2db61a60aa6 100644 --- a/src/include/migraphx/shape.hpp +++ b/src/include/migraphx/shape.hpp @@ -147,6 +147,7 @@ struct MIGRAPHX_EXPORT shape static std::string cpp_type(type_t t); static bool is_integral(type_t t); + static bool is_compatible(const shape& actual, const shape& expected); shape(); shape(type_t t); diff --git a/src/shape.cpp b/src/shape.cpp index f9a42361465..99c99d3bc16 100644 --- a/src/shape.cpp +++ b/src/shape.cpp @@ -260,6 +260,30 @@ bool shape::is_integral(shape::type_t t) return result; } +bool shape::is_compatible(const shape& actual, const shape& expected) +{ + // Check subshapes + if(expected.type() == shape::tuple_type) + return migraphx::equal(actual.sub_shapes(), expected.sub_shapes(), &is_compatible); + if(actual == expected) + return true; + if(actual.type() != expected.type()) + return false; + // Only the expected can be dynamic + if(expected.dynamic()) + return actual.ndim() == expected.ndim(); + if(actual.dynamic()) + return false; + if(actual.lens() != expected.lens()) + return false; + // Check strides from dimensions that are not 1 + return all_of(range(actual.lens().size()), [&](auto i) { + if(actual.lens()[i] == 1) + return true; + return actual.strides()[i] == expected.strides()[i]; + }); +} + shape::shape() : impl(shape_impl::default_shape()) {} shape::shape(type_t t) : impl(std::make_shared(t)) {} diff --git a/src/targets/gpu/code_object_op.cpp b/src/targets/gpu/code_object_op.cpp index e78316ab139..98a580dc435 100644 --- a/src/targets/gpu/code_object_op.cpp +++ b/src/targets/gpu/code_object_op.cpp @@ -40,7 +40,7 @@ shape code_object_op::compute_shape(std::vector inputs) const std::transform(einputs.begin(), einputs.end(), einputs.begin(), [](const shape& s) { return s.normalize_standard(); }); - if(flatten(einputs) != flatten(inputs)) + if(not migraphx::equal(flatten(einputs), flatten(inputs), &shape::is_compatible)) MIGRAPHX_THROW("Input shapes have changed: [" + to_string_range(einputs) + "] -> [" + to_string_range(inputs) + "]"); return output; diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 1acdd36eab1..becf1dadea6 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -823,7 +823,8 @@ struct find_pointwise_mlir m->add_return({ret}); auto inputs = find_inputs(map_ins, &mpm.get_module(), m); - mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs, {m}); + mpm.get_module().replace_instruction( + ins, ins->get_operator(), mlir_contiguous(mpm, inputs), {m}); } }; diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index a1627e38ed9..9db5ad562b1 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -154,7 +154,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, prefuse_ops{}, dead_code_elimination{}, - auto_contiguous{}, + enable_pass(not enabled(MIGRAPHX_ENABLE_NHWC{}), auto_contiguous{}), eliminate_data_type{{migraphx::shape::fp8e4m3fnuz_type}, shape::float_type, unsupported_fp8_ops}, dead_code_elimination{}, rewrite_reduce{}, @@ -172,6 +172,8 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, fuse_concat{}, dead_code_elimination{}, + enable_pass(enabled(MIGRAPHX_ENABLE_NHWC{}), auto_contiguous{}), + dead_code_elimination{}, lowering{&ctx, options.offload_copy}, eliminate_contiguous{"gpu::contiguous"}, dead_code_elimination{}, diff --git a/test/shape_test.cpp b/test/shape_test.cpp index 22ac7f54c0d..ab6d6e3f718 100644 --- a/test/shape_test.cpp +++ b/test/shape_test.cpp @@ -1099,5 +1099,68 @@ TEST_CASE(multi_within_bounds) std::vector multi_4 = {1, 2, 1}; EXPECT(not in_shape.multi_within_bounds(multi_4)); } +TEST_CASE(shape_is_compatible_diff_strides) +{ + migraphx::shape actual{migraphx::shape::float_type, {1, 1, 8}, {8, 8, 1}}; + migraphx::shape expected{migraphx::shape::float_type, {1, 1, 8}, {1, 1, 1}}; + EXPECT(actual != expected); + EXPECT(migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_diff_lens) +{ + migraphx::shape actual{migraphx::shape::float_type, {1, 2, 8}, {8, 8, 1}}; + migraphx::shape expected{migraphx::shape::float_type, {1, 1, 8}, {8, 8, 1}}; + EXPECT(actual != expected); + EXPECT(not migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_diff_type) +{ + migraphx::shape actual{migraphx::shape::float_type, {1, 2, 8}}; + migraphx::shape expected{migraphx::shape::half_type, {1, 2, 8}}; + EXPECT(actual != expected); + EXPECT(not migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_dynamic) +{ + migraphx::shape actual{migraphx::shape::float_type, {1, 2, 2, 4}}; + migraphx::shape expected{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}}; + EXPECT(actual != expected); + EXPECT(migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_dynamic_actual) +{ + migraphx::shape actual{migraphx::shape::float_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}}; + migraphx::shape expected{migraphx::shape::float_type, {1, 2, 2, 4}}; + EXPECT(actual != expected); + EXPECT(not migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_dynamic_diff_type) +{ + migraphx::shape actual{migraphx::shape::float_type, {1, 2, 2, 4}}; + migraphx::shape expected{migraphx::shape::half_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}}; + EXPECT(actual != expected); + EXPECT(not migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_dynamic_diff_rank) +{ + migraphx::shape actual{migraphx::shape::float_type, {1, 2, 2}}; + migraphx::shape expected{migraphx::shape::half_type, {{1, 1}, {2, 4}, {2, 4}, {2, 4}}}; + EXPECT(actual != expected); + EXPECT(not migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_diff_strides_tuple) +{ + migraphx::shape actual{migraphx::shape{migraphx::shape::float_type, {1, 1, 8}, {8, 8, 1}}}; + migraphx::shape expected{migraphx::shape{migraphx::shape::float_type, {1, 1, 8}, {1, 1, 1}}}; + EXPECT(actual != expected); + EXPECT(migraphx::shape::is_compatible(actual, expected)); +} +TEST_CASE(shape_is_compatible_diff_lens_tuple) +{ + migraphx::shape actual{migraphx::shape{migraphx::shape::float_type, {1, 2, 8}}}; + migraphx::shape expected{migraphx::shape{migraphx::shape::float_type, {1, 1, 8}}}; + EXPECT(actual != expected); + EXPECT(not migraphx::shape::is_compatible(actual, expected)); +} int main(int argc, const char* argv[]) { test::run(argc, argv); }