From 51b5bcaef2ffb4d85fede4bfaac6e89f7acf9c09 Mon Sep 17 00:00:00 2001 From: turneram <71655887+turneram@users.noreply.github.com> Date: Thu, 8 Aug 2024 10:32:50 -0500 Subject: [PATCH] Add ONNX parsing for SimplifiedLayerNormalization (#3129) * Add simplified_layer_normalization --- .../parse_simplified_layer_normalization.cpp | 98 +++++++++++++++++++ src/targets/gpu/fuse_mlir.cpp | 2 +- test/onnx/gen_onnx.py | 50 ++++++++++ test/onnx/include/onnx_test_utils.hpp | 39 +++++++- ...layer_normalization_invalid_input_test.cpp | 31 ++++++ ...ayer_normalization_invalid_n_args_test.cpp | 31 ++++++ .../simplified_layer_normalization_test.cpp | 35 +++++++ ...ayer_normalization_invalid_input_test.onnx | 23 +++++ ...yer_normalization_invalid_n_args_test.onnx | 28 ++++++ .../simplified_layer_normalization_test.onnx | 25 +++++ .../verify/simplified_layer_normalization.cpp | 84 ++++++++++++++++ 11 files changed, 442 insertions(+), 4 deletions(-) create mode 100644 src/onnx/parse_simplified_layer_normalization.cpp create mode 100644 test/onnx/parse/simplified_layer_normalization_invalid_input_test.cpp create mode 100644 test/onnx/parse/simplified_layer_normalization_invalid_n_args_test.cpp create mode 100644 test/onnx/parse/simplified_layer_normalization_test.cpp create mode 100644 test/onnx/simplified_layer_normalization_invalid_input_test.onnx create mode 100644 test/onnx/simplified_layer_normalization_invalid_n_args_test.onnx create mode 100644 test/onnx/simplified_layer_normalization_test.onnx create mode 100644 test/onnx/verify/simplified_layer_normalization.cpp diff --git a/src/onnx/parse_simplified_layer_normalization.cpp b/src/onnx/parse_simplified_layer_normalization.cpp new file mode 100644 index 00000000000..6b99b4fbb33 --- /dev/null +++ b/src/onnx/parse_simplified_layer_normalization.cpp @@ -0,0 +1,98 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { +namespace onnx { + +// ONNXRunTime implementation for reference: +// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc + +struct parse_simplified_layer_normalization : op_parser +{ + std::vector operators() const { return {{"SimplifiedLayerNormalization"}}; } + + std::vector parse(const op_desc& /*opd*/, + const onnx_parser& parser, + const onnx_parser::node_info& info, + std::vector args) const + { + int64_t axis = -1; + if(contains(info.attributes, "axis")) + { + axis = parser.parse_value(info.attributes.at("axis")).at(); + } + float epsilon = 1e-5f; + if(contains(info.attributes, "epsilon")) + { + epsilon = parser.parse_value(info.attributes.at("epsilon")).at(); + } + if(contains(info.attributes, "stash_type")) + { + std::cerr << "WARNING: SIMPLIFIED_LAYER_NORMALIZATION attribute stash_type is only " + "used for training.\n"; + } + + if(args.size() != 2) + { + MIGRAPHX_THROW( + "PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input count - expected 2 got " + + std::to_string(args.size())); + } + + auto x = args.at(0); + auto scale = args.at(1); + + auto x_shape = x->get_shape(); + auto x_dtype = x_shape.type(); + int64_t x_rank = x_shape.ndim(); + axis = axis < 0 ? axis + x_rank : axis; + + if(x_rank < 2 or x_rank > 3) + { + MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape"); + } + + auto x_sq = info.add_common_op("mul", x, x); + auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq); + auto mean = rms; + epsilon = + (x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon; + auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}}); + rms = info.add_common_op("add", rms, eps); + auto rrms = info.add_instruction(make_op("rsqrt"), rms); + auto result = info.add_common_op("mul", x, rrms); + result = info.add_common_op("mul", result, scale); + + return {result, mean, rrms}; + } +}; + +} // namespace onnx +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index befc57eb05f..c09830251e0 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -495,7 +495,7 @@ struct find_mlir_fused_ops pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins); // move all the reshape instructions and original GEMM instruction after the fused op to // avoid generating invalid migraphx program - for(const auto orig_i : reverse(reshapes_vec)) + for(const auto& orig_i : reverse(reshapes_vec)) { mpm.get_module().move_instruction(orig_i, pw_ins); } diff --git a/test/onnx/gen_onnx.py b/test/onnx/gen_onnx.py index 59ca063f8e2..3efc787c559 100644 --- a/test/onnx/gen_onnx.py +++ b/test/onnx/gen_onnx.py @@ -10539,6 +10539,56 @@ def sign_test(): return ([node], [x], [y]) +@onnx_test() +def simplified_layer_normalization_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4]) + + node = onnx.helper.make_node( + 'SimplifiedLayerNormalization', + inputs=['x', 'scale'], + outputs=['y'], + axis=-1, + epsilon=1e-5, + stash_type=1, + ) + + return ([node], [x, scale], [y]) + + +@onnx_test() +def simplified_layer_normalization_invalid_input_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [1, 2, 2, 4]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 2, 2, 4]) + + node = onnx.helper.make_node( + 'SimplifiedLayerNormalization', + inputs=['x', 'scale'], + outputs=['y'], + ) + + return ([node], [x, scale], [y]) + + +@onnx_test() +def simplified_layer_normalization_invalid_n_args_test(): + x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4]) + scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4]) + bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16, + [1, 2, 4]) + y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4]) + + node = onnx.helper.make_node( + 'SimplifiedLayerNormalization', + inputs=['x', 'scale', 'bias'], + outputs=['y'], + ) + + return ([node], [x, scale, bias], [y]) + + @onnx_test() def sin_test(): x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10]) diff --git a/test/onnx/include/onnx_test_utils.hpp b/test/onnx/include/onnx_test_utils.hpp index f6ebbb150f8..120346c0930 100644 --- a/test/onnx/include/onnx_test_utils.hpp +++ b/test/onnx/include/onnx_test_utils.hpp @@ -171,9 +171,7 @@ make_layer_norm(const std::vector& input_shape, { bias = mm->add_parameter("bias", {dtype, scale_bias_shape}); } - - auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}}); - + auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}}); auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x); auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean}); auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean}); @@ -201,7 +199,42 @@ make_layer_norm(const std::vector& input_shape, { mm->add_instruction(migraphx::make_op("add"), {scaled, bias_bcast}); } + return p; +} + +inline migraphx::program +make_simplified_layer_norm(const std::vector& input_shape, + const std::vector& skip_shape, + const std::vector& scale_shape, + const int axis, + const float eps_value = 1e-5f, + const migraphx::shape::type_t dtype = migraphx::shape::half_type) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", {dtype, input_shape}); + migraphx::instruction_ref skip; + migraphx::instruction_ref scale; + if(skip_shape.empty()) + { + scale = mm->add_parameter("scale", {dtype, scale_shape}); + } + else + { + skip = mm->add_parameter("skip", {dtype, skip_shape}); + scale = mm->add_parameter("gamma", {dtype, scale_shape}); + x = add_common_op(*mm, migraphx::make_op("add"), {x, skip}); + } + + auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}}); + auto x_sq = add_common_op(*mm, migraphx::make_op("mul"), {x, x}); + auto norm_axis = axis < 0 ? axis + x->get_shape().lens().size() : axis; + auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {norm_axis}}}), x_sq); + rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps}); + auto rrms = mm->add_instruction(migraphx::make_op("rsqrt"), {rms}); + auto result = add_common_op(*mm, migraphx::make_op("mul"), {x, rrms}); + result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale}); return p; } diff --git a/test/onnx/parse/simplified_layer_normalization_invalid_input_test.cpp b/test/onnx/parse/simplified_layer_normalization_invalid_input_test.cpp new file mode 100644 index 00000000000..1e6ae3b5762 --- /dev/null +++ b/test/onnx/parse/simplified_layer_normalization_invalid_input_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(simplified_layer_normalization_invalid_input_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("simplified_layer_normalization_invalid_input_test.onnx"); })); +} diff --git a/test/onnx/parse/simplified_layer_normalization_invalid_n_args_test.cpp b/test/onnx/parse/simplified_layer_normalization_invalid_n_args_test.cpp new file mode 100644 index 00000000000..91c21ba621c --- /dev/null +++ b/test/onnx/parse/simplified_layer_normalization_invalid_n_args_test.cpp @@ -0,0 +1,31 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(simplified_layer_normalization_invalid_n_args_test) +{ + EXPECT(test::throws( + [&] { migraphx::parse_onnx("simplified_layer_normalization_invalid_n_args_test.onnx"); })); +} diff --git a/test/onnx/parse/simplified_layer_normalization_test.cpp b/test/onnx/parse/simplified_layer_normalization_test.cpp new file mode 100644 index 00000000000..7c880fe8aa0 --- /dev/null +++ b/test/onnx/parse/simplified_layer_normalization_test.cpp @@ -0,0 +1,35 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(simplified_layer_normalization_test) +{ + migraphx::program p = + make_simplified_layer_norm({2, 2, 4}, {}, {4}, -1, 1e-5f, migraphx::shape::half_type); + + auto prog = optimize_onnx("simplified_layer_normalization_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/simplified_layer_normalization_invalid_input_test.onnx b/test/onnx/simplified_layer_normalization_invalid_input_test.onnx new file mode 100644 index 00000000000..1331cc53639 --- /dev/null +++ b/test/onnx/simplified_layer_normalization_invalid_input_test.onnx @@ -0,0 +1,23 @@ + 1simplified_layer_normalization_invalid_input_test:¯ ++ +x +scaley"SimplifiedLayerNormalization1simplified_layer_normalization_invalid_input_testZ +x + + + + + +Z +scale + + + +b +y + + + + + +B \ No newline at end of file diff --git a/test/onnx/simplified_layer_normalization_invalid_n_args_test.onnx b/test/onnx/simplified_layer_normalization_invalid_n_args_test.onnx new file mode 100644 index 00000000000..119cbb42b9a --- /dev/null +++ b/test/onnx/simplified_layer_normalization_invalid_n_args_test.onnx @@ -0,0 +1,28 @@ + 2simplified_layer_normalization_invalid_n_args_test:Ê +1 +x +scale +biasy"SimplifiedLayerNormalization2simplified_layer_normalization_invalid_n_args_testZ +x + + + + +Z +scale + + + +Z +bias + + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/simplified_layer_normalization_test.onnx b/test/onnx/simplified_layer_normalization_test.onnx new file mode 100644 index 00000000000..a7b73fb1462 --- /dev/null +++ b/test/onnx/simplified_layer_normalization_test.onnx @@ -0,0 +1,25 @@ + #simplified_layer_normalization_test:Õ +g +x +scaley"SimplifiedLayerNormalization* +axisÿÿÿÿÿÿÿÿÿ * +epsilon¬Å'7 * + +stash_type #simplified_layer_normalization_testZ +x + + + + +Z +scale + + + +b +y + + + + +B \ No newline at end of file diff --git a/test/onnx/verify/simplified_layer_normalization.cpp b/test/onnx/verify/simplified_layer_normalization.cpp new file mode 100644 index 00000000000..f3838e4a551 --- /dev/null +++ b/test/onnx/verify/simplified_layer_normalization.cpp @@ -0,0 +1,84 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(simplified_layer_normalization_test) +{ + using migraphx::half; + std::vector x{half{0.8}, + half{-0.5}, + half{0.0}, + half{1.0}, + half{0.5}, + half{0.2}, + half{0.3}, + half{-0.6}, + half{10.0}, + half{-1.0}, + half{0.0}, + half{1.0}, + half{1.2}, + half{3.2}, + half{-4.1}, + half{5.3}}; + std::vector scale{half{0.1}, half{0.2}, half{4.0}, half{-2.2}}; + + auto p = read_onnx("simplified_layer_normalization_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s_x{migraphx::shape::half_type, {2, 2, 4}}; + migraphx::shape s_s{migraphx::shape::half_type, {4}}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s_x, x.data()); + pp["scale"] = migraphx::argument(s_s, scale.data()); + + auto result = p.eval(pp).back(); + + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + std::vector gold = {half{0.11633}, + half{-0.1455}, + half{0.0}, + half{-3.2}, + half{0.1162}, + half{0.09296}, + half{2.791}, + half{3.068}, + half{0.198}, + half{-0.03958}, + half{0.0}, + half{-0.4355}, + half{0.0319}, + half{0.17}, + half{-4.363}, + half{-3.1}}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +}