Skip to content

Commit

Permalink
add code from onnx::pad
Browse files Browse the repository at this point in the history
  • Loading branch information
siddhant-0707 committed Mar 29, 2024
1 parent 6c8c132 commit 602618c
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 8 deletions.
66 changes: 61 additions & 5 deletions src/frontends/onnx/frontend/src/op/com.microsoft/pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,31 @@
#include "op/com.microsoft/pad.hpp"

#include "exceptions.hpp"
#include "op/pad.hpp"
#include "onnx_import/core/null_node.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/pad.hpp"
#include "openvino/op/squeeze.hpp"
#include "openvino/op/util/op_types.hpp"
#include "ov_models/ov_builders/split.hpp"
#include "utils/reshape.hpp"

namespace {
ov::op::PadMode get_pad_mode(std::string mode) {
ov::op::PadMode pad_mode;

if (mode == "constant") {
pad_mode = ov::op::PadMode::CONSTANT;
} else if (mode == "reflect") {
pad_mode = ov::op::PadMode::REFLECT;
} else if (mode == "edge") {
pad_mode = ov::op::PadMode::EDGE;
} else {
OPENVINO_THROW("Unsupported padding mode: [" + mode + "]");
}

return pad_mode;
}
} // namespace

using namespace ov::op;

Expand All @@ -16,11 +39,44 @@ namespace op {
namespace custom {
namespace set_1 {
ov::OutputVector pad(const Node& node) {
auto node_1 = node;
node_1.get_ng_inputs()[1] = std::make_shared<v0::Squeeze>(node_1.get_ng_inputs()[1]);
auto result = set_11::pad(node_1).at(0);
const auto inputs = node.get_ng_inputs();
const auto& data = inputs[0];
const auto& pads_input = inputs[1];
auto pads = pads_input;
if (pads.get_shape().size() == 2) {
pads = std::make_shared<v0::Squeeze>(pads);
}
ov::Output<ov::Node> values;
ov::Output<ov::Node> padding_begin;
ov::Output<ov::Node> padding_end;

if (inputs.size() == 3 && !ov::op::util::is_null(inputs[2])) {
values = reshape::interpret_as_scalar(inputs[2]);
} else {
values = v0::Constant::create(data.get_element_type(), ov::Shape{}, {0});
}

if (ov::op::util::is_constant(pads.get_node())) {
std::vector<std::int64_t> pads_vector =
ov::as_type_ptr<v0::Constant>(pads.get_node_shared_ptr())->get_vector<std::int64_t>();

std::size_t const half_size = pads_vector.size() / 2;
std::vector<std::int64_t> padding_begin_values(pads_vector.begin(), pads_vector.begin() + half_size);
std::vector<std::int64_t> padding_end_values(pads_vector.begin() + half_size, pads_vector.end());

padding_begin = v0::Constant::create(ov::element::i64, ov::Shape{half_size}, padding_begin_values);
padding_end = v0::Constant::create(ov::element::i64, ov::Shape{half_size}, padding_end_values);
} else {
ov::OutputVector padding = ov::op::util::split(pads, 2, 0);

padding_begin = padding.at(0);
padding_end = padding.at(1);
}

const std::string mode = node.get_attribute_value<std::string>("mode", "constant");
ov::op::PadMode pad_mode = get_pad_mode(mode);

return {result};
return {std::make_shared<v12::Pad>(data, padding_begin, padding_end, values, pad_mode)};
}
} // namespace set_1
} // namespace custom
Expand Down
66 changes: 66 additions & 0 deletions src/frontends/onnx/tests/models/com.microsoft/pad_1d.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
ir_version: 6
producer_name: "OV ONNX Frontend"
graph {
node {
input: "x"
input: "pads"
output: "y"
op_type: "Pad"
attribute {
name: "mode"
s: "constant"
type: STRING
}
domain: "com.microsoft"
}
name: "test_pad_1d_microsoft"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 2
}
}
}
}
}
input {
name: "pads"
type {
tensor_type {
elem_type: 7
shape {
dim {
dim_value: 4
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 3
}
dim {
dim_value: 4
}
}
}
}
}
}
opset_import {
domain: "com.microsoft"
version: 1
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ graph {
name: "pads"
type {
tensor_type {
elem_type: 1
elem_type: 7
shape {
dim {
dim_value: 1
Expand Down
16 changes: 14 additions & 2 deletions src/frontends/onnx/tests/onnx_import_com_microsoft.in.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1263,12 +1263,24 @@ OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_trilu_lower) {
// clang-format on
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_pad) {
OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_pad_2d) {
const auto model = convert_model("com.microsoft/pad_2d.onnx");
auto test_case = ov::test::TestCase(model, s_device);

test_case.add_input<float>({1.f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f});
test_case.add_input<float>(Shape{1, 4}, {0.f, 2.f, 0.f, 0.f});
test_case.add_input<int64_t>({0, 2, 0, 0});
test_case.add_expected_output<float>(Shape{3, 4},
{0.f, 0.f, 1.f, 1.2f, 0.f, 0.f, 2.3f, 3.4f, 0.f, 0.f, 4.5f, 5.7f});

test_case.run();
}

OPENVINO_TEST(${BACKEND_NAME}, onnx_com_microsoft_pad_1d) {
const auto model = convert_model("com.microsoft/pad_1d.onnx");
auto test_case = ov::test::TestCase(model, s_device);

test_case.add_input<float>({1.f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f});
test_case.add_input<int64_t>({0, 2, 0, 0});
test_case.add_expected_output<float>(Shape{3, 4},
{0.f, 0.f, 1.f, 1.2f, 0.f, 0.f, 2.3f, 3.4f, 0.f, 0.f, 4.5f, 5.7f});

Expand Down

0 comments on commit 602618c

Please sign in to comment.