Skip to content

Commit

Permalink
Add ONNX parsing for SimplifiedLayerNormalization (#3129)
Browse files Browse the repository at this point in the history
* Add simplified_layer_normalization
  • Loading branch information
turneram authored Aug 8, 2024
1 parent e0a2325 commit 5510d75
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 4 deletions.
98 changes: 98 additions & 0 deletions src/onnx/parse_simplified_layer_normalization.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/instruction.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace onnx {

// ONNXRunTime implementation for reference:
// https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc

struct parse_simplified_layer_normalization : op_parser<parse_simplified_layer_normalization>
{
std::vector<op_desc> operators() const { return {{"SimplifiedLayerNormalization"}}; }

std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& parser,
const onnx_parser::node_info& info,
std::vector<instruction_ref> args) const
{
int64_t axis = -1;
if(contains(info.attributes, "axis"))
{
axis = parser.parse_value(info.attributes.at("axis")).at<int64_t>();
}
float epsilon = 1e-5f;
if(contains(info.attributes, "epsilon"))
{
epsilon = parser.parse_value(info.attributes.at("epsilon")).at<float>();
}
if(contains(info.attributes, "stash_type"))
{
std::cerr << "WARNING: SIMPLIFIED_LAYER_NORMALIZATION attribute stash_type is only "
"used for training.\n";
}

if(args.size() != 2)
{
MIGRAPHX_THROW(
"PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input count - expected 2 got " +
std::to_string(args.size()));
}

auto x = args.at(0);
auto scale = args.at(1);

auto x_shape = x->get_shape();
auto x_dtype = x_shape.type();
int64_t x_rank = x_shape.ndim();
axis = axis < 0 ? axis + x_rank : axis;

if(x_rank < 2 or x_rank > 3)
{
MIGRAPHX_THROW("PARSE_SIMPLIFIED_LAYER_NORMALIZATION: invalid input shape");
}

auto x_sq = info.add_common_op("mul", x, x);
auto rms = info.add_instruction(make_op("reduce_mean", {{"axes", {axis}}}), x_sq);
auto mean = rms;
epsilon =
(x_dtype == migraphx::shape::half_type and std::abs(epsilon) < 1e-7) ? 1e-7 : epsilon;
auto eps = info.add_literal(migraphx::literal{migraphx::shape{x_dtype}, {epsilon}});
rms = info.add_common_op("add", rms, eps);
auto rrms = info.add_instruction(make_op("rsqrt"), rms);
auto result = info.add_common_op("mul", x, rrms);
result = info.add_common_op("mul", result, scale);

return {result, mean, rrms};
}
};

} // namespace onnx
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx
2 changes: 1 addition & 1 deletion src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ struct find_mlir_fused_ops
pw_ins, migraphx::make_op("get_tuple_elem", {{"index", 0}}), fused_ins);
// move all the reshape instructions and original GEMM instruction after the fused op to
// avoid generating invalid migraphx program
for(const auto orig_i : reverse(reshapes_vec))
for(const auto& orig_i : reverse(reshapes_vec))
{
mpm.get_module().move_instruction(orig_i, pw_ins);
}
Expand Down
50 changes: 50 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -10539,6 +10539,56 @@ def sign_test():
return ([node], [x], [y])


@onnx_test()
def simplified_layer_normalization_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4])

node = onnx.helper.make_node(
'SimplifiedLayerNormalization',
inputs=['x', 'scale'],
outputs=['y'],
axis=-1,
epsilon=1e-5,
stash_type=1,
)

return ([node], [x, scale], [y])


@onnx_test()
def simplified_layer_normalization_invalid_input_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [1, 2, 2, 4])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [1, 2, 2, 4])

node = onnx.helper.make_node(
'SimplifiedLayerNormalization',
inputs=['x', 'scale'],
outputs=['y'],
)

return ([node], [x, scale], [y])


@onnx_test()
def simplified_layer_normalization_invalid_n_args_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2, 4])
scale = helper.make_tensor_value_info('scale', TensorProto.FLOAT16, [4])
bias = helper.make_tensor_value_info('bias', TensorProto.FLOAT16,
[1, 2, 4])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, [2, 2, 4])

node = onnx.helper.make_node(
'SimplifiedLayerNormalization',
inputs=['x', 'scale', 'bias'],
outputs=['y'],
)

return ([node], [x, scale, bias], [y])


@onnx_test()
def sin_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
Expand Down
39 changes: 36 additions & 3 deletions test/onnx/include/onnx_test_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
{
bias = mm->add_parameter("bias", {dtype, scale_bias_shape});
}

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});
auto mean = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x, mean});
Expand Down Expand Up @@ -201,7 +199,42 @@ make_layer_norm(const std::vector<int64_t>& input_shape,
{
mm->add_instruction(migraphx::make_op("add"), {scaled, bias_bcast});
}
return p;
}

inline migraphx::program
make_simplified_layer_norm(const std::vector<int64_t>& input_shape,
const std::vector<int64_t>& skip_shape,
const std::vector<int64_t>& scale_shape,
const int axis,
const float eps_value = 1e-5f,
const migraphx::shape::type_t dtype = migraphx::shape::half_type)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", {dtype, input_shape});
migraphx::instruction_ref skip;
migraphx::instruction_ref scale;
if(skip_shape.empty())
{
scale = mm->add_parameter("scale", {dtype, scale_shape});
}
else
{
skip = mm->add_parameter("skip", {dtype, skip_shape});
scale = mm->add_parameter("gamma", {dtype, scale_shape});
x = add_common_op(*mm, migraphx::make_op("add"), {x, skip});
}

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});

auto x_sq = add_common_op(*mm, migraphx::make_op("mul"), {x, x});
auto norm_axis = axis < 0 ? axis + x->get_shape().lens().size() : axis;
auto rms = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", {norm_axis}}}), x_sq);
rms = add_common_op(*mm, migraphx::make_op("add"), {rms, eps});
auto rrms = mm->add_instruction(migraphx::make_op("rsqrt"), {rms});
auto result = add_common_op(*mm, migraphx::make_op("mul"), {x, rrms});
result = add_common_op(*mm, migraphx::make_op("mul"), {result, scale});
return p;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(simplified_layer_normalization_invalid_input_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("simplified_layer_normalization_invalid_input_test.onnx"); }));
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(simplified_layer_normalization_invalid_n_args_test)
{
EXPECT(test::throws(
[&] { migraphx::parse_onnx("simplified_layer_normalization_invalid_n_args_test.onnx"); }));
}
35 changes: 35 additions & 0 deletions test/onnx/parse/simplified_layer_normalization_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <onnx_test_utils.hpp>

TEST_CASE(simplified_layer_normalization_test)
{
migraphx::program p =
make_simplified_layer_norm({2, 2, 4}, {}, {4}, -1, 1e-5f, migraphx::shape::half_type);

auto prog = optimize_onnx("simplified_layer_normalization_test.onnx");
EXPECT(p == prog);
}
23 changes: 23 additions & 0 deletions test/onnx/simplified_layer_normalization_invalid_input_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
 1simplified_layer_normalization_invalid_input_test:�
+
x
scaley"SimplifiedLayerNormalization1simplified_layer_normalization_invalid_input_testZ
x





Z
scale



b
y





B
28 changes: 28 additions & 0 deletions test/onnx/simplified_layer_normalization_invalid_n_args_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
 2simplified_layer_normalization_invalid_n_args_test:�
1
x
scale
biasy"SimplifiedLayerNormalization2simplified_layer_normalization_invalid_n_args_testZ
x




Z
scale



Z
bias




b
y




B
Loading

0 comments on commit 5510d75

Please sign in to comment.