Skip to content

Commit 807cff7

Browse files
authored
just rewrite graph to fuse bn into conv (#126)
Signed-off-by: haoshengqiang <[email protected]>
1 parent 27f0345 commit 807cff7

File tree

2 files changed

+128
-114
lines changed

2 files changed

+128
-114
lines changed

onnxoptimizer/passes/fuse_bn_into_conv.h

+127-101
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
#include "onnx/common/assertions.h"
3535
#include "onnxoptimizer/pass.h"
36+
#include "onnxoptimizer/passes/pass_util.h"
3637

3738
namespace ONNX_NAMESPACE {
3839
namespace optimization {
@@ -46,132 +47,157 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
4647
return "fuse_bn_into_conv";
4748
}
4849

49-
void replace_inputs(Tensor& W, Tensor& b, Node* conv, Graph& graph) {
50-
W.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
51-
Value* new_W_value = graph.addInitializerAndCreateValue(W);
52-
Value* old_W_value = conv->inputs()[1];
53-
conv->replaceInput(1, new_W_value);
54-
if (old_W_value->uses().size() == 0) {
55-
graph.eraseInitializerAndInput(old_W_value);
56-
}
57-
58-
if (conv->inputs().size() == 3) {
59-
b.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
60-
Value* new_b_value = graph.addInitializerAndCreateValue(b);
61-
Value* old_b_value = conv->inputs()[2];
62-
conv->replaceInput(2, new_b_value);
63-
if (old_b_value->uses().size() == 0) {
64-
graph.eraseInitializerAndInput(old_b_value);
65-
}
66-
} else {
67-
Value* new_b_value = graph.addInitializerAndCreateValue(b);
68-
conv->addInput(new_b_value);
69-
}
70-
}
71-
7250
bool modify_conv(Node* conv, Node* bn, Graph& graph) {
7351
const auto& bn_inputs = bn->inputs();
7452
const auto& conv_inputs = conv->inputs();
75-
auto end_iter = graph.initializers().end();
76-
auto s_iter = graph.getInitializer(bn_inputs[1]->uniqueName());
77-
auto bbn_iter = graph.getInitializer(bn_inputs[2]->uniqueName());
78-
auto m_iter = graph.getInitializer(bn_inputs[3]->uniqueName());
79-
auto var_iter = graph.getInitializer(bn_inputs[4]->uniqueName());
80-
auto W_iter = graph.getInitializer(conv_inputs[1]->uniqueName());
81-
if (s_iter == end_iter || bbn_iter == end_iter || m_iter == end_iter ||
82-
var_iter == end_iter || W_iter == end_iter) {
83-
return false;
84-
}
8553

86-
ONNX_ASSERT(s_iter->sizes().size() == 1);
87-
ONNX_ASSERT(bbn_iter->sizes().size() == 1 &&
88-
bbn_iter->sizes()[0] == s_iter->sizes()[0]);
89-
ONNX_ASSERT(m_iter->sizes().size() == 1 &&
90-
m_iter->sizes()[0] == s_iter->sizes()[0]);
91-
ONNX_ASSERT(var_iter->sizes().size() == 1 &&
92-
var_iter->sizes()[0] == s_iter->sizes()[0]);
93-
ONNX_ASSERT(W_iter->sizes().size() > 2 &&
94-
W_iter->sizes()[0] == s_iter->sizes()[0]);
95-
ONNX_ASSERT(s_iter->elem_type() == bbn_iter->elem_type() &&
96-
s_iter->elem_type() == m_iter->elem_type() &&
97-
s_iter->elem_type() == var_iter->elem_type() &&
98-
s_iter->elem_type() == W_iter->elem_type());
99-
if (s_iter->elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
100-
s_iter->elem_type() != ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) {
54+
auto bn_scale = *FetchConstantTensor(bn_inputs[1]);
55+
auto bn_bais = *FetchConstantTensor(bn_inputs[2]);
56+
auto bn_mean = *FetchConstantTensor(bn_inputs[3]);
57+
auto bn_var = *FetchConstantTensor(bn_inputs[4]);
58+
auto conv_W = *FetchConstantTensor(conv_inputs[1]);
59+
bn_scale.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
60+
bn_bais.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
61+
bn_mean.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
62+
bn_var.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
63+
conv_W.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
64+
65+
/// scale bais mean var must be the same shape (C)
66+
ONNX_ASSERT(bn_scale.sizes() == bn_bais.sizes());
67+
ONNX_ASSERT(bn_scale.sizes() == bn_mean.sizes());
68+
ONNX_ASSERT(bn_scale.sizes() == bn_var.sizes());
69+
ONNX_ASSERT(bn_scale.sizes().size() == 1);
70+
int64_t C = bn_scale.sizes()[0];
71+
ONNX_ASSERT(conv_W.sizes().size() > 2 && conv_W.sizes()[0] == C);
72+
if (bn_scale.elem_type() != bn_bais.elem_type() ||
73+
bn_scale.elem_type() != bn_mean.elem_type() ||
74+
bn_scale.elem_type() != bn_var.elem_type() ||
75+
bn_scale.elem_type() != conv_W.elem_type()) {
10176
return false;
10277
}
10378

104-
Tensor bc;
79+
Value* conv_bias = nullptr;
10580
if (conv_inputs.size() == 3) {
106-
auto bc_iter = graph.getInitializer(conv_inputs[2]->uniqueName());
107-
if (bc_iter == end_iter) {
81+
if (!IsConstantTensor(conv_inputs[2])) {
10882
return false;
10983
}
110-
bc = *bc_iter;
111-
ONNX_ASSERT(bc.sizes().size() == 1 &&
112-
bc.sizes()[0] == s_iter->sizes()[0]);
84+
auto bc_t = *FetchConstantTensor(conv_inputs[2]);
85+
bc_t.setName(ONNX_NAMESPACE::to_string(graph.getNextUnique()));
86+
ONNX_ASSERT(bc_t.sizes() == bn_scale.sizes());
87+
conv_bias = graph.addInitializerAndCreateValue(bc_t);
88+
} else {
89+
Tensor bc_t;
90+
bc_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
91+
bc_t.sizes().push_back(C);
92+
for (int i = 0; i < C; ++i) {
93+
bc_t.floats().push_back(float{0});
94+
}
95+
conv_bias = graph.addInitializerAndCreateValue(bc_t);
11396
}
11497

115-
Tensor s = *s_iter;
116-
const Tensor& bbn = *bbn_iter;
117-
const Tensor& m = *m_iter;
118-
Tensor var = *var_iter;
119-
Tensor W = *W_iter;
120-
float epsilon = bn->hasAttribute(kepsilon) ? (float)bn->f(kepsilon) : 1e-5f;
121-
Tensor eps;
122-
123-
#define DO_COMPUTATION(TENSOR_TYPE, vec) \
124-
eps.sizes().push_back(s.sizes()[0]); \
125-
eps.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_##TENSOR_TYPE; \
126-
for (int64_t i = 0; i < eps.sizes()[0]; ++i) { \
127-
eps.vec().push_back(epsilon); \
128-
} \
129-
if (conv_inputs.size() != 3) { \
130-
bc.sizes().push_back(s.sizes()[0]); \
131-
bc.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_##TENSOR_TYPE; \
132-
for (int64_t i = 0; i < eps.sizes()[0]; ++i) { \
133-
bc.vec().push_back(0.f); \
134-
} \
135-
} \
136-
var.add(eps); \
137-
var.sqrt(); \
138-
s.divide(var); \
139-
W.scale_by_first_dim(s); \
140-
bc.subtract(m); \
141-
bc.multiply(s); \
142-
bc.add(bbn);
143-
144-
switch (s.elem_type()) {
145-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: {
146-
DO_COMPUTATION(FLOAT, floats)
147-
break;
148-
}
149-
case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: {
150-
DO_COMPUTATION(DOUBLE, doubles)
151-
break;
98+
/// scalar
99+
Tensor eps_t;
100+
eps_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
101+
eps_t.floats().push_back(GetValueFromAttrWithDefault(bn, kepsilon, 1e-5f));
102+
Value* eps = graph.addInitializerAndCreateValue(eps_t);
103+
104+
Node* cast = graph.create(kCast, 1);
105+
cast->addInput(eps);
106+
cast->i_(kto, bn_var.elem_type());
107+
cast->insertBefore(conv);
108+
109+
Node* var_add = graph.create(kAdd, 1);
110+
var_add->insertAfter(cast);
111+
var_add->addInput(graph.addInitializerAndCreateValue(bn_var));
112+
var_add->addInput(cast->output());
113+
114+
Node* sqrt = graph.create(kSqrt, 1);
115+
sqrt->insertAfter(var_add);
116+
sqrt->addInput(var_add->output());
117+
118+
Node* scale = graph.create(kDiv, 1);
119+
scale->insertAfter(sqrt);
120+
scale->addInput(graph.addInitializerAndCreateValue(bn_scale));
121+
scale->addInput(sqrt->output());
122+
123+
Node* unsqueeze = graph.create(kUnsqueeze, 1);
124+
unsqueeze->insertAfter(scale);
125+
unsqueeze->addInput(scale->output());
126+
std::vector<int64_t> insert_dims;
127+
for (int i = 1; i < conv_W.sizes().size(); ++i) {
128+
insert_dims.push_back(i);
129+
}
130+
if (getOpsetVersion(graph) > 11) {
131+
Tensor shape_s_t;
132+
shape_s_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64;
133+
shape_s_t.sizes().push_back(insert_dims.size());
134+
shape_s_t.int64s() = insert_dims;
135+
unsqueeze->addInput(graph.addInitializerAndCreateValue(shape_s_t));
136+
} else {
137+
unsqueeze->is_(kaxes, std::move(insert_dims));
138+
}
139+
140+
Node* mul_w = graph.create(kMul, 1);
141+
mul_w->insertAfter(unsqueeze);
142+
mul_w->addInput(graph.addInitializerAndCreateValue(conv_W));
143+
mul_w->addInput(unsqueeze->output());
144+
145+
Node* cast1 = graph.create(kCast, 1);
146+
cast1->insertAfter(mul_w);
147+
cast1->addInput(conv_bias);
148+
cast1->i_(kto, bn_mean.elem_type());
149+
150+
Node* sub = graph.create(kSub, 1);
151+
sub->insertAfter(cast1);
152+
sub->addInput(cast1->output());
153+
sub->addInput(graph.addInitializerAndCreateValue(bn_mean));
154+
155+
Node* mul = graph.create(kMul, 1);
156+
mul->insertAfter(sub);
157+
mul->addInput(sub->output());
158+
mul->addInput(scale->output());
159+
160+
Node* bias_add = graph.create(kAdd, 1);
161+
bias_add->insertAfter(mul);
162+
bias_add->addInput(mul->output());
163+
bias_add->addInput(graph.addInitializerAndCreateValue(bn_bais));
164+
165+
Value* old_w_value = conv_inputs[1];
166+
conv->replaceInput(1, mul_w->output());
167+
if (old_w_value->uses().size() == 0) {
168+
graph.eraseInitializerAndInput(old_w_value);
169+
}
170+
171+
if (conv_inputs.size() == 3) {
172+
Value* old_b_value = conv_inputs[2];
173+
conv->replaceInput(2, bias_add->output());
174+
if (old_b_value->uses().size() == 0) {
175+
graph.eraseInitializerAndInput(old_b_value);
152176
}
153-
default:
154-
return false;
177+
} else {
178+
conv->addInput(bias_add->output());
155179
}
156-
#undef DO_COMPUTATION
157-
replace_inputs(W, bc, conv, graph);
158180
return true;
159181
}
160182

161-
bool patternMatchPredicate(Node* node) override {
162-
return node->kind() == kBatchNormalization &&
163-
node->inputs()[0]->node()->kind() == kConv;
183+
bool patternMatchPredicate(Node* n) override {
184+
return CheckKind(n, kBatchNormalization, 0, kConv) &&
185+
GetValueFromAttrWithDefault(n, "training_mode", (int64_t)0) == 0 &&
186+
n->input(0)->uses().size() == 1 && n->outputs().size() == 1 &&
187+
IsConstantTensor(n, 1) && IsConstantTensor(n, 2) &&
188+
IsConstantTensor(n, 3) && IsConstantTensor(n, 4) &&
189+
IsConstantTensor(PrevNode(n, 0), 1);
164190
}
165191
bool runTransform(Node* n, Graph& graph,
166192
NodeDestroyType& destroy_current) override {
167193
Node* bn = n;
168-
Node* conv = n->inputs()[0]->node();
194+
Node* conv = PrevNode(n, 0);
169195
auto origInput = bn->inputs()[0];
170-
if (origInput->uses().size() > 1 || bn->outputs().size() > 1 ||
171-
!modify_conv(conv, bn, graph)) {
196+
if (!modify_conv(conv, bn, graph)) {
172197
destroy_current = NodeDestroyType::DestroyZero;
173198
return false;
174199
}
200+
// clean
175201
for (int i = 4; i >= 1; --i) {
176202
if (bn->inputs()[i]->uses().size() == 1) {
177203
auto input = bn->inputs()[i];

onnxoptimizer/test/optimizer_test.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -3061,19 +3061,7 @@ def test_fuse_bn_into_conv_simple(self): # type: () -> None
30613061
helper.make_tensor_value_info("Y", tensor_type, (5, 3, 24, 24))
30623062
],
30633063
)
3064-
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"])
3065-
3066-
self.assertEqual(len(optimized_model.graph.node), 1)
3067-
self.assertEqual(optimized_model.graph.node[0].op_type, "Conv")
3068-
self.assertEqual(len(optimized_model.graph.initializer), 2)
3069-
new_W = numpy_helper.to_array(optimized_model.graph.initializer[0])
3070-
new_b = numpy_helper.to_array(optimized_model.graph.initializer[1])
3071-
3072-
f = scale / np.sqrt(var + 1e-5)
3073-
np.testing.assert_almost_equal((B - mean) * f + b, new_b)
3074-
np.testing.assert_almost_equal(
3075-
W * f[:, np.newaxis, np.newaxis, np.newaxis], new_W
3076-
)
3064+
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"]) # noqa
30773065

30783066
def _internal_test_deadend_elimination(self, fixed): # type: (bool) -> None
30793067
softmax = helper.make_node("Softmax", ["X"], ["Y"], axis=2)

0 commit comments

Comments
 (0)