Skip to content

Commit

Permalink
FP8 Onnx tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ahsan-ca committed May 6, 2024
1 parent 2bdd02d commit 44ad3d5
Show file tree
Hide file tree
Showing 15 changed files with 577 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/onnx/add_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
  add_fp8_test:Q

0
12"Add add_fp8_testZ
0


Z
1


b
2


B
19 changes: 19 additions & 0 deletions test/onnx/conv_1d_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 conv_1d_fp8_test:n

0
12"Convconv_1d_fp8_testZ
0



Z
1



b
2



B
Binary file added test/onnx/gemm_fp8_test.onnx
Binary file not shown.
73 changes: 73 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,24 @@ def add_fp16_test():
])


@onnx_test()
def add_fp8_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ, [1])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ, [1])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT8E4M3FNUZ, [1])

node = onnx.helper.make_node(
'Add',
inputs=['0', '1'],
outputs=['2'],
)

return (
[node],
[x, y],
[z])


@onnx_test()
def add_scalar_test():
x = helper.make_tensor_value_info('0', TensorProto.UINT8, [2, 3, 4, 5])
Expand Down Expand Up @@ -1252,6 +1270,17 @@ def conv_3d_test():
return ([node], [x, y], [out])


@onnx_test()
def conv_1d_fp8_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ, [1, 3, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ, [1, 3, 3])
out = helper.make_tensor_value_info('2', TensorProto.FLOAT8E4M3FNUZ, [1, 1, 3])

node = onnx.helper.make_node('Conv', inputs=['0', '1'], outputs=['2'])

return ([node], [x, y], [out])


@onnx_test()
def conv_attr_fail_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5])
Expand Down Expand Up @@ -2719,6 +2748,22 @@ def gemm_half_test():
return ([node], [A, B, C], [Y])


@onnx_test()
def gemm_fp8_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT8E4M3FNUZ, [8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT8E4M3FNUZ, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT8E4M3FNUZ, [6, 1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT8E4M3FNUZ, [6, 7])

node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)

return ([node], [A, B, C], [Y])

@onnx_test()
def gemm_dyn_inner_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [None, 6])
Expand Down Expand Up @@ -9055,6 +9100,22 @@ def shrink_int8_test():
return ([node], [x], [y])


@onnx_test()
def shrink_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ, [3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ, [3, 3])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
bias=1.5,
)

return ([node], [x], [y])


@onnx_test()
def shrink_uint8_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [3, 3])
Expand Down Expand Up @@ -9163,6 +9224,18 @@ def size_int_test():
return ([node], [x], [y])


@onnx_test()
def size_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ, [2, 5, 3])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [1])
node = onnx.helper.make_node(
'Size',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])


@onnx_test()
def size_verify_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5, 3])
Expand Down
39 changes: 39 additions & 0 deletions test/onnx/parse/add_fp8_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(add_fp8_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto p0 = mm->add_parameter("0", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {1}});
auto p1 = mm->add_parameter("1", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {1}});

mm->add_instruction(migraphx::make_op("add"), p0, p1);

auto prog = optimize_onnx("add_fp8_test.onnx");

EXPECT(p == prog);
}
40 changes: 40 additions & 0 deletions test/onnx/parse/conv_1d_fp8_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(conv_1d_fp8_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("0", {migraphx::shape::fp8e4m3fnuz_type, {1, 3, 5}});
auto l1 = mm->add_parameter("1", {migraphx::shape::fp8e4m3fnuz_type, {1, 3, 3}});
mm->add_instruction(
migraphx::make_op("convolution", {{"padding", {0}}, {"stride", {1}}, {"dilation", {1}}}),
l0,
l1);

auto prog = optimize_onnx("conv_1d_fp8_test.onnx");
EXPECT(p == prog);
}
56 changes: 56 additions & 0 deletions test/onnx/parse/gemm_fp8_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>
#include <migraphx/apply_alpha_beta.hpp>

TEST_CASE(gemm_fp8_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto l0 = mm->add_parameter("A", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 6}});
auto l1 = mm->add_parameter("B", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {8, 7}});
auto l2 = mm->add_parameter("C", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {6, 1}});
auto alpha = 0.5f;
auto beta = 0.8f;
auto a_l = mm->add_literal(alpha);
auto t_a = add_common_op(*mm, migraphx::make_op("mul"), {a_l, l0});
t_a = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), t_a);
t_a = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), t_a);
std::vector<std::size_t> lens = {6, 7};
auto dot = migraphx::add_apply_alpha_beta(*mm, {t_a, l1}, migraphx::make_op("dot"), 1.0f, 0.0f);
l2 = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), l2);
l2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::float_type}}), l2);
auto b_l = mm->add_literal(beta);
auto b_b = mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", lens}}), b_l);
auto l2_b = mm->add_instruction(migraphx::make_op("mul"), l2, b_b);
l2_b = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), l2_b);
mm->add_instruction(migraphx::make_op("add"), dot, l2_b);

auto prog = optimize_onnx("gemm_fp8_test.onnx");
EXPECT(p == prog);
}
60 changes: 60 additions & 0 deletions test/onnx/parse/shrink_fp8_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(shrink_fp8_test)
{
migraphx::program p;
float bias = 1.5;
float lambd = 1.5;
std::vector<size_t> lens{3, 3};
auto* mm = p.get_main_module();
auto x = mm->add_parameter("x", migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, lens});
auto lit_bias = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {bias}});
auto lit_neg_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {-lambd}});
auto lit_lambd = mm->add_literal(migraphx::literal{migraphx::shape::float_type, {lambd}});

auto x_plus_bias = add_common_op(*mm, migraphx::make_op("add"), {x, lit_bias});
auto x_min_bias = add_common_op(*mm, migraphx::make_op("sub"), {x, lit_bias});

auto cond1 = add_common_op(*mm, migraphx::make_op("less"), {x, lit_neg_lambd});
auto cond2_a = add_common_op(*mm, migraphx::make_op("not"), {cond1});
auto cond2_b = add_common_op(*mm, migraphx::make_op("greater"), {x, lit_lambd});
auto cond2 = add_common_op(*mm, migraphx::make_op("logical_and"), {cond2_a, cond2_b});

auto mul1 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), cond1);
auto mul2 = mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), cond2);

auto first = add_common_op(*mm, migraphx::make_op("mul"), {mul1, x_plus_bias});
auto second = add_common_op(*mm, migraphx::make_op("mul"), {mul2, x_min_bias});
auto ret = add_common_op(*mm, migraphx::make_op("add"), {first, second});
mm->add_instruction(
migraphx::make_op("convert", {{"target_type", migraphx::shape::fp8e4m3fnuz_type}}), ret);
auto prog = optimize_onnx("shrink_fp8_test.onnx");

EXPECT(p == prog);
}
37 changes: 37 additions & 0 deletions test/onnx/parse/size_fp8_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/

#include <onnx_test.hpp>

TEST_CASE(size_fp8_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto s = migraphx::shape{migraphx::shape::fp8e4m3fnuz_type, {2, 5, 3}};
mm->add_parameter("x", s);
mm->add_literal(migraphx::literal{migraphx::shape::int64_type, {s.elements()}});

auto prog = optimize_onnx("size_fp8_test.onnx");
EXPECT(p == prog);
}
Binary file added test/onnx/shrink_fp8_test.onnx
Binary file not shown.
12 changes: 12 additions & 0 deletions test/onnx/size_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
 size_fp8_test:G

xy"Sizesize_fp8_testZ
x



b
y


B
Expand Down
Loading

0 comments on commit 44ad3d5

Please sign in to comment.