From 395c5fa79ced3739c8fe59c75cc933a3e1f393aa Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Mon, 15 Apr 2024 18:56:29 +0000 Subject: [PATCH 01/11] FP8 Onnx tests --- test/onnx/add_fp8_test.onnx | 16 ++ test/onnx/binary_dyn_brcst_mul_fp8_test.onnx | Bin 0 -> 165 bytes test/onnx/conv_1d_fp8_test.onnx | 19 ++ test/onnx/cos_fp8_test.onnx | 13 ++ test/onnx/div_fp8_test.onnx | 16 ++ test/onnx/gemm_fp8_test.onnx | Bin 0 -> 191 bytes test/onnx/gen_onnx.py | 207 ++++++++++++++++++ test/onnx/globalavgpool_fp8_test.onnx | 15 ++ test/onnx/globalmaxpool_fp8_test.onnx | 15 ++ test/onnx/parse/add_fp8_test.cpp | 39 ++++ .../parse/binary_dyn_brcst_mul_fp8_test.cpp | 53 +++++ test/onnx/parse/conv_1d_fp8_test.cpp | 40 ++++ test/onnx/parse/cos_fp8_test.cpp | 36 +++ test/onnx/parse/div_fp8_test.cpp | 39 ++++ test/onnx/parse/gemm_fp8_test.cpp | 56 +++++ test/onnx/parse/globalavgpool_fp8_test.cpp | 43 ++++ test/onnx/parse/globalmaxpool_fp8_test.cpp | 43 ++++ test/onnx/parse/reducemax_fp8_test.cpp | 37 ++++ test/onnx/parse/reducesum_fp8_test.cpp | 37 ++++ test/onnx/parse/shrink_fp8_test.cpp | 60 +++++ test/onnx/parse/sin_fp8_test.cpp | 36 +++ test/onnx/parse/size_fp8_test.cpp | 37 ++++ test/onnx/parse/sqrt_fp8_test.cpp | 36 +++ test/onnx/reducemax_fp8_test.onnx | Bin 0 -> 151 bytes test/onnx/reducesum_fp8_test.onnx | Bin 0 -> 155 bytes test/onnx/shrink_fp8_test.onnx | Bin 0 -> 133 bytes test/onnx/sin_fp8_test.onnx | 13 ++ test/onnx/size_fp8_test.onnx | 12 + test/onnx/sqrt_fp8_test.onnx | 13 ++ test/onnx/verify/add_fp8_test.cpp | 49 +++++ test/onnx/verify/gemm_fp8_test.cpp | 72 ++++++ test/onnx/verify/shrink_fp8_test.cpp | 50 +++++ test/onnx/verify/size_fp8_test.cpp | 43 ++++ 33 files changed, 1145 insertions(+) create mode 100644 test/onnx/add_fp8_test.onnx create mode 100644 test/onnx/binary_dyn_brcst_mul_fp8_test.onnx create mode 100644 test/onnx/conv_1d_fp8_test.onnx create mode 100644 test/onnx/cos_fp8_test.onnx create mode 100644 test/onnx/div_fp8_test.onnx create mode 100644 test/onnx/gemm_fp8_test.onnx create mode 100644 test/onnx/globalavgpool_fp8_test.onnx create mode 100644 test/onnx/globalmaxpool_fp8_test.onnx create mode 100644 test/onnx/parse/add_fp8_test.cpp create mode 100644 test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp create mode 100644 test/onnx/parse/conv_1d_fp8_test.cpp create mode 100644 test/onnx/parse/cos_fp8_test.cpp create mode 100644 test/onnx/parse/div_fp8_test.cpp create mode 100644 test/onnx/parse/gemm_fp8_test.cpp create mode 100644 test/onnx/parse/globalavgpool_fp8_test.cpp create mode 100644 test/onnx/parse/globalmaxpool_fp8_test.cpp create mode 100644 test/onnx/parse/reducemax_fp8_test.cpp create mode 100644 test/onnx/parse/reducesum_fp8_test.cpp create mode 100644 test/onnx/parse/shrink_fp8_test.cpp create mode 100644 test/onnx/parse/sin_fp8_test.cpp create mode 100644 test/onnx/parse/size_fp8_test.cpp create mode 100644 test/onnx/parse/sqrt_fp8_test.cpp create mode 100644 test/onnx/reducemax_fp8_test.onnx create mode 100644 test/onnx/reducesum_fp8_test.onnx create mode 100644 test/onnx/shrink_fp8_test.onnx create mode 100644 test/onnx/sin_fp8_test.onnx create mode 100644 test/onnx/size_fp8_test.onnx create mode 100644 test/onnx/sqrt_fp8_test.onnx create mode 100644 test/onnx/verify/add_fp8_test.cpp create mode 100644 test/onnx/verify/gemm_fp8_test.cpp create mode 100644 test/onnx/verify/shrink_fp8_test.cpp create mode 100644 test/onnx/verify/size_fp8_test.cpp diff --git a/test/onnx/add_fp8_test.onnx b/test/onnx/add_fp8_test.onnx new file mode 100644 index 00000000000..734d503c79f --- /dev/null +++ b/test/onnx/add_fp8_test.onnx @@ -0,0 +1,16 @@ +  add_fp8_test:Q + +0 +12"Add add_fp8_testZ +0 + + +Z +1 + + +b +2 + + +B \ No newline at end of file diff --git a/test/onnx/binary_dyn_brcst_mul_fp8_test.onnx b/test/onnx/binary_dyn_brcst_mul_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..80178a7ed4e523373a22753d8cbb537d1f43ecba GIT binary patch literal 165 zcmdwebi;067h*^M`HA)y{5FZy0P>KU8$e1L}1u`1D4kso7 FVE|#YA{GDu literal 0 HcmV?d00001 diff --git a/test/onnx/conv_1d_fp8_test.onnx b/test/onnx/conv_1d_fp8_test.onnx new file mode 100644 index 00000000000..a6b910da71d --- /dev/null +++ b/test/onnx/conv_1d_fp8_test.onnx @@ -0,0 +1,19 @@ + conv_1d_fp8_test:n + +0 +12"Convconv_1d_fp8_testZ +0 + + + +Z +1 + + + +b +2 + + + +B \ No newline at end of file diff --git a/test/onnx/cos_fp8_test.onnx b/test/onnx/cos_fp8_test.onnx new file mode 100644 index 00000000000..d494bc558f5 --- /dev/null +++ b/test/onnx/cos_fp8_test.onnx @@ -0,0 +1,13 @@ +  cos_fp8_test:= + +xy"Cos cos_fp8_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/div_fp8_test.onnx b/test/onnx/div_fp8_test.onnx new file mode 100644 index 00000000000..4eb4923a296 --- /dev/null +++ b/test/onnx/div_fp8_test.onnx @@ -0,0 +1,16 @@ +  div_fp8_test:a + +0 +1out"Div div_fp8_testZ +0 +  + +Z +1 +  + +b +out +  + +B \ No newline at end of file diff --git a/test/onnx/gemm_fp8_test.onnx b/test/onnx/gemm_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..effebfeaa370893a80bfb57fa751750970a69745 GIT binary patch literal 191 zcmdwehLr92&i-`k- f*+7C$XoBn@L1!dEHlRwzB%oj + +TEST_CASE(add_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {1}}); + auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {1}}); + + mm->add_instruction(migraphx::make_op("add"), p0, p1); + + auto prog = optimize_onnx("add_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp b/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp new file mode 100644 index 00000000000..3cd6a933a75 --- /dev/null +++ b/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp @@ -0,0 +1,53 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(binary_dyn_brcst_mul_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter( + "0", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {{1, 4}, {3, 3}, {4, 4}, {5, 5}}}); + auto l1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {4, 1}}); + + auto bl0 = mm->add_instruction( + migraphx::make_op("multibroadcast", + {{"out_dyn_dims", to_value(l0->get_shape().dyn_dims())}}), + l0, + l1); + auto bl1 = mm->add_instruction( + migraphx::make_op("multibroadcast", + {{"out_dyn_dims", to_value(l0->get_shape().dyn_dims())}}), + l1, + bl0); + auto ret = mm->add_instruction(migraphx::make_op("mul"), bl0, bl1); + mm->add_return({ret}); + + migraphx::onnx_options options; + options.default_dyn_dim_value = {1, 4}; + auto prog = migraphx::parse_onnx("binary_dyn_brcst_mul_fp8_test.onnx", options); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/conv_1d_fp8_test.cpp b/test/onnx/parse/conv_1d_fp8_test.cpp new file mode 100644 index 00000000000..d28abe6987b --- /dev/null +++ b/test/onnx/parse/conv_1d_fp8_test.cpp @@ -0,0 +1,40 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(conv_1d_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("0", {migraphx::shape::fp8e4m3fnuz_type, {1, 3, 5}}); + auto l1 = mm->add_parameter("1", {migraphx::shape::fp8e4m3fnuz_type, {1, 3, 3}}); + mm->add_instruction( + migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}), + l0, + l1); + + auto prog = optimize_onnx("conv_1d_fp8_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/cos_fp8_test.cpp b/test/onnx/parse/cos_fp8_test.cpp new file mode 100644 index 00000000000..409f0fa7d6f --- /dev/null +++ b/test/onnx/parse/cos_fp8_test.cpp @@ -0,0 +1,36 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(cos_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {10}}); + mm->add_instruction(migraphx::make_op("cos"), input); + + auto prog = optimize_onnx("cos_fp8_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/div_fp8_test.cpp b/test/onnx/parse/div_fp8_test.cpp new file mode 100644 index 00000000000..b3216421793 --- /dev/null +++ b/test/onnx/parse/div_fp8_test.cpp @@ -0,0 +1,39 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(div_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {2, 3}}); + auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {2, 3}}); + + mm->add_instruction(migraphx::make_op("div"), p0, p1); + + auto prog = optimize_onnx("div_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp new file mode 100644 index 00000000000..e01d0143aa7 --- /dev/null +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -0,0 +1,56 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(gemm_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 6}}); + auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 7}}); + auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {6, 1}}); + auto alpha = 0.5f; + auto beta = 0.8f; + auto a_l = mm->add_literal(alpha); + auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); + t_a = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), t_a); + t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); + std::vector lens = {6, 7}; + auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); + l2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); + auto b_l = mm->add_literal(beta); + auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); + auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); + l2_b = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), l2_b); + mm->add_instruction(migraphx::make_op("add"), dot, l2_b); + + auto prog = optimize_onnx("gemm_fp8_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/globalavgpool_fp8_test.cpp b/test/onnx/parse/globalavgpool_fp8_test.cpp new file mode 100644 index 00000000000..62c54ab4236 --- /dev/null +++ b/test/onnx/parse/globalavgpool_fp8_test.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(globalavgpool_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("0", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::average}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + op.padding = {0, 0, 0, 0}; + mm->add_instruction(op, input); + + auto prog = optimize_onnx("globalavgpool_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/globalmaxpool_fp8_test.cpp b/test/onnx/parse/globalmaxpool_fp8_test.cpp new file mode 100644 index 00000000000..a11873970f5 --- /dev/null +++ b/test/onnx/parse/globalmaxpool_fp8_test.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include + +TEST_CASE(globalmaxpool_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("0", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {1, 3, 16, 16}}); + auto op = migraphx::op::pooling{migraphx::op::pooling_mode::max}; + auto lens = input->get_shape().lens(); + op.lengths = {lens[2], lens[3]}; + op.padding = {0, 0, 0, 0}; + mm->add_instruction(op, input); + + auto prog = optimize_onnx("globalmaxpool_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/reducemax_fp8_test.cpp b/test/onnx/parse/reducemax_fp8_test.cpp new file mode 100644 index 00000000000..3fb52a87538 --- /dev/null +++ b/test/onnx/parse/reducemax_fp8_test.cpp @@ -0,0 +1,37 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(reducemax_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), l1); + auto prog = optimize_onnx("reducemax_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/reducesum_fp8_test.cpp b/test/onnx/parse/reducesum_fp8_test.cpp new file mode 100644 index 00000000000..ce27e9c7234 --- /dev/null +++ b/test/onnx/parse/reducesum_fp8_test.cpp @@ -0,0 +1,37 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(reducesum_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), l0); + mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), l1); + auto prog = optimize_onnx("reducesum_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/shrink_fp8_test.cpp b/test/onnx/parse/shrink_fp8_test.cpp new file mode 100644 index 00000000000..ed35aac5817 --- /dev/null +++ b/test/onnx/parse/shrink_fp8_test.cpp @@ -0,0 +1,60 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(shrink_fp8_test) +{ + migraphx::program p; + float bias = 1.5; + float lambd = 1.5; + std::vector lens{3, 3}; + auto* mm = p.get_main_module(); + auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, lens}); + auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}}); + auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}}); + auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}}); + + auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias}); + auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias}); + + auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd}); + auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1}); + auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); + auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); + + auto mul1 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), cond1); + auto mul2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), cond2); + + auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second}); + mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), ret); + auto prog = optimize_onnx("shrink_fp8_test.onnx"); + + EXPECT(p == prog); +} diff --git a/test/onnx/parse/sin_fp8_test.cpp b/test/onnx/parse/sin_fp8_test.cpp new file mode 100644 index 00000000000..dc006f036f3 --- /dev/null +++ b/test/onnx/parse/sin_fp8_test.cpp @@ -0,0 +1,36 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(sin_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {10}}); + mm->add_instruction(migraphx::make_op("sin"), input); + + auto prog = optimize_onnx("sin_fp8_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/size_fp8_test.cpp b/test/onnx/parse/size_fp8_test.cpp new file mode 100644 index 00000000000..76412b92181 --- /dev/null +++ b/test/onnx/parse/size_fp8_test.cpp @@ -0,0 +1,37 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(size_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto s = migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {2, 5, 3}}; + mm->add_parameter("x", s); + mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {s.elements()}}); + + auto prog = optimize_onnx("size_fp8_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/parse/sqrt_fp8_test.cpp b/test/onnx/parse/sqrt_fp8_test.cpp new file mode 100644 index 00000000000..81956c487aa --- /dev/null +++ b/test/onnx/parse/sqrt_fp8_test.cpp @@ -0,0 +1,36 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +TEST_CASE(sqrt_fp8_test) +{ + migraphx::program p; + auto* mm = p.get_main_module(); + auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {10, 15}}); + mm->add_instruction(migraphx::make_op("sqrt"), input); + + auto prog = optimize_onnx("sqrt_fp8_test.onnx"); + EXPECT(p == prog); +} diff --git a/test/onnx/reducemax_fp8_test.onnx b/test/onnx/reducemax_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..e5bd763dd134beee85c6b313040e44940f062eff GIT binary patch literal 151 zcmd*O({)I%}uO`Pb;v9FG(#fv8v|M=VGi7Vysl+3U5U|7J&gw2>JX^0>rF9%$6k1#aJmM R#3cZf*O({)IEiTQCPb;v9FG(#fv8w0N=VGi7Vysl+3U5U|7J&gw2>JX^0>rF9%$5X{ti&S8 L2$FPS5)cLeR_q`j literal 0 HcmV?d00001 diff --git a/test/onnx/shrink_fp8_test.onnx b/test/onnx/shrink_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..28f5f1a81d0b91db3c1c9ca65d9e1ea122e8069e GIT binary patch literal 133 zcmd +#include +#include +#include + +TEST_CASE(add_fp8_test) +{ + auto p = migraphx::parse_onnx("add_fp8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {1}}; + + migraphx::parameter_map pp; + migraphx::literal l1{s, {3.25}}; + migraphx::literal l2{s, {2.25}}; + pp["0"] = l1.get_argument(); + pp["1"] = l2.get_argument(); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + std::vector gold{static_cast(5.5)}; + + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/gemm_fp8_test.cpp b/test/onnx/verify/gemm_fp8_test.cpp new file mode 100644 index 00000000000..914789c74c2 --- /dev/null +++ b/test/onnx/verify/gemm_fp8_test.cpp @@ -0,0 +1,72 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(gemm_fp8_test) +{ + migraphx::program p = migraphx::parse_onnx("gemm_fp8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape a_shape{migraphx::shape::fp8e4m3fnuz_type, {8, 6}}; + std::vector tmp = {0.2646, 0.8525, 0.4192, 0.1415, 0.4321, 0.675, 0.4248, 0.8203, + 0.978, 0.5796, 0.6626, 0.479, 0.924, 0.734, 0.674, 0.8716, + 0.3733, 0.3328, 0.4272, 0.0247, 0.7583, 0.4873, 0.5835, 0.694, + 0.4375, 0.2406, 0.269, 0.6763, 0.542, 0.8994, 0.657, 0.5425, + 0.1412, 0.8994, 0.2183, 0.812, 0.937, 0.3438, 0.712, 0.9033, + 0.266, 0.8013, 0.803, 0.4993, 0.07196, 0.635, 0.7344, 0.3213}; + std::vector a_data{tmp.cbegin(), tmp.cend()}; + + migraphx::shape b_shape{migraphx::shape::fp8e4m3fnuz_type, {8, 7}}; + tmp = {0.7095, 0.612, 0.741, 0.02121, 0.3872, 0.4482, 0.6235, 0.02249, 0.2332, 0.7656, + 0.8955, 0.8154, 0.2239, 0.9277, 0.4622, 0.708, 0.566, 0.0736, 0.138, 0.8574, + 0.4055, 0.382, 0.6206, 0.424, 0.3674, 0.435, 0.998, 0.3594, 0.701, 0.6216, + 0.01826, 0.6313, 0.514, 0.1095, 0.3203, 0.01636, 0.537, 0.01952, 0.4502, 0.8965, + 0.5415, 0.7456, 0.793, 0.756, 0.9, 0.5264, 0.05368, 0.4214, 0.276, 0.1517, + 0.08453, 0.83, 0.417, 0.1682, 0.845, 0.1729}; + std::vector b_data{tmp.cbegin(), tmp.cend()}; + + migraphx::shape c_shape{migraphx::shape::fp8e4m3fnuz_type, {6, 1}}; + tmp = {0.10846, 0.672, 0.527, 0.94, 0.429, 0.2291}; + std::vector c_data{tmp.cbegin(), tmp.cend()}; + + migraphx::parameter_map params; + params["A"] = migraphx::argument(a_shape, a_data.data()); + params["B"] = migraphx::argument(b_shape, b_data.data()); + params["C"] = migraphx::argument(c_shape, c_data.data()); + + auto result = p.eval(params).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + + tmp = {1.071, 1.378, 1.465, 1.093, 0.968, 1.542, 1.145, 1.287, 1.533, 1.75, 1.338, + 1.449, 1.592, 1.668, 1.265, 1.531, 1.656, 1.348, 1.2705, 1.525, 1.479, 1.754, + 2.143, 2.062, 1.921, 1.836, 2.203, 1.952, 1.055, 1.225, 1.418, 1.209, 1.155, + 1.42, 1.234, 1.302, 1.593, 1.368, 1.289, 1.327, 1.451, 1.394}; + std::vector gold{tmp.cbegin(), tmp.cend()}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/shrink_fp8_test.cpp b/test/onnx/verify/shrink_fp8_test.cpp new file mode 100644 index 00000000000..14d53be0add --- /dev/null +++ b/test/onnx/verify/shrink_fp8_test.cpp @@ -0,0 +1,50 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(shrink_fp8_test) +{ + migraphx::program p = migraphx::parse_onnx("shrink_fp8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {3, 3}}; + // TODO: Make FP8 vector work for initializer list. + std::vector tmp_data{-4, -3, -2, -1, 0, 1, 2, 3, 4}; + std::vector data{tmp_data.cbegin(), tmp_data.cend()}; + + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + std::vector result_vector; + result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); }); + // TODO: Make FP8 vector work for initializer list. + std::vector tmp_gold = {-2, -1, 0, 0, 0, 0, 0, 1, 2}; + std::vector gold{tmp_gold.cbegin(), tmp_gold.cend()}; + EXPECT(migraphx::verify::verify_rms_range(result_vector, gold)); +} diff --git a/test/onnx/verify/size_fp8_test.cpp b/test/onnx/verify/size_fp8_test.cpp new file mode 100644 index 00000000000..ef4567ec36b --- /dev/null +++ b/test/onnx/verify/size_fp8_test.cpp @@ -0,0 +1,43 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include +#include +#include +#include + +TEST_CASE(size_fp8_test) +{ + migraphx::program p = migraphx::parse_onnx("size_fp8_test.onnx"); + p.compile(migraphx::make_target("ref")); + + migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {2, 5, 3}}; + std::vector data(30, 1.); + migraphx::parameter_map pp; + pp["x"] = migraphx::argument(s, data.data()); + + auto result = p.eval(pp).back(); + auto size_result = result.at(); + EXPECT(size_result == int64_t{30}); +} From 4eac0770ad5ea46e1700e74847f6388e6dcf20e0 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Mon, 3 Jun 2024 12:51:13 +0000 Subject: [PATCH 02/11] Fix CI failures --- test/onnx/add_fp8_test.onnx | 16 ------- test/onnx/binary_dyn_brcst_mul_fp8_test.onnx | Bin 165 -> 0 bytes test/onnx/conv_1d_fp8_test.onnx | 19 -------- test/onnx/cos_fp8_test.onnx | 13 ------ test/onnx/div_fp8_test.onnx | 16 ------- test/onnx/gemm_fp8_test.onnx | Bin 191 -> 0 bytes test/onnx/gen_onnx.py | 42 ++++++++++++------ test/onnx/globalavgpool_fp8_test.onnx | 15 ------- test/onnx/globalmaxpool_fp8_test.onnx | 15 ------- .../parse/binary_dyn_brcst_mul_fp8_test.cpp | 2 +- test/onnx/parse/reducemax_fp8_test.cpp | 5 ++- test/onnx/parse/reducesum_fp8_test.cpp | 5 ++- test/onnx/parse/sqrt_fp8_test.cpp | 5 ++- test/onnx/reducemax_fp8_test.onnx | Bin 151 -> 0 bytes test/onnx/reducesum_fp8_test.onnx | Bin 155 -> 0 bytes test/onnx/shrink_fp8_test.onnx | Bin 133 -> 0 bytes test/onnx/sin_fp8_test.onnx | 13 ------ test/onnx/size_fp8_test.onnx | 12 ----- test/onnx/sqrt_fp8_test.onnx | 13 ------ 19 files changed, 38 insertions(+), 153 deletions(-) delete mode 100644 test/onnx/add_fp8_test.onnx delete mode 100644 test/onnx/binary_dyn_brcst_mul_fp8_test.onnx delete mode 100644 test/onnx/conv_1d_fp8_test.onnx delete mode 100644 test/onnx/cos_fp8_test.onnx delete mode 100644 test/onnx/div_fp8_test.onnx delete mode 100644 test/onnx/gemm_fp8_test.onnx delete mode 100644 test/onnx/globalavgpool_fp8_test.onnx delete mode 100644 test/onnx/globalmaxpool_fp8_test.onnx delete mode 100644 test/onnx/reducemax_fp8_test.onnx delete mode 100644 test/onnx/reducesum_fp8_test.onnx delete mode 100644 test/onnx/shrink_fp8_test.onnx delete mode 100644 test/onnx/sin_fp8_test.onnx delete mode 100644 test/onnx/size_fp8_test.onnx delete mode 100644 test/onnx/sqrt_fp8_test.onnx diff --git a/test/onnx/add_fp8_test.onnx b/test/onnx/add_fp8_test.onnx deleted file mode 100644 index 734d503c79f..00000000000 --- a/test/onnx/add_fp8_test.onnx +++ /dev/null @@ -1,16 +0,0 @@ -  add_fp8_test:Q - -0 -12"Add add_fp8_testZ -0 - - -Z -1 - - -b -2 - - -B \ No newline at end of file diff --git a/test/onnx/binary_dyn_brcst_mul_fp8_test.onnx b/test/onnx/binary_dyn_brcst_mul_fp8_test.onnx deleted file mode 100644 index 80178a7ed4e523373a22753d8cbb537d1f43ecba..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 165 zcmdwebi;067h*^M`HA)y{5FZy0P>KU8$e1L}1u`1D4kso7 FVE|#YA{GDu diff --git a/test/onnx/conv_1d_fp8_test.onnx b/test/onnx/conv_1d_fp8_test.onnx deleted file mode 100644 index a6b910da71d..00000000000 --- a/test/onnx/conv_1d_fp8_test.onnx +++ /dev/null @@ -1,19 +0,0 @@ - conv_1d_fp8_test:n - -0 -12"Convconv_1d_fp8_testZ -0 - - - -Z -1 - - - -b -2 - - - -B \ No newline at end of file diff --git a/test/onnx/cos_fp8_test.onnx b/test/onnx/cos_fp8_test.onnx deleted file mode 100644 index d494bc558f5..00000000000 --- a/test/onnx/cos_fp8_test.onnx +++ /dev/null @@ -1,13 +0,0 @@ -  cos_fp8_test:= - -xy"Cos cos_fp8_testZ -x - - - -b -y - - - -B \ No newline at end of file diff --git a/test/onnx/div_fp8_test.onnx b/test/onnx/div_fp8_test.onnx deleted file mode 100644 index 4eb4923a296..00000000000 --- a/test/onnx/div_fp8_test.onnx +++ /dev/null @@ -1,16 +0,0 @@ -  div_fp8_test:a - -0 -1out"Div div_fp8_testZ -0 -  - -Z -1 -  - -b -out -  - -B \ No newline at end of file diff --git a/test/onnx/gemm_fp8_test.onnx b/test/onnx/gemm_fp8_test.onnx deleted file mode 100644 index effebfeaa370893a80bfb57fa751750970a69745..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 191 zcmdwehLr92&i-`k- f*+7C$XoBn@L1!dEHlRwzB%ojadd_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {3, 4, 5, 6}}); - auto l1 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0); + auto l0 = + mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {2}}}), l0); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), l1); auto prog = optimize_onnx("reducemax_fp8_test.onnx"); diff --git a/test/onnx/parse/reducesum_fp8_test.cpp b/test/onnx/parse/reducesum_fp8_test.cpp index ce27e9c7234..b446d8e78af 100644 --- a/test/onnx/parse/reducesum_fp8_test.cpp +++ b/test/onnx/parse/reducesum_fp8_test.cpp @@ -28,8 +28,9 @@ TEST_CASE(reducesum_fp8_test) { migraphx::program p; auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {3, 4, 5, 6}}); - auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), l0); + auto l0 = + mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {3, 4, 5, 6}}); + auto l1 = mm->add_instruction(migraphx::make_op("reduce_sum", {{"axes", {2}}}), l0); mm->add_instruction(migraphx::make_op("squeeze", {{"axes", {2}}}), l1); auto prog = optimize_onnx("reducesum_fp8_test.onnx"); diff --git a/test/onnx/parse/sqrt_fp8_test.cpp b/test/onnx/parse/sqrt_fp8_test.cpp index 81956c487aa..4345f50cb75 100644 --- a/test/onnx/parse/sqrt_fp8_test.cpp +++ b/test/onnx/parse/sqrt_fp8_test.cpp @@ -27,8 +27,9 @@ TEST_CASE(sqrt_fp8_test) { migraphx::program p; - auto* mm = p.get_main_module(); - auto input = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {10, 15}}); + auto* mm = p.get_main_module(); + auto input = + mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {10, 15}}); mm->add_instruction(migraphx::make_op("sqrt"), input); auto prog = optimize_onnx("sqrt_fp8_test.onnx"); diff --git a/test/onnx/reducemax_fp8_test.onnx b/test/onnx/reducemax_fp8_test.onnx deleted file mode 100644 index e5bd763dd134beee85c6b313040e44940f062eff..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 151 zcmd*O({)I%}uO`Pb;v9FG(#fv8v|M=VGi7Vysl+3U5U|7J&gw2>JX^0>rF9%$6k1#aJmM R#3cZf*O({)IEiTQCPb;v9FG(#fv8w0N=VGi7Vysl+3U5U|7J&gw2>JX^0>rF9%$5X{ti&S8 L2$FPS5)cLeR_q`j diff --git a/test/onnx/shrink_fp8_test.onnx b/test/onnx/shrink_fp8_test.onnx deleted file mode 100644 index 28f5f1a81d0b91db3c1c9ca65d9e1ea122e8069e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 133 zcmd Date: Mon, 3 Jun 2024 14:41:47 +0000 Subject: [PATCH 03/11] Regenerate onnx files --- test/onnx/add_fp8_test.onnx | 16 ++++++++++++++++ test/onnx/binary_dyn_brcst_mul_fp8_test.onnx | Bin 0 -> 165 bytes test/onnx/conv_1d_fp8_test.onnx | 19 +++++++++++++++++++ test/onnx/cos_fp8_test.onnx | 13 +++++++++++++ test/onnx/div_fp8_test.onnx | 16 ++++++++++++++++ test/onnx/gemm_fp8_test.onnx | Bin 0 -> 191 bytes test/onnx/globalavgpool_fp8_test.onnx | 15 +++++++++++++++ test/onnx/globalmaxpool_fp8_test.onnx | 15 +++++++++++++++ test/onnx/reducemax_fp8_test.onnx | Bin 0 -> 151 bytes test/onnx/reducesum_fp8_test.onnx | Bin 0 -> 155 bytes test/onnx/shrink_fp8_test.onnx | Bin 0 -> 133 bytes test/onnx/sin_fp8_test.onnx | 13 +++++++++++++ test/onnx/size_fp8_test.onnx | 12 ++++++++++++ test/onnx/sqrt_fp8_test.onnx | 13 +++++++++++++ 14 files changed, 132 insertions(+) create mode 100644 test/onnx/add_fp8_test.onnx create mode 100644 test/onnx/binary_dyn_brcst_mul_fp8_test.onnx create mode 100644 test/onnx/conv_1d_fp8_test.onnx create mode 100644 test/onnx/cos_fp8_test.onnx create mode 100644 test/onnx/div_fp8_test.onnx create mode 100644 test/onnx/gemm_fp8_test.onnx create mode 100644 test/onnx/globalavgpool_fp8_test.onnx create mode 100644 test/onnx/globalmaxpool_fp8_test.onnx create mode 100644 test/onnx/reducemax_fp8_test.onnx create mode 100644 test/onnx/reducesum_fp8_test.onnx create mode 100644 test/onnx/shrink_fp8_test.onnx create mode 100644 test/onnx/sin_fp8_test.onnx create mode 100644 test/onnx/size_fp8_test.onnx create mode 100644 test/onnx/sqrt_fp8_test.onnx diff --git a/test/onnx/add_fp8_test.onnx b/test/onnx/add_fp8_test.onnx new file mode 100644 index 00000000000..734d503c79f --- /dev/null +++ b/test/onnx/add_fp8_test.onnx @@ -0,0 +1,16 @@ +  add_fp8_test:Q + +0 +12"Add add_fp8_testZ +0 + + +Z +1 + + +b +2 + + +B \ No newline at end of file diff --git a/test/onnx/binary_dyn_brcst_mul_fp8_test.onnx b/test/onnx/binary_dyn_brcst_mul_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..80178a7ed4e523373a22753d8cbb537d1f43ecba GIT binary patch literal 165 zcmdwebi;067h*^M`HA)y{5FZy0P>KU8$e1L}1u`1D4kso7 FVE|#YA{GDu literal 0 HcmV?d00001 diff --git a/test/onnx/conv_1d_fp8_test.onnx b/test/onnx/conv_1d_fp8_test.onnx new file mode 100644 index 00000000000..a6b910da71d --- /dev/null +++ b/test/onnx/conv_1d_fp8_test.onnx @@ -0,0 +1,19 @@ + conv_1d_fp8_test:n + +0 +12"Convconv_1d_fp8_testZ +0 + + + +Z +1 + + + +b +2 + + + +B \ No newline at end of file diff --git a/test/onnx/cos_fp8_test.onnx b/test/onnx/cos_fp8_test.onnx new file mode 100644 index 00000000000..d494bc558f5 --- /dev/null +++ b/test/onnx/cos_fp8_test.onnx @@ -0,0 +1,13 @@ +  cos_fp8_test:= + +xy"Cos cos_fp8_testZ +x + + + +b +y + + + +B \ No newline at end of file diff --git a/test/onnx/div_fp8_test.onnx b/test/onnx/div_fp8_test.onnx new file mode 100644 index 00000000000..4eb4923a296 --- /dev/null +++ b/test/onnx/div_fp8_test.onnx @@ -0,0 +1,16 @@ +  div_fp8_test:a + +0 +1out"Div div_fp8_testZ +0 +  + +Z +1 +  + +b +out +  + +B \ No newline at end of file diff --git a/test/onnx/gemm_fp8_test.onnx b/test/onnx/gemm_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..effebfeaa370893a80bfb57fa751750970a69745 GIT binary patch literal 191 zcmdwehLr92&i-`k- f*+7C$XoBn@L1!dEHlRwzB%oj*O({)I%}uO`Pb;v9FG(#fv8v|M=VGi7Vysl+3U5U|7J&gw2>JX^0>rF9%$6k1#aJmM R#3cZf*O({)IEiTQCPb;v9FG(#fv8w0N=VGi7Vysl+3U5U|7J&gw2>JX^0>rF9%$5X{ti&S8 L2$FPS5)cLeR_q`j literal 0 HcmV?d00001 diff --git a/test/onnx/shrink_fp8_test.onnx b/test/onnx/shrink_fp8_test.onnx new file mode 100644 index 0000000000000000000000000000000000000000..28f5f1a81d0b91db3c1c9ca65d9e1ea122e8069e GIT binary patch literal 133 zcmd Date: Mon, 3 Jun 2024 15:37:27 +0000 Subject: [PATCH 04/11] change parse_onnx to optimize_onnx --- test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp | 2 +- test/onnx/parse/gemm_fp8_test.cpp | 9 +++++---- test/onnx/verify/add_fp8_test.cpp | 2 +- test/onnx/verify/gemm_fp8_test.cpp | 2 +- test/onnx/verify/shrink_fp8_test.cpp | 2 +- test/onnx/verify/size_fp8_test.cpp | 2 +- 6 files changed, 10 insertions(+), 9 deletions(-) diff --git a/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp b/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp index 2f900b83060..3c31c86184e 100644 --- a/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp +++ b/test/onnx/parse/binary_dyn_brcst_mul_fp8_test.cpp @@ -47,7 +47,7 @@ TEST_CASE(binary_dyn_brcst_mul_fp8_test) migraphx::onnx_options options; options.default_dyn_dim_value = {1, 4}; - auto prog = migraphx::parse_onnx("binary_dyn_brcst_mul_fp8_test.onnx", options); + auto prog = read_onnx("binary_dyn_brcst_mul_fp8_test.onnx", options); EXPECT(p == prog); } diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp index e01d0143aa7..2fb495128f8 100644 --- a/test/onnx/parse/gemm_fp8_test.cpp +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -28,10 +28,11 @@ TEST_CASE(gemm_fp8_test) { migraphx::program p; - auto* mm = p.get_main_module(); - auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 6}}); - auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 7}}); - auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {6, 1}}); + auto* mm = p.get_main_module(); + auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 6}}); + auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 7}}); + auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {6, 1}}); + auto alpha = 0.5f; auto beta = 0.8f; auto a_l = mm->add_literal(alpha); diff --git a/test/onnx/verify/add_fp8_test.cpp b/test/onnx/verify/add_fp8_test.cpp index 64f9bffcefa..0ef99140c44 100644 --- a/test/onnx/verify/add_fp8_test.cpp +++ b/test/onnx/verify/add_fp8_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(add_fp8_test) { - auto p = migraphx::parse_onnx("add_fp8_test.onnx"); + auto p = optimize_onnx("add_fp8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {1}}; diff --git a/test/onnx/verify/gemm_fp8_test.cpp b/test/onnx/verify/gemm_fp8_test.cpp index 914789c74c2..b5aa56d99f0 100644 --- a/test/onnx/verify/gemm_fp8_test.cpp +++ b/test/onnx/verify/gemm_fp8_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(gemm_fp8_test) { - migraphx::program p = migraphx::parse_onnx("gemm_fp8_test.onnx"); + migraphx::program p = optimize_onnx("gemm_fp8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape a_shape{migraphx::shape::fp8e4m3fnuz_type, {8, 6}}; diff --git a/test/onnx/verify/shrink_fp8_test.cpp b/test/onnx/verify/shrink_fp8_test.cpp index 14d53be0add..b1f541e415b 100644 --- a/test/onnx/verify/shrink_fp8_test.cpp +++ b/test/onnx/verify/shrink_fp8_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(shrink_fp8_test) { - migraphx::program p = migraphx::parse_onnx("shrink_fp8_test.onnx"); + migraphx::program p = optimize_onnx("shrink_fp8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {3, 3}}; diff --git a/test/onnx/verify/size_fp8_test.cpp b/test/onnx/verify/size_fp8_test.cpp index ef4567ec36b..dbb97ed1b56 100644 --- a/test/onnx/verify/size_fp8_test.cpp +++ b/test/onnx/verify/size_fp8_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(size_fp8_test) { - migraphx::program p = migraphx::parse_onnx("size_fp8_test.onnx"); + migraphx::program p = optimize_onnx("size_fp8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {2, 5, 3}}; From da91dbf2f811e09d8b994e557423550d9720d6c6 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Wed, 5 Jun 2024 13:27:01 +0000 Subject: [PATCH 05/11] Modify shrink fp8 test --- test/onnx/parse/shrink_fp8_test.cpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/test/onnx/parse/shrink_fp8_test.cpp b/test/onnx/parse/shrink_fp8_test.cpp index ed35aac5817..76f51601805 100644 --- a/test/onnx/parse/shrink_fp8_test.cpp +++ b/test/onnx/parse/shrink_fp8_test.cpp @@ -44,13 +44,8 @@ TEST_CASE(shrink_fp8_test) auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd}); auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b}); - auto mul1 = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), cond1); - auto mul2 = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), cond2); - - auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias}); - auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias}); + auto first = add_common_op(*mm, migraphx::make_op("mul"), {cond1, x_plus_bias}); + auto second = add_common_op(*mm, migraphx::make_op("mul"), {cond2, x_min_bias}); auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second}); mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), ret); From 49f451d4d43a5f8cf005d18384aa0b89594b45c3 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Wed, 5 Jun 2024 14:12:20 +0000 Subject: [PATCH 06/11] Change optimize_onnx to read_onnx --- test/onnx/verify/size_fp8_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/verify/size_fp8_test.cpp b/test/onnx/verify/size_fp8_test.cpp index dbb97ed1b56..7ae5d1d1647 100644 --- a/test/onnx/verify/size_fp8_test.cpp +++ b/test/onnx/verify/size_fp8_test.cpp @@ -29,7 +29,7 @@ TEST_CASE(size_fp8_test) { - migraphx::program p = optimize_onnx("size_fp8_test.onnx"); + migraphx::program p = read_onnx("size_fp8_test.onnx"); p.compile(migraphx::make_target("ref")); migraphx::shape s{migraphx::shape::fp8e4m3fnuz_type, {2, 5, 3}}; From 1927b3ec781204eaffe8ef058c8011809876d0ff Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Thu, 6 Jun 2024 08:29:36 +0000 Subject: [PATCH 07/11] Update gemm_fp8_test --- test/onnx/parse/gemm_fp8_test.cpp | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp index 2fb495128f8..6c4ac330f54 100644 --- a/test/onnx/parse/gemm_fp8_test.cpp +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -35,23 +35,13 @@ TEST_CASE(gemm_fp8_test) auto alpha = 0.5f; auto beta = 0.8f; - auto a_l = mm->add_literal(alpha); - auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); - t_a = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), t_a); - t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); - std::vector lens = {6, 7}; - auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); - l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); - l2 = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); - auto b_l = mm->add_literal(beta); - auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); - auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); - l2_b = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), l2_b); - mm->add_instruction(migraphx::make_op("add"), dot, l2_b); + + auto l0_transposed = + mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0); + + add_apply_alpha_beta(*mm, {l0_transposed, l1, l2}, migraphx::make_op("dot"), alpha, beta); auto prog = optimize_onnx("gemm_fp8_test.onnx"); + EXPECT(p == prog); } From c245c5b4fc7451d62ef0761aef2bb8fb1a29da54 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Thu, 6 Jun 2024 19:57:15 +0000 Subject: [PATCH 08/11] Revert "Update gemm_fp8_test" This reverts commit 1927b3ec781204eaffe8ef058c8011809876d0ff. --- test/onnx/parse/gemm_fp8_test.cpp | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp index 6c4ac330f54..2fb495128f8 100644 --- a/test/onnx/parse/gemm_fp8_test.cpp +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -35,13 +35,23 @@ TEST_CASE(gemm_fp8_test) auto alpha = 0.5f; auto beta = 0.8f; - - auto l0_transposed = - mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), l0); - - add_apply_alpha_beta(*mm, {l0_transposed, l1, l2}, migraphx::make_op("dot"), alpha, beta); + auto a_l = mm->add_literal(alpha); + auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0}); + t_a = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), t_a); + t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); + std::vector lens = {6, 7}; + auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); + l2 = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); + auto b_l = mm->add_literal(beta); + auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); + auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); + l2_b = mm->add_instruction( + migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), l2_b); + mm->add_instruction(migraphx::make_op("add"), dot, l2_b); auto prog = optimize_onnx("gemm_fp8_test.onnx"); - EXPECT(p == prog); } From d23315ef087d07b4fc092671ccb77174a4699087 Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Thu, 6 Jun 2024 20:07:15 +0000 Subject: [PATCH 09/11] replace add_apply_alpha_beta with dot instruction --- test/onnx/parse/gemm_fp8_test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp index 2fb495128f8..3d9e373e083 100644 --- a/test/onnx/parse/gemm_fp8_test.cpp +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -41,7 +41,7 @@ TEST_CASE(gemm_fp8_test) migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), t_a); t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); std::vector lens = {6, 7}; - auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f); + auto dot = mm->add_instruction(migraphx::make_op("dot"), t_a, l1); l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); l2 = mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); From 20fac32d4c7bb4c2a218d8ab23b07c24afe9f2ec Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Thu, 6 Jun 2024 20:35:57 +0000 Subject: [PATCH 10/11] Fix formatting --- test/onnx/parse/gemm_fp8_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp index 3d9e373e083..111d2183965 100644 --- a/test/onnx/parse/gemm_fp8_test.cpp +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -42,8 +42,8 @@ TEST_CASE(gemm_fp8_test) t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); std::vector lens = {6, 7}; auto dot = mm->add_instruction(migraphx::make_op("dot"), t_a, l1); - l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); - l2 = mm->add_instruction( + l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); + l2 = mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); auto b_l = mm->add_literal(beta); auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); From 14759a30d4bfc27163aa832b2d5dc788810a7f4e Mon Sep 17 00:00:00 2001 From: Ahsan Saghir Date: Tue, 11 Jun 2024 21:04:24 +0000 Subject: [PATCH 11/11] Update gemm_fp8_test test to address comments to add add_common_op --- test/onnx/parse/gemm_fp8_test.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/test/onnx/parse/gemm_fp8_test.cpp b/test/onnx/parse/gemm_fp8_test.cpp index 111d2183965..54554d54420 100644 --- a/test/onnx/parse/gemm_fp8_test.cpp +++ b/test/onnx/parse/gemm_fp8_test.cpp @@ -42,12 +42,9 @@ TEST_CASE(gemm_fp8_test) t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a); std::vector lens = {6, 7}; auto dot = mm->add_instruction(migraphx::make_op("dot"), t_a, l1); - l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); - l2 = mm->add_instruction( - migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2); + l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2); auto b_l = mm->add_literal(beta); - auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l); - auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b); + auto l2_b = add_common_op(*mm, migraphx::make_op("mul"), {l2, b_l}); l2_b = mm->add_instruction( migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), l2_b); mm->add_instruction(migraphx::make_op("add"), dot, l2_b);