Skip to content

Commit

Permalink
ConvInteger: fix parsing for x_zero_point and w_zero_point (#3763)
Browse files Browse the repository at this point in the history
kahmed10 authored Jan 21, 2025
1 parent 976ae75 commit f36eba4
Showing 10 changed files with 189 additions and 95 deletions.
35 changes: 21 additions & 14 deletions src/onnx/parse_convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -141,17 +141,14 @@ struct parse_convolution : op_parser<parse_convolution>
return all_zeros;
}

static auto
static migraphx::operation
qparam_broadcast_op(instruction_ref qparam, std::vector<std::size_t> lens, std::size_t axis)
{
if(qparam->get_shape().scalar())
if(qparam->get_shape().elements() == 1)
{
return migraphx::make_op("multibroadcast", {{"out_lens", lens}});
}
else
{
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}
return migraphx::make_op("broadcast", {{"out_lens", lens}, {"axis", axis}});
}

static instruction_ref handle_quant_bias(const operation& op,
@@ -162,27 +159,37 @@ struct parse_convolution : op_parser<parse_convolution>
const instruction_ref& w_zp,
onnx_parser::node_info& info)
{
// to handle the bias, apply the following transformation:
// conv(x-x_zp,w-w_zp) = conv(x,w) - conv(x_zp,w) - conv(x,w_zp) + conv(x_zp,w_zp)
instruction_ref ret = input;

// multibroadcast (or broadcast) zero points according to spec
// x_zp should be a scalar or literal with one element
// w_zp can be either a single element or a 1d tensor with size out_channels
migraphx::operation x_zp_bc =
migraphx::make_op("multibroadcast", {{"out_lens", x->get_shape().lens()}});
migraphx::operation w_zp_bc = qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0);

if(not is_symmetric_zero_point(x_zp))
{
auto out_zp_1 = info.add_common_op(op.name(), x_zp, weights);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto out_zp_1 = info.add_instruction(op, x_zp_mb, weights);
ret = info.add_common_op("sub", ret, out_zp_1);
}

if(not is_symmetric_zero_point(w_zp))
{
auto out_zp_2 = info.add_common_op(op.name(), x, w_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);
auto out_zp_2 = info.add_instruction(op, x, w_zp_mb);
ret = info.add_common_op("sub", ret, out_zp_2);
}

if(not(is_symmetric_zero_point(x_zp)) and not(is_symmetric_zero_point(w_zp)))
{
auto x_zp_bc =
info.add_instruction(qparam_broadcast_op(x_zp, x->get_shape().lens(), 0), x_zp);
auto w_zp_bc = info.add_instruction(
qparam_broadcast_op(w_zp, weights->get_shape().lens(), 0), w_zp);
auto x_zp_mb = info.add_instruction(x_zp_bc, x_zp);
auto w_zp_mb = info.add_instruction(w_zp_bc, w_zp);

auto out_zp_3 = info.add_instruction(op, x_zp_bc, w_zp_bc);
auto out_zp_3 = info.add_instruction(op, x_zp_mb, w_zp_mb);

ret = info.add_common_op("add", ret, out_zp_3);
}
4 changes: 3 additions & 1 deletion src/simplify_algebra.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -761,6 +761,8 @@ struct find_inner_broadcast
void apply(module& m, const match::matcher_result& r) const
{
auto ins = r.result;
if(ins->get_operator().name() == "layout")
return;
const auto& broadcasts = ins->inputs();
if(broadcasts.empty())
return;
6 changes: 3 additions & 3 deletions test/onnx/convinteger_bias_test.onnx
Original file line number Diff line number Diff line change
@@ -7,13 +7,13 @@
strides@@�convinteger_bias_testZ
0





 Z
1





Z
@@ -23,7 +23,7 @@
b
3





B
34 changes: 34 additions & 0 deletions test/onnx/convinteger_dual_bias_simple_test.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
 !convinteger_dual_bias_simple_test:�
B
0
1
2
34" ConvInteger*
dilations@@�*
strides@@�!convinteger_dual_bias_simple_testZ
0




Z
1




Z
2


Z
3


b
4




B
20 changes: 11 additions & 9 deletions test/onnx/convinteger_dual_bias_test.onnx
Original file line number Diff line number Diff line change
@@ -8,16 +8,18 @@ B
strides@@�convinteger_dual_bias_testZ
0





Z



Z
1





Z

Z
2


@@ -28,7 +30,7 @@ B
b
4





B

B
23 changes: 20 additions & 3 deletions test/onnx/gen_onnx.py
Original file line number Diff line number Diff line change
@@ -1670,10 +1670,10 @@ def convinteger_no_bias_uint8_test():

@onnx_test()
def convinteger_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 5, 5])
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 32, 32])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 5, 5])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [1, 2, 28, 28])
out = helper.make_tensor_value_info('3', TensorProto.INT32, [2, 4, 28, 28])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2'],
@@ -1686,6 +1686,23 @@ def convinteger_bias_test():

@onnx_test()
def convinteger_dual_bias_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [2, 3, 10, 10])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [4, 3, 3, 3])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
w = helper.make_tensor_value_info('3', TensorProto.INT8, [1])
out = helper.make_tensor_value_info('4', TensorProto.INT32, [2, 4, 8, 8])

node = onnx.helper.make_node('ConvInteger',
inputs=['0', '1', '2', '3'],
outputs=['4'],
dilations=[1, 1],
strides=[1, 1])

return ([node], [x, y, z, w], [out])


@onnx_test()
def convinteger_dual_bias_simple_test():
x = helper.make_tensor_value_info('0', TensorProto.INT8, [1, 3, 5, 5])
y = helper.make_tensor_value_info('1', TensorProto.INT8, [1, 3, 2, 2])
z = helper.make_tensor_value_info('2', TensorProto.INT8, [1])
14 changes: 5 additions & 9 deletions test/onnx/parse/convinteger_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -28,24 +28,20 @@ TEST_CASE(convinteger_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 32, 32}});
auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 32, 32}});
auto weights = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 5, 5}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});

mm->add_literal(migraphx::literal{migraphx::shape{data->get_shape().type(), {1}, {0}}, {0}});
auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weights);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weights->get_shape().lens()}}),
data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant2 =
mm->add_instruction(migraphx::make_op("quant_convolution"), bcast_data_bias, weights);

auto bcast_quant2 = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant2);

mm->add_instruction(migraphx::make_op("sub"), quant, bcast_quant2);
mm->add_instruction(migraphx::make_op("sub"), quant, quant2);

auto prog = optimize_onnx("convinteger_bias_test.onnx");
EXPECT(p == prog);
32 changes: 14 additions & 18 deletions test/onnx/parse/convinteger_dual_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -28,41 +28,37 @@ TEST_CASE(convinteger_dual_bias_test)
{
migraphx::program p;
auto* mm = p.get_main_module();
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {1, 3, 5, 5}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {1, 3, 2, 2}});
auto data = mm->add_parameter("0", {migraphx::shape::int8_type, {2, 3, 10, 10}});
auto weight = mm->add_parameter("1", {migraphx::shape::int8_type, {4, 3, 3, 3}});
auto data_bias = mm->add_parameter("2", {migraphx::shape::int8_type, {1}, {1}});
auto weight_bias = mm->add_parameter("3", {migraphx::shape::int8_type, {1}, {1}});

auto quant = mm->add_instruction(migraphx::make_op("quant_convolution"), data, weight);

auto mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}), data_bias);
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);

auto quant_db_w =
auto quant_mb_w =
mm->add_instruction(migraphx::make_op("quant_convolution"), mbcast_data_bias, weight);

auto quant_mb_w = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_db_w);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_mb_w);

auto mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), weight_bias);
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);

auto quant_d_wb =
auto quant_md_wb =
mm->add_instruction(migraphx::make_op("quant_convolution"), data, mbcast_weight_bias);

auto quant_md_wb = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", quant->get_shape().lens()}}), quant_d_wb);

quant = mm->add_instruction(migraphx::make_op("sub"), quant, quant_md_wb);

auto bcast_data_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
auto bcast_weight_bias = mm->add_instruction(
migraphx::make_op("broadcast", {{"out_lens", weight->get_shape().lens()}}), weight_bias);
mbcast_data_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", data->get_shape().lens()}}), data_bias);
mbcast_weight_bias = mm->add_instruction(
migraphx::make_op("multibroadcast", {{"out_lens", weight->get_shape().lens()}}),
weight_bias);
auto bias_quant = mm->add_instruction(
migraphx::make_op("quant_convolution"), bcast_data_bias, bcast_weight_bias);
migraphx::make_op("quant_convolution"), mbcast_data_bias, mbcast_weight_bias);

mm->add_instruction(migraphx::make_op("add"), quant, bias_quant);

62 changes: 42 additions & 20 deletions test/onnx/verify/quant_convolution_dual_bias_test.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -24,11 +24,15 @@

#include <migraphx/register_target.hpp>
#include <migraphx/verify.hpp>
#include <migraphx/program.hpp>
#include <migraphx/module.hpp>
#include <migraphx/common.hpp>
#include <onnx_test.hpp>

TEST_CASE(quant_convolution_dual_zero_bias_test)
{
migraphx::program p = read_onnx("convinteger_dual_bias_test.onnx");
// TODO: use other dual_bias test, verify with other framework once convinteger supported
migraphx::program p = read_onnx("convinteger_dual_bias_simple_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}};
@@ -82,7 +86,7 @@ TEST_CASE(quant_convolution_dual_zero_bias_test)
TEST_CASE(quant_convolution_dual_non_zero_bias_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearMul
migraphx::program p = read_onnx("convinteger_dual_bias_test.onnx");
migraphx::program p = read_onnx("convinteger_dual_bias_simple_test.onnx");
p.compile(migraphx::make_target("ref"));

migraphx::shape a{migraphx::shape::int8_type, {1, 3, 5, 5}};
@@ -113,22 +117,40 @@ TEST_CASE(quant_convolution_dual_non_zero_bias_test)
std::vector<int32_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<int32_t> gold = {-6088,
6248,
-6472,
6632,
6664,
-8264,
8520,
-8713,
-3788,
-1446,
1488,
-1586,
-712,
745,
-914,
1019};
// create the following program to compare:
// conv(x-x_bias,w-w_bias)
// where datatypes for x,w,x_bias,w_bias are int32
migraphx::program p2;
migraphx::module* mm = p2.get_main_module();

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
migraphx::shape a_i32{migraphx::shape::int32_type, {1, 3, 5, 5}};
migraphx::shape b_i32{migraphx::shape::int32_type, {1, 3, 2, 2}};

migraphx::shape bias_i32{migraphx::shape::int32_type, {1}, {1}};
auto x = mm->add_parameter("0", a_i32);
auto weights = mm->add_parameter("1", b_i32);
auto x_bias = mm->add_parameter("2", bias_i32);
auto weights_bias = mm->add_parameter("3", bias_i32);

auto sub_input = add_common_op(*mm, migraphx::make_op("sub"), {x, x_bias});
auto sub_weights = add_common_op(*mm, migraphx::make_op("sub"), {weights, weights_bias});
mm->add_instruction(migraphx::make_op("convolution"), sub_input, sub_weights);

std::vector<int32_t> data_a_i32(data_a.begin(), data_a.end());
std::vector<int32_t> data_b_i32(data_b.begin(), data_b.end());
std::vector<int32_t> data_a_bias_i32 = {10};
std::vector<int32_t> data_b_bias_i32 = {-2};

migraphx::parameter_map pp2;
pp2["0"] = migraphx::argument(a_i32, data_a_i32.data());
pp2["1"] = migraphx::argument(b_i32, data_b_i32.data());
pp2["2"] = migraphx::argument(bias_i32, data_a_bias_i32.data());
pp2["3"] = migraphx::argument(bias_i32, data_b_bias_i32.data());

auto result2 = p2.eval(pp2).back();

std::vector<int32_t> result_vector_i32;
result2.visit([&](auto output) { result_vector_i32.assign(output.begin(), output.end()); });

EXPECT(migraphx::verify::verify_rms_range(result_vector, result_vector_i32));
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
@@ -112,22 +112,40 @@ TEST_CASE(quant_convolution_mismatched_inputs_dual_non_zero_bias_test)
std::vector<int32_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });

std::vector<int32_t> gold = {-6088,
6248,
-6472,
6632,
6664,
-8264,
8520,
-8713,
-3788,
-1446,
1488,
-1586,
-712,
745,
-914,
1019};
// create the following program to compare:
// conv(x-x_bias,w-w_bias)
// where datatypes for x,w,x_bias,w_bias are int32
migraphx::program p2;
migraphx::module* mm = p2.get_main_module();

EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
migraphx::shape a_i32{migraphx::shape::int32_type, {1, 3, 5, 5}};
migraphx::shape b_i32{migraphx::shape::int32_type, {1, 3, 2, 2}};

migraphx::shape bias_i32{migraphx::shape::int32_type, {1}, {1}};
auto x = mm->add_parameter("0", a_i32);
auto weights = mm->add_parameter("1", b_i32);
auto x_bias = mm->add_parameter("2", bias_i32);
auto weights_bias = mm->add_parameter("3", bias_i32);

auto sub_input = add_common_op(*mm, migraphx::make_op("sub"), {x, x_bias});
auto sub_weights = add_common_op(*mm, migraphx::make_op("sub"), {weights, weights_bias});
mm->add_instruction(migraphx::make_op("convolution"), sub_input, sub_weights);

std::vector<int32_t> data_a_i32(data_a.begin(), data_a.end());
std::vector<int32_t> data_b_i32(data_b.begin(), data_b.end());
std::vector<int32_t> data_a_bias_i32 = {138};
std::vector<int32_t> data_b_bias_i32 = {-2};

migraphx::parameter_map pp2;
pp2["0"] = migraphx::argument(a_i32, data_a_i32.data());
pp2["1"] = migraphx::argument(b_i32, data_b_i32.data());
pp2["2"] = migraphx::argument(bias_i32, data_a_bias_i32.data());
pp2["3"] = migraphx::argument(bias_i32, data_b_bias_i32.data());

auto result2 = p2.eval(pp2).back();

std::vector<int32_t> result_vector_i32;
result2.visit([&](auto output) { result_vector_i32.assign(output.begin(), output.end()); });

EXPECT(migraphx::verify::verify_rms_range(result_vector, result_vector_i32));
}

0 comments on commit f36eba4

Please sign in to comment.