Skip to content

Commit b19a9c5

Browse files
XiaobingSuperEikanWang
authored andcommitted
change batch_norm scheme from aten::batch_norm to ipex::batch_norm to disable TE fusion path (#404)
* change batch_norm scheme from aten::batch_norm to ipex::batch_norm to disable TE fusion path * remove jira link
1 parent d8cd254 commit b19a9c5

File tree

6 files changed

+76
-17
lines changed

6 files changed

+76
-17
lines changed

tests/cpu/test_ipex_optimize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_optimize_parameters_behavior(self):
2626
x = torch.randn(1, 3, 224, 224)
2727
traced_model = torch.jit.trace(opt_M, x)
2828
trace_graph = traced_model.graph_for(x)
29-
self.assertTrue(any(n.kind() == "aten::batch_norm" for n in trace_graph.nodes()))
29+
self.assertTrue(any(n.kind() == "ipex::batch_norm" for n in trace_graph.nodes()))
3030
# TODO check weight_prepack.
3131

3232
def test_optimize_bf16_model(self):

tests/cpu/test_jit.py

+29-10
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,14 @@ def __init__(self, dim=-1):
509509
def forward(self, x):
510510
return self.softmax(x)
511511

512+
class AtenBatchNormRepalce(nn.Module):
513+
def __init__(self):
514+
super(AtenBatchNormRepalce, self).__init__()
515+
self.bn = torch.nn.BatchNorm2d(10)
516+
517+
def forward(self, x):
518+
return self.bn(x)
519+
512520
class AddLayerNorm(torch.nn.Module):
513521
def __init__(self, dim=32):
514522
super(AddLayerNorm, self).__init__()
@@ -925,35 +933,35 @@ def test_output_conv_bn_2d(self):
925933
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
926934
torch.randn(32, 3, 64, 64),
927935
kind_in_graph="ipex_prepack::convolution_run",
928-
kind_not_in_graph="aten::batch_norm",
936+
kind_not_in_graph="ipex::batch_norm",
929937
levels=['O1'])
930938
self._test_output_bf16(
931939
ConvBatchNorm_Fixed(2, 3, 32, kernel_size=3, stride=1),
932940
torch.randn(32, 3, 64, 64),
933941
kind_in_graph="ipex_prepack::convolution_run",
934-
kind_not_in_graph="aten::batch_norm",
942+
kind_not_in_graph="ipex::batch_norm",
935943
prec=0.02,
936944
levels=['O1'])
937945

938946
def test_output_bn_conv_2d(self):
939947
self._test_output(
940948
BatchNormConv_Fixed(2, 3, 32, kernel_size=3, stride=1),
941949
torch.randn(32, 3, 64, 64),
942-
kind_in_graph="aten::batch_norm",
950+
kind_in_graph="ipex::batch_norm",
943951
kind_not_in_graph=None)
944952

945953
def test_output_bn_conv_bn(self):
946954
self._test_output(
947955
BatchNorm_Conv_BatchNorm(2, 3, 32, kernel_size=3, stride=1),
948956
torch.randn(32, 3, 64, 64),
949-
kind_in_graph="aten::batch_norm",
957+
kind_in_graph="ipex::batch_norm",
950958
kind_not_in_graph=None)
951959

952960
def test_output_conv_reshape_bn_2d(self):
953961
self._test_output(
954962
ConvReshapeBatchNorm(2, 3, 32, (64, 16, 62, 62), kernel_size=3, stride=1),
955963
torch.randn(32, 3, 64, 64),
956-
kind_in_graph="aten::batch_norm",
964+
kind_in_graph="ipex::batch_norm",
957965
kind_not_in_graph=None)
958966

959967
def test_output_conv_conv_concate(self):
@@ -994,7 +1002,7 @@ def test_output_conv_bn_3d(self):
9941002
ConvBatchNorm_Fixed(3, 3, 32, kernel_size=3, stride=1),
9951003
torch.randn(32, 3, 32, 32, 32),
9961004
kind_in_graph="aten::conv3d",
997-
kind_not_in_graph="aten::batch_norm")
1005+
kind_not_in_graph="ipex::batch_norm")
9981006

9991007
def test_output_conv_relu_2d(self):
10001008
self._test_output(
@@ -1061,25 +1069,25 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
10611069
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
10621070
torch.rand(32, 3, 64, 64),
10631071
kind_in_graph="ipex_prepack::convolution_add_relu_run",
1064-
kind_not_in_graph="aten::batch_norm")
1072+
kind_not_in_graph="ipex::batch_norm")
10651073
self._test_output_bf16(
10661074
CascadedConvBnSumRelu(2, 3, 64, 32, kernel_size=3, stride=1),
10671075
torch.rand(32, 3, 64, 64),
10681076
kind_in_graph="ipex_prepack::convolution_add_relu_run",
1069-
kind_not_in_graph="aten::batch_norm",
1077+
kind_not_in_graph="ipex::batch_norm",
10701078
prec=0.02)
10711079

10721080
def test_output_cascaded_conv_bn_sum_relu_3d(self):
10731081
self._test_output(
10741082
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
10751083
torch.rand(32, 3, 32, 32, 32),
10761084
kind_in_graph="ipex::conv3d_sum_relu",
1077-
kind_not_in_graph="aten::batch_norm")
1085+
kind_not_in_graph="ipex::batch_norm")
10781086
self._test_output_bf16(
10791087
CascadedConvBnSumRelu(3, 3, 64, 32, kernel_size=3, stride=1),
10801088
torch.rand(32, 3, 32, 32, 32),
10811089
kind_in_graph="ipex::conv3d_sum_relu",
1082-
kind_not_in_graph="aten::batch_norm",
1090+
kind_not_in_graph="ipex::batch_norm",
10831091
prec=0.02)
10841092

10851093
def test_output_conv_transpose2d(self):
@@ -1346,6 +1354,17 @@ def test_ipex_softmax(self):
13461354
kind_in_graph="ipex::softmax",
13471355
prec=5e-3)
13481356

1357+
def test_ipex_batch_norm(self):
1358+
self._test_output(
1359+
AtenBatchNormRepalce(),
1360+
torch.rand(10, 10, 4, 4),
1361+
kind_in_graph="ipex::batch_norm")
1362+
self._test_output_bf16(
1363+
AtenBatchNormRepalce(),
1364+
torch.rand(10, 10, 4, 4, dtype=torch.bfloat16),
1365+
kind_in_graph="ipex::batch_norm",
1366+
prec=5e-3)
1367+
13491368
def test_restore_inplace(self):
13501369
class M(nn.Module):
13511370
def __init__(self, eltwise_fn, params_dict={}):

torch_ipex/csrc/jit/fusion_pass.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ void IPEXFusionPass(std::shared_ptr<Graph>& graph) {
342342
// replace aten::softmax with ipex::softmax
343343
graph_rewrite::replaceAtenSoftmaxWithIpexSoftmax(graph);
344344

345+
// replace aten::batch_norm with ipex::batch_norm, it will be removed
346+
// after TensorExprs fix the performance issue(IPB-808).
347+
graph_rewrite::replaceAtenBatchNormWithIpexBatchNorm(graph);
345348
// TODO: Some post processing?? ECS/EDC/Peephole???
346349
ConstantPropagation(graph);
347350
}

torch_ipex/csrc/jit/graph_rewrite.cpp

+18-3
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ void FuseAddLayerNorm(std::shared_ptr<Graph>& graph) {
320320
graph(%add_a, %add_b, %alpha, %shape:int[], %w, %b, %eps:float, %cudnn_enable:bool):
321321
%r = ipex::add_layernorm(%add_a, %add_b, %alpha, %shape, %w, %b, %eps, %cudnn_enable)
322322
return (%r) )";
323-
SubgraphRewriter rewriter_aten;
323+
IpexSubgraphRewriter rewriter_aten;
324324
rewriter_aten.RegisterRewritePattern(aten_add_layernorm, fused_add_layernorm);
325325
rewriter_aten.runOnGraph(graph);
326326
}
@@ -346,7 +346,7 @@ void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph) {
346346
%scores = ipex::mha_scores_calc(%q, %k, %relative_qk, %alpha, %dim_per_head, %softmax_dim, %dtype)
347347
return (%scores) )";
348348

349-
SubgraphRewriter mha_fusion;
349+
IpexSubgraphRewriter mha_fusion;
350350
mha_fusion.RegisterRewritePattern(
351351
div_matmul_add_softmax, div_matmul_add_softmax_fusion);
352352
mha_fusion.RegisterRewritePattern(
@@ -384,7 +384,22 @@ void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr<Graph>& graph) {
384384
rewriter_aten.runOnGraph(graph);
385385
}
386386

387-
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph> &graph) {
387+
void replaceAtenBatchNormWithIpexBatchNorm(std::shared_ptr<Graph>& graph) {
388+
std::string batch_norm = R"(
389+
graph(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled):
390+
%r = aten::batch_norm(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
391+
return (%r) )";
392+
std::string ipex_batch_norm = R"(
393+
graph(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled):
394+
%r = ipex::batch_norm(%a, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled)
395+
return (%r) )";
396+
397+
IpexSubgraphRewriter rewriter_batch_norm;
398+
rewriter_batch_norm.RegisterRewritePattern(batch_norm, ipex_batch_norm);
399+
rewriter_batch_norm.runOnGraph(graph);
400+
}
401+
402+
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph>& graph) {
388403
std::string qembedingbag = R"(
389404
graph(%weight, %input, %offsets, %sparse, %include_last_offset, %o_scale, %o_zp, %o_dtype):
390405
%r = ipex::qembedding_bag(%weight, %input, %offsets, %sparse, %include_last_offset, %o_scale, %o_zp, %o_dtype)

torch_ipex/csrc/jit/graph_rewrite.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ void FuseMHAScoreCalc(std::shared_ptr<Graph>& graph);
2929
void replaceAtenMaxPool2dWithIpexMaxPool2d(std::shared_ptr<Graph>& graph);
3030

3131
void replaceAtenSoftmaxWithIpexSoftmax(std::shared_ptr<Graph>& graph);
32-
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph> &graph);
33-
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph> &graph);
34-
void replaceInteractionWithQInteraction(std::shared_ptr<Graph> &graph);
32+
void replaceAtenBatchNormWithIpexBatchNorm(std::shared_ptr<Graph>& graph);
33+
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph>& graph);
34+
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph>& graph);
35+
void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph);
3536

3637
void insertPrePackedConv2dOp(std::shared_ptr<Graph>& graph);
3738
void fuseConvWithEltwise(std::shared_ptr<Graph>& graph);

torch_ipex/csrc/jit/register_dnnl_jit_ops.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,27 @@ RegisterOperators op({
411411
},
412412
aliasAnalysisFromSchema()),
413413

414+
Operator(
415+
"ipex::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
416+
[](const Node* node) -> Operation {
417+
return [](Stack* stack) {
418+
auto result = at::batch_norm(
419+
(std::move(peek(stack, 0, 9))).toTensor(),
420+
toOptionalTensor(std::move(peek(stack, 1, 9))),
421+
toOptionalTensor(std::move(peek(stack, 2, 9))),
422+
toOptionalTensor(std::move(peek(stack, 3, 9))),
423+
toOptionalTensor(std::move(peek(stack, 4, 9))),
424+
(std::move(peek(stack, 5, 9))).toBool(),
425+
(std::move(peek(stack, 6, 9))).toDouble(),
426+
(std::move(peek(stack, 7, 9))).toDouble(),
427+
(std::move(peek(stack, 8, 9))).toBool());
428+
drop(stack, 9);
429+
pack(stack, std::move(result));
430+
return 0;
431+
};
432+
},
433+
aliasAnalysisFromSchema()),
434+
414435
Operator(
415436
"ipex::qembedding_bag(Tensor weight, Tensor indices, Tensor offsets, "
416437
"bool sparse, bool include_last_offset, "

0 commit comments

Comments
 (0)