33
33
34
34
#include " onnx/common/assertions.h"
35
35
#include " onnxoptimizer/pass.h"
36
+ #include " onnxoptimizer/passes/pass_util.h"
36
37
37
38
namespace ONNX_NAMESPACE {
38
39
namespace optimization {
@@ -46,132 +47,157 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
46
47
return " fuse_bn_into_conv" ;
47
48
}
48
49
49
- void replace_inputs (Tensor& W, Tensor& b, Node* conv, Graph& graph) {
50
- W.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
51
- Value* new_W_value = graph.addInitializerAndCreateValue (W);
52
- Value* old_W_value = conv->inputs ()[1 ];
53
- conv->replaceInput (1 , new_W_value);
54
- if (old_W_value->uses ().size () == 0 ) {
55
- graph.eraseInitializerAndInput (old_W_value);
56
- }
57
-
58
- if (conv->inputs ().size () == 3 ) {
59
- b.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
60
- Value* new_b_value = graph.addInitializerAndCreateValue (b);
61
- Value* old_b_value = conv->inputs ()[2 ];
62
- conv->replaceInput (2 , new_b_value);
63
- if (old_b_value->uses ().size () == 0 ) {
64
- graph.eraseInitializerAndInput (old_b_value);
65
- }
66
- } else {
67
- Value* new_b_value = graph.addInitializerAndCreateValue (b);
68
- conv->addInput (new_b_value);
69
- }
70
- }
71
-
72
50
bool modify_conv (Node* conv, Node* bn, Graph& graph) {
73
51
const auto & bn_inputs = bn->inputs ();
74
52
const auto & conv_inputs = conv->inputs ();
75
- auto end_iter = graph.initializers ().end ();
76
- auto s_iter = graph.getInitializer (bn_inputs[1 ]->uniqueName ());
77
- auto bbn_iter = graph.getInitializer (bn_inputs[2 ]->uniqueName ());
78
- auto m_iter = graph.getInitializer (bn_inputs[3 ]->uniqueName ());
79
- auto var_iter = graph.getInitializer (bn_inputs[4 ]->uniqueName ());
80
- auto W_iter = graph.getInitializer (conv_inputs[1 ]->uniqueName ());
81
- if (s_iter == end_iter || bbn_iter == end_iter || m_iter == end_iter ||
82
- var_iter == end_iter || W_iter == end_iter) {
83
- return false ;
84
- }
85
53
86
- ONNX_ASSERT (s_iter->sizes ().size () == 1 );
87
- ONNX_ASSERT (bbn_iter->sizes ().size () == 1 &&
88
- bbn_iter->sizes ()[0 ] == s_iter->sizes ()[0 ]);
89
- ONNX_ASSERT (m_iter->sizes ().size () == 1 &&
90
- m_iter->sizes ()[0 ] == s_iter->sizes ()[0 ]);
91
- ONNX_ASSERT (var_iter->sizes ().size () == 1 &&
92
- var_iter->sizes ()[0 ] == s_iter->sizes ()[0 ]);
93
- ONNX_ASSERT (W_iter->sizes ().size () > 2 &&
94
- W_iter->sizes ()[0 ] == s_iter->sizes ()[0 ]);
95
- ONNX_ASSERT (s_iter->elem_type () == bbn_iter->elem_type () &&
96
- s_iter->elem_type () == m_iter->elem_type () &&
97
- s_iter->elem_type () == var_iter->elem_type () &&
98
- s_iter->elem_type () == W_iter->elem_type ());
99
- if (s_iter->elem_type () != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
100
- s_iter->elem_type () != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
54
+ auto bn_scale = *FetchConstantTensor (bn_inputs[1 ]);
55
+ auto bn_bais = *FetchConstantTensor (bn_inputs[2 ]);
56
+ auto bn_mean = *FetchConstantTensor (bn_inputs[3 ]);
57
+ auto bn_var = *FetchConstantTensor (bn_inputs[4 ]);
58
+ auto conv_W = *FetchConstantTensor (conv_inputs[1 ]);
59
+ bn_scale.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
60
+ bn_bais.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
61
+ bn_mean.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
62
+ bn_var.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
63
+ conv_W.setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
64
+
65
+ // / scale bais mean var must be the same shape (C)
66
+ ONNX_ASSERT (bn_scale.sizes () == bn_bais.sizes ());
67
+ ONNX_ASSERT (bn_scale.sizes () == bn_mean.sizes ());
68
+ ONNX_ASSERT (bn_scale.sizes () == bn_var.sizes ());
69
+ ONNX_ASSERT (bn_scale.sizes ().size () == 1 );
70
+ int64_t C = bn_scale.sizes ()[0 ];
71
+ ONNX_ASSERT (conv_W.sizes ().size () > 2 && conv_W.sizes ()[0 ] == C);
72
+ if (bn_scale.elem_type () != bn_bais.elem_type () ||
73
+ bn_scale.elem_type () != bn_mean.elem_type () ||
74
+ bn_scale.elem_type () != bn_var.elem_type () ||
75
+ bn_scale.elem_type () != conv_W.elem_type ()) {
101
76
return false ;
102
77
}
103
78
104
- Tensor bc ;
79
+ Value* conv_bias = nullptr ;
105
80
if (conv_inputs.size () == 3 ) {
106
- auto bc_iter = graph.getInitializer (conv_inputs[2 ]->uniqueName ());
107
- if (bc_iter == end_iter) {
81
+ if (!IsConstantTensor (conv_inputs[2 ])) {
108
82
return false ;
109
83
}
110
- bc = *bc_iter;
111
- ONNX_ASSERT (bc.sizes ().size () == 1 &&
112
- bc.sizes ()[0 ] == s_iter->sizes ()[0 ]);
84
+ auto bc_t = *FetchConstantTensor (conv_inputs[2 ]);
85
+ bc_t .setName (ONNX_NAMESPACE::to_string (graph.getNextUnique ()));
86
+ ONNX_ASSERT (bc_t .sizes () == bn_scale.sizes ());
87
+ conv_bias = graph.addInitializerAndCreateValue (bc_t );
88
+ } else {
89
+ Tensor bc_t ;
90
+ bc_t .elem_type () = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
91
+ bc_t .sizes ().push_back (C);
92
+ for (int i = 0 ; i < C; ++i) {
93
+ bc_t .floats ().push_back (float {0 });
94
+ }
95
+ conv_bias = graph.addInitializerAndCreateValue (bc_t );
113
96
}
114
97
115
- Tensor s = *s_iter;
116
- const Tensor& bbn = *bbn_iter;
117
- const Tensor& m = *m_iter;
118
- Tensor var = *var_iter;
119
- Tensor W = *W_iter;
120
- float epsilon = bn->hasAttribute (kepsilon) ? (float )bn->f (kepsilon) : 1e-5f ;
121
- Tensor eps;
122
-
123
- #define DO_COMPUTATION (TENSOR_TYPE, vec ) \
124
- eps.sizes ().push_back (s.sizes ()[0 ]); \
125
- eps.elem_type () = ONNX_NAMESPACE::TensorProto_DataType_##TENSOR_TYPE; \
126
- for (int64_t i = 0 ; i < eps.sizes ()[0 ]; ++i) { \
127
- eps.vec ().push_back (epsilon); \
128
- } \
129
- if (conv_inputs.size () != 3 ) { \
130
- bc.sizes ().push_back (s.sizes ()[0 ]); \
131
- bc.elem_type () = ONNX_NAMESPACE::TensorProto_DataType_##TENSOR_TYPE; \
132
- for (int64_t i = 0 ; i < eps.sizes ()[0 ]; ++i) { \
133
- bc.vec ().push_back (0 .f ); \
134
- } \
135
- } \
136
- var.add (eps); \
137
- var.sqrt (); \
138
- s.divide (var); \
139
- W.scale_by_first_dim (s); \
140
- bc.subtract (m); \
141
- bc.multiply (s); \
142
- bc.add (bbn);
143
-
144
- switch (s.elem_type ()) {
145
- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
146
- DO_COMPUTATION (FLOAT, floats)
147
- break ;
148
- }
149
- case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
150
- DO_COMPUTATION (DOUBLE, doubles)
151
- break ;
98
+ // / scalar
99
+ Tensor eps_t ;
100
+ eps_t .elem_type () = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
101
+ eps_t .floats ().push_back (GetValueFromAttrWithDefault (bn, kepsilon, 1e-5f ));
102
+ Value* eps = graph.addInitializerAndCreateValue (eps_t );
103
+
104
+ Node* cast = graph.create (kCast , 1 );
105
+ cast->addInput (eps);
106
+ cast->i_ (kto, bn_var.elem_type ());
107
+ cast->insertBefore (conv);
108
+
109
+ Node* var_add = graph.create (kAdd , 1 );
110
+ var_add->insertAfter (cast);
111
+ var_add->addInput (graph.addInitializerAndCreateValue (bn_var));
112
+ var_add->addInput (cast->output ());
113
+
114
+ Node* sqrt = graph.create (kSqrt , 1 );
115
+ sqrt ->insertAfter (var_add);
116
+ sqrt ->addInput (var_add->output ());
117
+
118
+ Node* scale = graph.create (kDiv , 1 );
119
+ scale->insertAfter (sqrt );
120
+ scale->addInput (graph.addInitializerAndCreateValue (bn_scale));
121
+ scale->addInput (sqrt ->output ());
122
+
123
+ Node* unsqueeze = graph.create (kUnsqueeze , 1 );
124
+ unsqueeze->insertAfter (scale);
125
+ unsqueeze->addInput (scale->output ());
126
+ std::vector<int64_t > insert_dims;
127
+ for (int i = 1 ; i < conv_W.sizes ().size (); ++i) {
128
+ insert_dims.push_back (i);
129
+ }
130
+ if (getOpsetVersion (graph) > 11 ) {
131
+ Tensor shape_s_t ;
132
+ shape_s_t .elem_type () = ONNX_NAMESPACE::TensorProto_DataType_INT64;
133
+ shape_s_t .sizes ().push_back (insert_dims.size ());
134
+ shape_s_t .int64s () = insert_dims;
135
+ unsqueeze->addInput (graph.addInitializerAndCreateValue (shape_s_t ));
136
+ } else {
137
+ unsqueeze->is_ (kaxes, std::move (insert_dims));
138
+ }
139
+
140
+ Node* mul_w = graph.create (kMul , 1 );
141
+ mul_w->insertAfter (unsqueeze);
142
+ mul_w->addInput (graph.addInitializerAndCreateValue (conv_W));
143
+ mul_w->addInput (unsqueeze->output ());
144
+
145
+ Node* cast1 = graph.create (kCast , 1 );
146
+ cast1->insertAfter (mul_w);
147
+ cast1->addInput (conv_bias);
148
+ cast1->i_ (kto, bn_mean.elem_type ());
149
+
150
+ Node* sub = graph.create (kSub , 1 );
151
+ sub->insertAfter (cast1);
152
+ sub->addInput (cast1->output ());
153
+ sub->addInput (graph.addInitializerAndCreateValue (bn_mean));
154
+
155
+ Node* mul = graph.create (kMul , 1 );
156
+ mul->insertAfter (sub);
157
+ mul->addInput (sub->output ());
158
+ mul->addInput (scale->output ());
159
+
160
+ Node* bias_add = graph.create (kAdd , 1 );
161
+ bias_add->insertAfter (mul);
162
+ bias_add->addInput (mul->output ());
163
+ bias_add->addInput (graph.addInitializerAndCreateValue (bn_bais));
164
+
165
+ Value* old_w_value = conv_inputs[1 ];
166
+ conv->replaceInput (1 , mul_w->output ());
167
+ if (old_w_value->uses ().size () == 0 ) {
168
+ graph.eraseInitializerAndInput (old_w_value);
169
+ }
170
+
171
+ if (conv_inputs.size () == 3 ) {
172
+ Value* old_b_value = conv_inputs[2 ];
173
+ conv->replaceInput (2 , bias_add->output ());
174
+ if (old_b_value->uses ().size () == 0 ) {
175
+ graph.eraseInitializerAndInput (old_b_value);
152
176
}
153
- default :
154
- return false ;
177
+ } else {
178
+ conv-> addInput (bias_add-> output ()) ;
155
179
}
156
- #undef DO_COMPUTATION
157
- replace_inputs (W, bc, conv, graph);
158
180
return true ;
159
181
}
160
182
161
- bool patternMatchPredicate (Node* node) override {
162
- return node->kind () == kBatchNormalization &&
163
- node->inputs ()[0 ]->node ()->kind () == kConv ;
183
+ bool patternMatchPredicate (Node* n) override {
184
+ return CheckKind (n, kBatchNormalization , 0 , kConv ) &&
185
+ GetValueFromAttrWithDefault (n, " training_mode" , (int64_t )0 ) == 0 &&
186
+ n->input (0 )->uses ().size () == 1 && n->outputs ().size () == 1 &&
187
+ IsConstantTensor (n, 1 ) && IsConstantTensor (n, 2 ) &&
188
+ IsConstantTensor (n, 3 ) && IsConstantTensor (n, 4 ) &&
189
+ IsConstantTensor (PrevNode (n, 0 ), 1 );
164
190
}
165
191
bool runTransform (Node* n, Graph& graph,
166
192
NodeDestroyType& destroy_current) override {
167
193
Node* bn = n;
168
- Node* conv = n-> inputs ()[ 0 ]-> node ( );
194
+ Node* conv = PrevNode (n, 0 );
169
195
auto origInput = bn->inputs ()[0 ];
170
- if (origInput->uses ().size () > 1 || bn->outputs ().size () > 1 ||
171
- !modify_conv (conv, bn, graph)) {
196
+ if (!modify_conv (conv, bn, graph)) {
172
197
destroy_current = NodeDestroyType::DestroyZero;
173
198
return false ;
174
199
}
200
+ // clean
175
201
for (int i = 4 ; i >= 1 ; --i) {
176
202
if (bn->inputs ()[i]->uses ().size () == 1 ) {
177
203
auto input = bn->inputs ()[i];
0 commit comments