@@ -509,6 +509,14 @@ def __init__(self, dim=-1):
509
509
def forward (self , x ):
510
510
return self .softmax (x )
511
511
512
+ class AtenBatchNormRepalce (nn .Module ):
513
+ def __init__ (self ):
514
+ super (AtenBatchNormRepalce , self ).__init__ ()
515
+ self .bn = torch .nn .BatchNorm2d (10 )
516
+
517
+ def forward (self , x ):
518
+ return self .bn (x )
519
+
512
520
class AddLayerNorm (torch .nn .Module ):
513
521
def __init__ (self , dim = 32 ):
514
522
super (AddLayerNorm , self ).__init__ ()
@@ -925,35 +933,35 @@ def test_output_conv_bn_2d(self):
925
933
ConvBatchNorm_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
926
934
torch .randn (32 , 3 , 64 , 64 ),
927
935
kind_in_graph = "ipex_prepack::convolution_run" ,
928
- kind_not_in_graph = "aten ::batch_norm" ,
936
+ kind_not_in_graph = "ipex ::batch_norm" ,
929
937
levels = ['O1' ])
930
938
self ._test_output_bf16 (
931
939
ConvBatchNorm_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
932
940
torch .randn (32 , 3 , 64 , 64 ),
933
941
kind_in_graph = "ipex_prepack::convolution_run" ,
934
- kind_not_in_graph = "aten ::batch_norm" ,
942
+ kind_not_in_graph = "ipex ::batch_norm" ,
935
943
prec = 0.02 ,
936
944
levels = ['O1' ])
937
945
938
946
def test_output_bn_conv_2d (self ):
939
947
self ._test_output (
940
948
BatchNormConv_Fixed (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
941
949
torch .randn (32 , 3 , 64 , 64 ),
942
- kind_in_graph = "aten ::batch_norm" ,
950
+ kind_in_graph = "ipex ::batch_norm" ,
943
951
kind_not_in_graph = None )
944
952
945
953
def test_output_bn_conv_bn (self ):
946
954
self ._test_output (
947
955
BatchNorm_Conv_BatchNorm (2 , 3 , 32 , kernel_size = 3 , stride = 1 ),
948
956
torch .randn (32 , 3 , 64 , 64 ),
949
- kind_in_graph = "aten ::batch_norm" ,
957
+ kind_in_graph = "ipex ::batch_norm" ,
950
958
kind_not_in_graph = None )
951
959
952
960
def test_output_conv_reshape_bn_2d (self ):
953
961
self ._test_output (
954
962
ConvReshapeBatchNorm (2 , 3 , 32 , (64 , 16 , 62 , 62 ), kernel_size = 3 , stride = 1 ),
955
963
torch .randn (32 , 3 , 64 , 64 ),
956
- kind_in_graph = "aten ::batch_norm" ,
964
+ kind_in_graph = "ipex ::batch_norm" ,
957
965
kind_not_in_graph = None )
958
966
959
967
def test_output_conv_conv_concate (self ):
@@ -994,7 +1002,7 @@ def test_output_conv_bn_3d(self):
994
1002
ConvBatchNorm_Fixed (3 , 3 , 32 , kernel_size = 3 , stride = 1 ),
995
1003
torch .randn (32 , 3 , 32 , 32 , 32 ),
996
1004
kind_in_graph = "aten::conv3d" ,
997
- kind_not_in_graph = "aten ::batch_norm" )
1005
+ kind_not_in_graph = "ipex ::batch_norm" )
998
1006
999
1007
def test_output_conv_relu_2d (self ):
1000
1008
self ._test_output (
@@ -1061,25 +1069,25 @@ def test_output_cascaded_conv_bn_sum_relu_2d(self):
1061
1069
CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
1062
1070
torch .rand (32 , 3 , 64 , 64 ),
1063
1071
kind_in_graph = "ipex_prepack::convolution_add_relu_run" ,
1064
- kind_not_in_graph = "aten ::batch_norm" )
1072
+ kind_not_in_graph = "ipex ::batch_norm" )
1065
1073
self ._test_output_bf16 (
1066
1074
CascadedConvBnSumRelu (2 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
1067
1075
torch .rand (32 , 3 , 64 , 64 ),
1068
1076
kind_in_graph = "ipex_prepack::convolution_add_relu_run" ,
1069
- kind_not_in_graph = "aten ::batch_norm" ,
1077
+ kind_not_in_graph = "ipex ::batch_norm" ,
1070
1078
prec = 0.02 )
1071
1079
1072
1080
def test_output_cascaded_conv_bn_sum_relu_3d (self ):
1073
1081
self ._test_output (
1074
1082
CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
1075
1083
torch .rand (32 , 3 , 32 , 32 , 32 ),
1076
1084
kind_in_graph = "ipex::conv3d_sum_relu" ,
1077
- kind_not_in_graph = "aten ::batch_norm" )
1085
+ kind_not_in_graph = "ipex ::batch_norm" )
1078
1086
self ._test_output_bf16 (
1079
1087
CascadedConvBnSumRelu (3 , 3 , 64 , 32 , kernel_size = 3 , stride = 1 ),
1080
1088
torch .rand (32 , 3 , 32 , 32 , 32 ),
1081
1089
kind_in_graph = "ipex::conv3d_sum_relu" ,
1082
- kind_not_in_graph = "aten ::batch_norm" ,
1090
+ kind_not_in_graph = "ipex ::batch_norm" ,
1083
1091
prec = 0.02 )
1084
1092
1085
1093
def test_output_conv_transpose2d (self ):
@@ -1346,6 +1354,17 @@ def test_ipex_softmax(self):
1346
1354
kind_in_graph = "ipex::softmax" ,
1347
1355
prec = 5e-3 )
1348
1356
1357
+ def test_ipex_batch_norm (self ):
1358
+ self ._test_output (
1359
+ AtenBatchNormRepalce (),
1360
+ torch .rand (10 , 10 , 4 , 4 ),
1361
+ kind_in_graph = "ipex::batch_norm" )
1362
+ self ._test_output_bf16 (
1363
+ AtenBatchNormRepalce (),
1364
+ torch .rand (10 , 10 , 4 , 4 , dtype = torch .bfloat16 ),
1365
+ kind_in_graph = "ipex::batch_norm" ,
1366
+ prec = 5e-3 )
1367
+
1349
1368
def test_restore_inplace (self ):
1350
1369
class M (nn .Module ):
1351
1370
def __init__ (self , eltwise_fn , params_dict = {}):
0 commit comments