Skip to content

Commit

Permalink
[ONNX] Extended BatchNorm to support opsets 1-6-7-9-14-15 (openvinoto…
Browse files Browse the repository at this point in the history
…olkit#23337)

### Details:
 - *Addition of Batchnorm opset 1 6 7 9 14 & 15*
- *Creation of the tests CPP for the respective opset, and their
prototxt models.*

### Tickets:
- 20554 :
[20554](openvinotoolkit#20554)
- 18485 :
[18485](openvinotoolkit#18485)

---------

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
Bepitic and Ubuntu authored Mar 29, 2024
1 parent 3e114be commit 34ebb77
Show file tree
Hide file tree
Showing 10 changed files with 814 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/frontends/onnx/frontend/src/op/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ ov::OutputVector batch_norm(const ov::frontend::onnx::Node& node) {
OPENVINO_THROW("Cannot create OpenVINO batch norm with unsupported number of inputs");
}
} // namespace set_1
/*
Opset 6 is skipped because there are no significant difference between opset1 and opset6.
Found difference is:
1. In Training, the computation of ReduceMean and ReduceVar uses float
to avoid overflow for float16 inputs.
*/

namespace set_7 {
// This version supports ONNX BatchNormalization-7 and BatchNormalization-9
Expand All @@ -71,8 +77,42 @@ ov::OutputVector batch_norm(const ov::frontend::onnx::Node& node) {

return {std::make_shared<v5::BatchNormInference>(x, scale, bias, mean, var, epsilon)};
}

} // namespace set_7
/*
Opset 9 is skipped because there are no significant difference between opset7 and opset9.
Found difference is:
1. removed -> spatial : int (default is 1)
If true, compute the mean and variance across per activation. If false, compute the mean and variance across
per feature over each mini-batch.
*/

namespace set_14 {
// This version supports ONNX BatchNormalization-14 BatchNormalization-15
ov::OutputVector batch_norm(const ov::frontend::onnx::Node& node) {
ov::OutputVector inputs{node.get_ov_inputs()};
auto x = inputs.at(0);
auto scale = inputs.at(1);
auto bias = inputs.at(2);
auto mean = inputs.at(3);
auto var = inputs.at(4);

double epsilon{node.get_attribute_value<double>("epsilon", 1e-5)};
int64_t training_mode{node.get_attribute_value<int64_t>("training_mode", 0)};

CHECK_VALID_NODE(node,
training_mode == false && node.get_outputs_size() == 1,
"Training mode of BatchNormalization is not supported.");
return {std::make_shared<v5::BatchNormInference>(x, scale, bias, mean, var, epsilon)};
}
} // namespace set_14
/*
Opset 15 is skipped because there are no significant difference between opset14 and opset15.
Found difference is:
1. In Training, the computation of ReduceMean and ReduceVar uses float
to avoid overflow for float16 inputs.
*/

} // namespace op
} // namespace onnx
} // namespace frontend
Expand Down
5 changes: 5 additions & 0 deletions src/frontends/onnx/frontend/src/op/batch_norm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ namespace set_7 {
ov::OutputVector batch_norm(const ov::frontend::onnx::Node& node);

} // namespace set_7

namespace set_14 {
ov::OutputVector batch_norm(const ov::frontend::onnx::Node& node);

} // namespace set_14
} // namespace op
} // namespace onnx
} // namespace frontend
Expand Down
1 change: 1 addition & 0 deletions src/frontends/onnx/frontend/src/ops_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ OperatorsBridge::OperatorsBridge() {
REGISTER_OPERATOR("AveragePool", 1, average_pool);
REGISTER_OPERATOR("BatchNormalization", 1, batch_norm);
REGISTER_OPERATOR("BatchNormalization", 7, batch_norm);
REGISTER_OPERATOR("BatchNormalization", 14, batch_norm);
REGISTER_OPERATOR("BitShift", 1, bitshift);
REGISTER_OPERATOR("BitwiseAnd", 1, bitwise_and);
REGISTER_OPERATOR("BitwiseNot", 1, bitwise_not);
Expand Down
113 changes: 113 additions & 0 deletions src/frontends/onnx/tests/models/batchnorm_opset1.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
ir_version: 3
producer_name: "OpenVINO ONNX Frontend"
graph {
node {
input: "x"
input: "s"
input: "bias"
input: "mean"
input: "var"
output: "y"
op_type: "BatchNormalization"
}
name: "test_batchnorm_example"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "s"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "mean"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "var"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 1
}
113 changes: 113 additions & 0 deletions src/frontends/onnx/tests/models/batchnorm_opset14.prototxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
ir_version: 3
producer_name: "OpenVINO ONNX Frontend"
graph {
node {
input: "x"
input: "s"
input: "bias"
input: "mean"
input: "var"
output: "y"
op_type: "BatchNormalization"
}
name: "test_batchnorm_example"
input {
name: "x"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
input {
name: "s"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "bias"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "mean"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
input {
name: "var"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 2
}
}
}
}
}
output {
name: "y"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 1
}
dim {
dim_value: 2
}
dim {
dim_value: 1
}
dim {
dim_value: 3
}
}
}
}
}
}
opset_import {
version: 14
}
Loading

0 comments on commit 34ebb77

Please sign in to comment.