@@ -47,7 +47,7 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
47
47
return " fuse_bn_into_conv" ;
48
48
}
49
49
50
- bool modify_conv (Node* conv, Node* bn, Graph& graph) {
50
+ bool modify_conv (Node* conv, Node* bn, Graph& graph, const bool is_conv ) {
51
51
const auto & bn_inputs = bn->inputs ();
52
52
const auto & conv_inputs = conv->inputs ();
53
53
@@ -123,10 +123,9 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
123
123
Node* unsqueeze = graph.create (kUnsqueeze , 1 );
124
124
unsqueeze->insertAfter (scale);
125
125
unsqueeze->addInput (scale->output ());
126
- std::vector<int64_t > insert_dims;
127
- for (int i = 1 ; i < conv_W.sizes ().size (); ++i) {
128
- insert_dims.push_back (i);
129
- }
126
+ std::vector<int64_t > insert_dims (conv_W.sizes ().size ());
127
+ std::iota (insert_dims.begin (), insert_dims.end (), 0 );
128
+ insert_dims.erase (insert_dims.begin () + (is_conv ? 0 : 1 ));
130
129
if (getOpsetVersion (graph) > 11 ) {
131
130
Tensor shape_s_t ;
132
131
shape_s_t .elem_type () = ONNX_NAMESPACE::TensorProto_DataType_INT64;
@@ -181,7 +180,8 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
181
180
}
182
181
183
182
bool patternMatchPredicate (Node* n) override {
184
- return CheckKind (n, kBatchNormalization , 0 , kConv ) &&
183
+ return (CheckKind (n, kBatchNormalization , 0 , kConv ) ||
184
+ CheckKind (n, kBatchNormalization , 0 , kConvTranspose )) &&
185
185
GetValueFromAttrWithDefault (n, " training_mode" , (int64_t )0 ) == 0 &&
186
186
n->input (0 )->uses ().size () == 1 && n->outputs ().size () == 1 &&
187
187
IsConstantTensor (n, 1 ) && IsConstantTensor (n, 2 ) &&
@@ -190,10 +190,12 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
190
190
}
191
191
bool runTransform (Node* n, Graph& graph,
192
192
NodeDestroyType& destroy_current) override {
193
+ const bool is_conv = CheckKind (n, kBatchNormalization , 0 , kConv );
194
+
193
195
Node* bn = n;
194
196
Node* conv = PrevNode (n, 0 );
195
197
auto origInput = bn->inputs ()[0 ];
196
- if (!modify_conv (conv, bn, graph)) {
198
+ if (!modify_conv (conv, bn, graph, is_conv )) {
197
199
destroy_current = NodeDestroyType::DestroyZero;
198
200
return false ;
199
201
}
0 commit comments