Skip to content

Commit

Permalink
ONNX greater_or_equal enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Nov 9, 2023
1 parent b1705e8 commit 76e3f96
Show file tree
Hide file tree
Showing 6 changed files with 237 additions and 0 deletions.
42 changes: 42 additions & 0 deletions src/frontends/onnx/frontend/src/op/greater_or_equal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "op/greater_or_equal.hpp"

#include <memory>

#include "default_opset.hpp"
#include "utils/common.hpp"
#define _USE_MATH_DEFINES
#include <math.h>

OPENVINO_SUPPRESS_DEPRECATED_START
namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector greater_or_equal(const Node& node) {
const auto A = node.get_ng_inputs().at(0);
const auto B = node.get_ng_inputs().at(1);

const auto C = std::make_shared<default_opset::GreaterEqual>(A, B);

return {C};
}
} // namespace set_1

namespace set_12 {
OutputVector greater_or_equal(const Node& node) {
const auto A = node.get_ng_inputs().at(0);
const auto B = node.get_ng_inputs().at(1);

const auto C = std::make_shared<default_opset::GreaterEqual>(A, B);

return {C};
}
} // namespace set_12
} // namespace op
} // namespace onnx_import
} // namespace ngraph
OPENVINO_SUPPRESS_DEPRECATED_END
28 changes: 28 additions & 0 deletions src/frontends/onnx/frontend/src/op/greater_or_equal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/core/deprecated.hpp"
OPENVINO_SUPPRESS_DEPRECATED_START

#include "ngraph/node.hpp"
#include "onnx_import/core/node.hpp"

namespace ngraph {
namespace onnx_import {
namespace op {
namespace set_1 {
OutputVector greater_or_equal(const Node& node);

} // namespace set_1

namespace set_12 {
OutputVector greater_or_equal(const Node& node);

} // namespace set_12
} // namespace op
} // namespace onnx_import
} // namespace ngraph
OPENVINO_SUPPRESS_DEPRECATED_END
3 changes: 3 additions & 0 deletions src/frontends/onnx/frontend/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
#include "op/global_average_pool.hpp"
#include "op/global_max_pool.hpp"
#include "op/greater.hpp"
#include "op/greater_or_equal.hpp"
#include "op/grid_sample.hpp"
#include "op/group_normalization.hpp"
#include "op/gru.hpp"
Expand Down Expand Up @@ -395,6 +396,8 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("GlobalLpPool", 1, global_lp_pool);
REGISTER_OPERATOR("GlobalMaxPool", 1, global_max_pool);
REGISTER_OPERATOR("Greater", 1, greater);
REGISTER_OPERATOR("Greater_Or_Equal", 1, greater_or_equal);
REGISTER_OPERATOR("Greater_Or_Equal", 12, greater_or_equal);
REGISTER_OPERATOR("GridSample", 1, grid_sample);
REGISTER_OPERATOR("GroupNormalization", 1, group_normalization);
REGISTER_OPERATOR("GRU", 1, gru);
Expand Down
53 changes: 53 additions & 0 deletions src/frontends/onnx/tests/models/greater_or_equal_float.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "C"
op_type: "Greater_Or_Equal"
}
name: "test_greater_or_equal_float"
input {
name: "A"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "C"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 16
}
53 changes: 53 additions & 0 deletions src/frontends/onnx/tests/models/greater_or_equal_int.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
ir_version: 7
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "A"
input: "B"
output: "C"
op_type: "Greater_Or_Equal"
}
name: "test_greater_or_equal_int"
input {
name: "A"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "B"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "C"
type {
tensor_type {
elem_type: 2
shape {
dim {
dim_value: 2
}
}
}
}
}
}
opset_import {
version: 16
}
58 changes: 58 additions & 0 deletions src/frontends/onnx/tests/onnx_import.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6976,3 +6976,61 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_model_mm_nms_rotated) {

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_float) {
auto function = onnx_import::import_onnx_model(file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/greater_or_equal_float.onnx"));

auto test_case = ov::test::TestCase(function, s_device);

test_case.add_input<int64_t>({10});
test_case.add_expected_output<float>(Shape{10},
{-0.000000014901161f,
0.040212844f,
0.20077012f,
0.50978714f,
0.8492299f,
0.99999994f,
0.84922975f,
0.5097869f,
0.20077008f,
0.040212862f});

// GPU has an accuracy drop, need to use different tolerance
if (std::string("${BACKEND_NAME}") != std::string("IE_GPU")) {
test_case.run_with_tolerance_as_fp();
} else {
test_case.run_with_tolerance_as_fp(0.01f);
}
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_int) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/greater_or_equal_int.onnx"));

auto test_case = ov::test::TestCase(function, s_device);

test_case.add_input<int64_t>(Shape{2}, {10, 20});
test_case.add_input<int64_t>(Shape{2}, {15, 15});
test_case.add_expected_output<bool>(Shape{2}, {false, true});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_model_greater_or_equal_float) {
auto function = onnx_import::import_onnx_model(
file_util::path_join(ov::test::utils::getExecutableDirectory(),
SERIALIZED_ZOO,
"onnx/greater_or_equal_float.onnx"));

auto test_case = ov::test::TestCase(function, s_device);

test_case.add_input<float>(Shape{2}, {12.03513f, 22.03513f});
test_case.add_input<float>(Shape{2}, {5.84916f, 22.03513f});
test_case.add_expected_output<bool>(Shape{2}, {true, true});

test_case.run();
}

0 comments on commit 76e3f96

Please sign in to comment.