Skip to content

Commit 5d4d388

Browse files
author
wenyuchi.wyc
committed
Support fuse bn into ConvTranspose.
Signed-off-by: wenyuchi.wyc <[email protected]>
1 parent 807cff7 commit 5d4d388

File tree

2 files changed

+49
-7
lines changed

2 files changed

+49
-7
lines changed

onnxoptimizer/passes/fuse_bn_into_conv.h

+9-7
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
4747
return "fuse_bn_into_conv";
4848
}
4949

50-
bool modify_conv(Node* conv, Node* bn, Graph& graph) {
50+
bool modify_conv(Node* conv, Node* bn, Graph& graph, const bool is_conv) {
5151
const auto& bn_inputs = bn->inputs();
5252
const auto& conv_inputs = conv->inputs();
5353

@@ -123,10 +123,9 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
123123
Node* unsqueeze = graph.create(kUnsqueeze, 1);
124124
unsqueeze->insertAfter(scale);
125125
unsqueeze->addInput(scale->output());
126-
std::vector<int64_t> insert_dims;
127-
for (int i = 1; i < conv_W.sizes().size(); ++i) {
128-
insert_dims.push_back(i);
129-
}
126+
std::vector<int64_t> insert_dims(conv_W.sizes().size());
127+
std::iota(insert_dims.begin(), insert_dims.end(), 0);
128+
insert_dims.erase(insert_dims.begin() + (is_conv ? 0 : 1));
130129
if (getOpsetVersion(graph) > 11) {
131130
Tensor shape_s_t;
132131
shape_s_t.elem_type() = ONNX_NAMESPACE::TensorProto_DataType_INT64;
@@ -181,7 +180,8 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
181180
}
182181

183182
bool patternMatchPredicate(Node* n) override {
184-
return CheckKind(n, kBatchNormalization, 0, kConv) &&
183+
return (CheckKind(n, kBatchNormalization, 0, kConv) ||
184+
CheckKind(n, kBatchNormalization, 0, kConvTranspose)) &&
185185
GetValueFromAttrWithDefault(n, "training_mode", (int64_t)0) == 0 &&
186186
n->input(0)->uses().size() == 1 && n->outputs().size() == 1 &&
187187
IsConstantTensor(n, 1) && IsConstantTensor(n, 2) &&
@@ -190,10 +190,12 @@ struct FuseBNIntoConv final : public PredicateBasedPass {
190190
}
191191
bool runTransform(Node* n, Graph& graph,
192192
NodeDestroyType& destroy_current) override {
193+
const bool is_conv = CheckKind(n, kBatchNormalization, 0, kConv);
194+
193195
Node* bn = n;
194196
Node* conv = PrevNode(n, 0);
195197
auto origInput = bn->inputs()[0];
196-
if (!modify_conv(conv, bn, graph)) {
198+
if (!modify_conv(conv, bn, graph, is_conv)) {
197199
destroy_current = NodeDestroyType::DestroyZero;
198200
return false;
199201
}

onnxoptimizer/test/optimizer_test.py

+40
Original file line numberDiff line numberDiff line change
@@ -3063,6 +3063,46 @@ def test_fuse_bn_into_conv_simple(self): # type: () -> None
30633063
)
30643064
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"]) # noqa
30653065

3066+
def test_fuse_bn_into_conv_transpose_simple(self): # type: () -> None
3067+
for (tensor_type, np_type) in [(TensorProto.FLOAT, np.float32)]:
3068+
conv = helper.make_node("ConvTranspose", ["X", "W", "B"], ["Y"])
3069+
bn = helper.make_node(
3070+
"BatchNormalization", ["Y", "scale", "b", "mean", "var"], ["Z"]
3071+
)
3072+
3073+
W = np.random.randn(64, 64, 2, 2).astype(np_type) + 2
3074+
B = np.random.randn(64,).astype(np_type) + 2
3075+
scale = np.random.randn(64,).astype(np_type) + 2
3076+
b = np.random.randn(64,).astype(np_type) + 2
3077+
mean = np.random.randn(64,).astype(np_type) + 2
3078+
var = np.abs(np.random.randn(64,).astype(np_type)) + 2
3079+
3080+
initializers = [
3081+
helper.make_tensor(
3082+
name, tensor_type, npa.shape, npa.tobytes(), raw=True
3083+
)
3084+
for name, npa in [
3085+
("W", W),
3086+
("B", B),
3087+
("scale", scale),
3088+
("b", b),
3089+
("mean", mean),
3090+
("var", var),
3091+
]
3092+
]
3093+
graph = helper.make_graph(
3094+
[conv, bn],
3095+
"test",
3096+
[helper.make_tensor_value_info("X", tensor_type, (1, 64, 160, 160))],
3097+
[helper.make_tensor_value_info("Z", tensor_type, (1, 64, 320, 320))],
3098+
initializer=initializers,
3099+
value_info=[
3100+
helper.make_tensor_value_info("Y", tensor_type, (1, 64, 320, 320))
3101+
],
3102+
)
3103+
3104+
optimized_model = self._optimized(graph, ["fuse_bn_into_conv"])
3105+
30663106
def _internal_test_deadend_elimination(self, fixed): # type: (bool) -> None
30673107
softmax = helper.make_node("Softmax", ["X"], ["Y"], axis=2)
30683108
log = helper.make_node("Log", ["Y"], ["Z"])

0 commit comments

Comments
 (0)