Skip to content

Commit

Permalink
Remove helper function for limits
Browse files Browse the repository at this point in the history
  • Loading branch information
gyulaz-htec committed Dec 11, 2023
1 parent 3358faf commit 7cb098b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
48 changes: 23 additions & 25 deletions src/onnx/parse_dynamicquantizelinear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,52 +87,50 @@ struct parse_dynamicquantizelinear : op_parser<parse_dynamicquantizelinear>
{
std::vector<op_desc> operators() const { return {{"DynamicQuantizeLinear"}}; }

template <class T>
std::pair<int32_t, int32_t> num_limits() const
{
return std::make_pair(std::numeric_limits<T>::min(), std::numeric_limits<T>::max());
}

std::vector<instruction_ref> parse(const op_desc& /*opd*/,
const onnx_parser& /*parser*/,
const onnx_parser::node_info& info,
const std::vector<instruction_ref>& args) const
{
auto x = args[0];
auto x_shape = x->get_shape();
auto x_type = x_shape.type();
if(x_shape.dynamic())
MIGRAPHX_THROW("DYNAMICQUANTIZELINEAR: dynamic shapes are not supported");

auto x_type = x_shape.type();
auto x_reshaped =
(x_shape.lens().size() == 1)
? x
: info.add_instruction(
migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x);

auto lit_0 = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {0}});
x_reshaped =
info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, lit_0);

// 1. Computing y_scale
auto l0 = info.add_literal({0.f});
// DynamicQuantizeLinear only has uint8 quantization
auto limits = num_limits<uint8_t>();
auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {limits.first}});
auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {limits.second}});
auto q_scale = info.add_literal(
migraphx::literal{migraphx::shape{x_type}, {limits.second - limits.first}});
auto x_reshape = x;
if(x_shape.lens().size() != 1)
{
x_reshape = info.add_instruction(
migraphx::make_op("reshape", {{"dims", {x_shape.elements()}}}), x);
}
x_reshape = info.add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshape, l0);
// Note: currently, DynamicQuantizeLinear only has uint8 quantization:
const auto Q_MAX = std::numeric_limits<uint8_t>::max();
const auto Q_MIN = std::numeric_limits<uint8_t>::min();

auto q_range =
info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MAX - Q_MIN}});

// maximum(0, max(x))
auto max_x =
info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshape);

info.add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
// minimum(0, min(x))
auto min_x =
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshape);
info.add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);

// y_scale = (maximum(0, max(x)) - minimum(0, min(x))) / (qmax - qmin)
auto sub0 = info.add_common_op("sub", max_x, min_x);
auto y_scale = info.add_common_op("div", sub0, q_scale);
auto y_scale = info.add_common_op("div", sub0, q_range);

// 2. Computing y_zero_point
// intermediate_zero_point = qmin - min(x) / y_scale
auto q_min = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MIN}});
auto q_max = info.add_literal(migraphx::literal{migraphx::shape{x_type}, {Q_MAX}});
auto sub1 = info.add_common_op("sub", q_min, min_x);
auto interm_zp = info.add_common_op("div", sub1, y_scale);
// y_zero_point = cast(round(saturate(itermediate_zero_point)))
Expand Down
31 changes: 16 additions & 15 deletions test/onnx/onnx_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1873,22 +1873,23 @@ TEST_CASE(dynamicquantizelinear_2d_test)
auto x_type = migraphx::shape::float_type;
auto x = mm->add_parameter("x", {x_type, x_dims});

auto l0 = mm->add_literal({0.f});
auto q_min = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::min()}});
auto q_max = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
auto q_scale = mm->add_literal(
auto l0 = mm->add_literal({0.f});
auto x_reshaped = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), x);
x_reshaped = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshaped, l0);

auto q_range = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
auto x_reshape = mm->add_instruction(migraphx::make_op("reshape", {{"dims", {12}}}), x);
x_reshape = mm->add_instruction(migraphx::make_op("concat", {{"axis", 0}}), x_reshape, l0);

auto max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshape);
auto min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshape);
auto max_x = mm->add_instruction(migraphx::make_op("reduce_max", {{"axes", {0}}}), x_reshaped);
auto min_x = mm->add_instruction(migraphx::make_op("reduce_min", {{"axes", {0}}}), x_reshaped);

auto sub0 = mm->add_instruction(migraphx::make_op("sub"), max_x, min_x);
auto y_scale = mm->add_instruction(migraphx::make_op("div"), sub0, q_scale);
auto y_scale = mm->add_instruction(migraphx::make_op("div"), sub0, q_range);

auto q_min = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::min()}});
auto q_max = mm->add_literal(
migraphx::literal{migraphx::shape{x_type}, {std::numeric_limits<uint8_t>::max()}});
auto sub1 = mm->add_instruction(migraphx::make_op("sub"), q_min, min_x);
auto interm_zp = mm->add_instruction(migraphx::make_op("div"), sub1, y_scale);
auto saturate = mm->add_instruction(migraphx::make_op("clip"), interm_zp, q_min, q_max);
Expand Down Expand Up @@ -2906,12 +2907,12 @@ migraphx::program make_group_norm(const std::vector<int64_t>& input_dims,

auto eps = mm->add_literal(migraphx::literal{dtype, {eps_value}});

auto x_reshaped =
auto x_reshapedd =
mm->add_instruction(migraphx::make_op("reshape", {{"dims", reshape_dims}}), x);
auto mean =
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x_reshaped);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x_reshaped, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x_reshaped, mean});
mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}), x_reshapedd);
auto x_sub_mean = add_common_op(*mm, migraphx::make_op("sub"), {x_reshapedd, mean});
auto x_sqdiff_mean = add_common_op(*mm, migraphx::make_op("sqdiff"), {x_reshapedd, mean});
auto var = mm->add_instruction(migraphx::make_op("reduce_mean", {{"axes", reduce_axes}}),
x_sqdiff_mean);
auto var_eps = add_common_op(*mm, migraphx::make_op("add"), {var, eps});
Expand Down

0 comments on commit 7cb098b

Please sign in to comment.