From c5eee1a34314591ade7043322e7cae0f979f91b3 Mon Sep 17 00:00:00 2001 From: Charlie Lin Date: Mon, 17 Apr 2023 13:00:40 -0400 Subject: [PATCH] Convert a fully fixed map_dyn_input_dims value to a static shape when parsing ONNX (#1682) Fixes the above behavior This needs to be changed to allow for setting static shapes with map_dyn_input_dims since you cannot also use map_input_dims --- src/onnx/onnx_parser.cpp | 27 ++++++++++++++++----------- test/onnx/onnx_test.cpp | 21 +++++++++++++++++++-- 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/src/onnx/onnx_parser.cpp b/src/onnx/onnx_parser.cpp index 2f1c3703365..ef9e43b1b5c 100644 --- a/src/onnx/onnx_parser.cpp +++ b/src/onnx/onnx_parser.cpp @@ -39,6 +39,20 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +static shape shape_from_dyn_dims(shape::type_t shape_type, + const std::vector& dyn_dims) +{ + if(std::all_of(dyn_dims.begin(), dyn_dims.end(), [](auto dd) { return dd.is_fixed(); })) + { + std::vector dims; + std::transform(dyn_dims.cbegin(), dyn_dims.cend(), std::back_inserter(dims), [](auto d) { + return d.max; + }); + return {shape_type, dims}; + } + return {shape_type, dyn_dims}; +} + namespace onnx { static onnx_parser::attribute_map get_attributes(const onnx::NodeProto& node) @@ -300,7 +314,7 @@ onnx_parser::parse_graph(module* mod, const onnx::GraphProto& graph, bool inlini else if(map_dyn_input_dims.count(name) > 0) { shape::type_t shape_type = get_type(input.type().tensor_type().elem_type()); - s = {shape_type, map_dyn_input_dims.at(name)}; + s = shape_from_dyn_dims(shape_type, map_dyn_input_dims.at(name)); } else { @@ -503,16 +517,7 @@ shape onnx_parser::parse_type(const onnx::TypeProto& t, { return {shape_type}; } - if(std::all_of(dynamic_dims.begin(), dynamic_dims.end(), [](auto dd) { return dd.is_fixed(); })) - { - std::vector dims; - std::transform(dynamic_dims.begin(), - dynamic_dims.end(), - std::back_inserter(dims), - [](auto d) { return d.max; }); - return {shape_type, dims}; - } - return {shape_type, dynamic_dims}; + return shape_from_dyn_dims(shape_type, dynamic_dims); } shape::type_t get_type(int dtype) diff --git a/test/onnx/onnx_test.cpp b/test/onnx/onnx_test.cpp index 4374ee868a5..ad26672260a 100644 --- a/test/onnx/onnx_test.cpp +++ b/test/onnx/onnx_test.cpp @@ -4959,13 +4959,13 @@ TEST_CASE(reducemax_dyn_test) migraphx::program p; auto* mm = p.get_main_module(); auto l0 = mm->add_parameter( - "x", migraphx::shape{migraphx::shape::float_type, {{3, 3}, {4, 4}, {5, 5}, {6, 6}}}); + "x", migraphx::shape{migraphx::shape::float_type, {{3, 5}, {4, 4}, {5, 5}, {6, 6}}}); auto r0 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0); auto r1 = mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), r0); mm->add_return({r1}); migraphx::onnx_options options; - options.map_dyn_input_dims["x"] = {{3, 3}, {4, 4}, {5, 5}, {6, 6}}; + options.map_dyn_input_dims["x"] = {{3, 5}, {4, 4}, {5, 5}, {6, 6}}; auto prog = migraphx::parse_onnx("reducemax_dyn_test.onnx", options); EXPECT(p == prog); @@ -6953,6 +6953,23 @@ TEST_CASE(variable_batch_user_input_test6) EXPECT(test::throws([&] { migraphx::parse_onnx("variable_batch_test.onnx", options); })); } +TEST_CASE(variable_batch_user_input_test7) +{ + // if entry in map_dyn_input_dims is all fixed dynamic_dimensions, convert it to a static shape + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::float_type, {2, 3, 16, 16}}); + auto r = mm->add_instruction(migraphx::make_op("identity"), l0); + mm->add_return({r}); + + migraphx::onnx_options options; + options.map_dyn_input_dims["0"] = {{2, 2, {2}}, {3, 3}, {16, 16}, {16, 16}}; + + auto prog = migraphx::parse_onnx("variable_batch_test.onnx", options); + + EXPECT(p == prog); +} + TEST_CASE(variable_batch_leq_zero_test) { migraphx::program p;