Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add FP8 ONNX tests #3041

Merged
merged 12 commits into from
Jun 12, 2024
16 changes: 16 additions & 0 deletions test/onnx/add_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
  add_fp8_test:Q

0
12"Add add_fp8_testZ
0


Z
1


b
2


B
Binary file added test/onnx/binary_dyn_brcst_mul_fp8_test.onnx
Binary file not shown.
19 changes: 19 additions & 0 deletions test/onnx/conv_1d_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
 conv_1d_fp8_test:n

0
12"Convconv_1d_fp8_testZ
0



Z
1



b
2



B
13 changes: 13 additions & 0 deletions test/onnx/cos_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
  cos_fp8_test:=

xy"Cos cos_fp8_testZ
x



b
y



B
16 changes: 16 additions & 0 deletions test/onnx/div_fp8_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
  div_fp8_test:a

0
1out"Div div_fp8_testZ
0


Z
1


b
out


B
Binary file added test/onnx/gemm_fp8_test.onnx
Binary file not shown.
221 changes: 221 additions & 0 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,21 @@ def add_fp16_test():
])


@onnx_test()
def add_fp8_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ, [1])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ, [1])
z = helper.make_tensor_value_info('2', TensorProto.FLOAT8E4M3FNUZ, [1])

node = onnx.helper.make_node(
'Add',
inputs=['0', '1'],
outputs=['2'],
)

return ([node], [x, y], [z])


@onnx_test()
def add_scalar_test():
x = helper.make_tensor_value_info('0', TensorProto.UINT8, [2, 3, 4, 5])
Expand Down Expand Up @@ -618,6 +633,42 @@ def binary_dyn_brcst_mul_test():
return ([node], [arg0, arg1], [arg_out])


@onnx_test()
def binary_dyn_brcst_mul_fp8_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ,
[None, 3, 4, 5])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ,
[4, 1])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT8E4M3FNUZ,
[None, 3, 4, 5])

node = onnx.helper.make_node(
'Mul',
inputs=['0', '1'],
outputs=['out'],
)

return ([node], [arg0, arg1], [arg_out])


@onnx_test()
def div_fp8_test():
arg0 = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ,
[2, 3])
arg1 = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ,
[2, 3])
arg_out = helper.make_tensor_value_info('out', TensorProto.FLOAT8E4M3FNUZ,
[2, 3])

node = onnx.helper.make_node(
'Div',
inputs=['0', '1'],
outputs=['out'],
)

return ([node], [arg0, arg1], [arg_out])


@onnx_test()
def cast_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [10])
Expand Down Expand Up @@ -1252,6 +1303,20 @@ def conv_3d_test():
return ([node], [x, y], [out])


@onnx_test()
def conv_1d_fp8_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ,
[1, 3, 5])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ,
[1, 3, 3])
out = helper.make_tensor_value_info('2', TensorProto.FLOAT8E4M3FNUZ,
[1, 1, 3])

node = onnx.helper.make_node('Conv', inputs=['0', '1'], outputs=['2'])

return ([node], [x, y], [out])


@onnx_test()
def conv_attr_fail_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT, [1, 3, 5])
Expand Down Expand Up @@ -1688,6 +1753,20 @@ def cos_test():
return ([node], [x], [y])


@onnx_test()
def cos_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ, [10])

node = onnx.helper.make_node(
'Cos',
inputs=['x'],
outputs=['y'],
)

return ([node], [x], [y])


@onnx_test()
def cosh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [1])
Expand Down Expand Up @@ -3626,6 +3705,23 @@ def gemm_half_test():
return ([node], [A, B, C], [Y])


@onnx_test()
def gemm_fp8_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT8E4M3FNUZ, [8, 6])
B = helper.make_tensor_value_info('B', TensorProto.FLOAT8E4M3FNUZ, [8, 7])
C = helper.make_tensor_value_info('C', TensorProto.FLOAT8E4M3FNUZ, [6, 1])
Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT8E4M3FNUZ, [6, 7])

node = onnx.helper.make_node('Gemm',
inputs=['A', 'B', 'C'],
outputs=['Y'],
alpha=0.5,
beta=0.8,
transA=1)

return ([node], [A, B, C], [Y])


@onnx_test()
def gemm_dyn_inner_test():
A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [None, 6])
Expand Down Expand Up @@ -3705,6 +3801,22 @@ def globalavgpool_test():
return ([node], [x], [y])


@onnx_test()
def globalavgpool_fp8_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ,
[1, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ,
[1, 3, 1, 1])

node = onnx.helper.make_node(
'GlobalAveragePool',
inputs=['0'],
outputs=['1'],
)

return ([node], [x], [y])


@onnx_test()
def globalavgpool_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
Expand Down Expand Up @@ -3763,6 +3875,22 @@ def globalmaxpool_test():
return ([node], [x], [y])


@onnx_test()
def globalmaxpool_fp8_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT8E4M3FNUZ,
[1, 3, 16, 16])
y = helper.make_tensor_value_info('1', TensorProto.FLOAT8E4M3FNUZ,
[1, 3, 1, 1])

node = onnx.helper.make_node(
'GlobalMaxPool',
inputs=['0'],
outputs=['1'],
)

return ([node], [x], [y])


@onnx_test()
def globalmaxpool_dyn_test():
x = helper.make_tensor_value_info('0', TensorProto.FLOAT,
Expand Down Expand Up @@ -8610,6 +8738,24 @@ def reducemax_test():
return ([node], [x], [y])


@onnx_test()
def reducemax_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ,
[3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ,
[3, 4, 6])

axes = [2]

node = onnx.helper.make_node('ReduceMax',
inputs=['x'],
outputs=['y'],
axes=axes,
keepdims=0)

return ([node], [x], [y])


@onnx_test
def reducemax_dyn_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [None, 4, 5, 6])
Expand Down Expand Up @@ -8699,6 +8845,22 @@ def reducesum_test():
return ([node], [x], [y])


@onnx_test()
def reducesum_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ,
[3, 4, 5, 6])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ,
[3, 4, 1, 6])

node = onnx.helper.make_node('ReduceSum',
inputs=['x'],
outputs=['y'],
axes=[2],
keepdims=0)

return ([node], [x], [y])


@onnx_test()
def reducesum_empty_axes_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 4, 5, 6])
Expand Down Expand Up @@ -9962,6 +10124,22 @@ def shrink_int8_test():
return ([node], [x], [y])


@onnx_test()
def shrink_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ, [3, 3])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ, [3, 3])

node = onnx.helper.make_node(
"Shrink",
inputs=["x"],
outputs=["y"],
lambd=1.5,
bias=1.5,
)

return ([node], [x], [y])


@onnx_test()
def shrink_uint8_test():
x = helper.make_tensor_value_info('x', TensorProto.UINT8, [3, 3])
Expand Down Expand Up @@ -10006,6 +10184,20 @@ def sin_test():
return ([node], [x], [y])


@onnx_test()
def sin_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ, [10])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ, [10])

node = onnx.helper.make_node(
'Sin',
inputs=['x'],
outputs=['y'],
)

return ([node], [x], [y])


@onnx_test()
def sinh_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [10])
Expand Down Expand Up @@ -10070,6 +10262,19 @@ def size_int_test():
return ([node], [x], [y])


@onnx_test()
def size_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ,
[2, 5, 3])
y = helper.make_tensor_value_info('y', TensorProto.INT64, [1])
node = onnx.helper.make_node(
'Size',
inputs=['x'],
outputs=['y'],
)
return ([node], [x], [y])


@onnx_test()
def size_verify_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [2, 5, 3])
Expand Down Expand Up @@ -10913,6 +11118,22 @@ def sqrt_test():
return ([node], [x], [y])


@onnx_test()
def sqrt_fp8_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT8E4M3FNUZ,
[10, 15])
y = helper.make_tensor_value_info('y', TensorProto.FLOAT8E4M3FNUZ,
[10, 15])

node = onnx.helper.make_node(
'Sqrt',
inputs=['x'],
outputs=['y'],
)

return ([node], [x], [y])


@onnx_test()
def squeeze_axes_input_test():
x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 1, 5, 1])
Expand Down
Loading
Loading