@@ -961,46 +961,6 @@ def test_add_layernorm(self):
961
961
node = "ipex::add_layernorm"
962
962
self .assertTrue (any (n .kind () == node for n in trace_graph .nodes ()))
963
963
964
- def _test_concat_bn_relu (self , a1 , a2 , a3 , enable_3d = True , use_channels_last = True ):
965
- if enable_3d :
966
- if use_channels_last :
967
- model = ConcatBnRelu3d ().eval ().to (memory_format = torch .channels_last_3d )
968
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
969
- with torch .no_grad ():
970
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
971
- jit_model = torch .jit .freeze (jit_model )
972
- jit_res = jit_model (a1 , a2 , a3 )
973
- ori_res = model (a1 , a2 , a3 )
974
- self .assertEqual (jit_res , ori_res )
975
- else :
976
- model = ConcatBnRelu3d ().eval ()
977
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
978
- with torch .no_grad ():
979
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
980
- jit_model = torch .jit .freeze (jit_model )
981
- jit_res = jit_model (a1 , a2 , a3 )
982
- ori_res = model (a1 , a2 , a3 )
983
- self .assertEqual (jit_res , ori_res )
984
- else :
985
- if use_channels_last :
986
- model = ConcatBnRelu2d ().eval ().to (memory_format = torch .channels_last )
987
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
988
- with torch .no_grad ():
989
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
990
- jit_model = torch .jit .freeze (jit_model )
991
- jit_res = jit_model (a1 , a2 , a3 )
992
- ori_res = model (a1 , a2 , a3 )
993
- self .assertEqual (jit_res , ori_res )
994
- else :
995
- model = ConcatBnRelu2d ().eval ()
996
- model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
997
- with torch .no_grad ():
998
- jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
999
- jit_model = torch .jit .freeze (jit_model )
1000
- jit_res = jit_model (a1 , a2 , a3 )
1001
- ori_res = model (a1 , a2 , a3 )
1002
- self .assertEqual (jit_res , ori_res )
1003
-
1004
964
def test_concat_bn_relu (self ):
1005
965
a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
1006
966
a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .bfloat16 ).contiguous (memory_format = torch .channels_last )
@@ -1010,8 +970,10 @@ def test_concat_bn_relu(self):
1010
970
with torch .no_grad ():
1011
971
jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1012
972
jit_model = torch .jit .freeze (jit_model )
1013
- jit_res = jit_model (a1 , a2 , a3 )
1014
- ori_res = model (a1 , a2 , a3 )
973
+ #warmup run
974
+ for _ in range (2 ):
975
+ jit_res = jit_model (a1 , a2 , a3 )
976
+ ori_res = model (a1 , a2 , a3 )
1015
977
self .assertEqual (jit_res , ori_res )
1016
978
1017
979
a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
@@ -1022,46 +984,92 @@ def test_concat_bn_relu(self):
1022
984
with torch .no_grad ():
1023
985
jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1024
986
jit_model = torch .jit .freeze (jit_model )
1025
- jit_res = jit_model (a1 , a2 , a3 )
1026
- ori_res = model (a1 , a2 , a3 )
987
+ #warmup run
988
+ for _ in range (2 ):
989
+ jit_res = jit_model (a1 , a2 , a3 )
990
+ ori_res = model (a1 , a2 , a3 )
1027
991
self .assertEqual (jit_res , ori_res )
1028
992
1029
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = True )
993
+ model = ConcatBnRelu2d ().eval ().to (memory_format = torch .channels_last )
994
+ model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
995
+ with torch .no_grad ():
996
+ jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
997
+ jit_model = torch .jit .freeze (jit_model )
998
+ #warmup run
999
+ for _ in range (2 ):
1000
+ jit_res = jit_model (a1 , a2 , a3 )
1001
+ ori_res = model (a1 , a2 , a3 )
1002
+ self .assertEqual (jit_res , ori_res )
1030
1003
1031
- a1 = torch .randn (1 , 16 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1032
- a2 = torch .randn (1 , 48 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1033
- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1034
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = True )
1004
+ a1 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1005
+ a2 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1006
+ a3 = torch .randn (1 , 32 , 18 , 53 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1007
+ with torch .no_grad ():
1008
+ jit_res = jit_model (a1 , a2 , a3 )
1009
+ ori_res = model (a1 , a2 , a3 )
1010
+ self .assertEqual (jit_res , ori_res )
1035
1011
1036
- a1 = torch .randn (1 , 17 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1037
- a2 = torch .randn (1 , 47 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1038
- a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1039
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = True )
1012
+ a1 = torch .randn (1 , 16 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1013
+ a2 = torch .randn (1 , 48 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1014
+ a3 = torch .randn (1 , 32 , 24 , 116 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1015
+ with torch .no_grad ():
1016
+ jit_res = jit_model (a1 , a2 , a3 )
1017
+ ori_res = model (a1 , a2 , a3 )
1018
+ self .assertEqual (jit_res , ori_res )
1019
+
1020
+ a1 = torch .randn (1 , 17 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1021
+ a2 = torch .randn (1 , 47 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1022
+ a3 = torch .randn (1 , 32 , 15 , 24 , dtype = torch .float ).contiguous (memory_format = torch .channels_last )
1023
+ with torch .no_grad ():
1024
+ jit_res = jit_model (a1 , a2 , a3 )
1025
+ ori_res = model (a1 , a2 , a3 )
1026
+ self .assertEqual (jit_res , ori_res )
1040
1027
1041
1028
a1 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1042
1029
a2 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1043
1030
a3 = torch .randn (1 , 32 , 13 , 24 , dtype = torch .float )
1044
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = False , use_channels_last = False )
1031
+ with torch .no_grad ():
1032
+ jit_res = jit_model (a1 , a2 , a3 )
1033
+ ori_res = model (a1 , a2 , a3 )
1034
+ self .assertEqual (jit_res , ori_res )
1045
1035
1046
1036
a1 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1047
1037
a2 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1048
1038
a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1049
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = True )
1039
+ model = ConcatBnRelu3d ().eval ().to (memory_format = torch .channels_last_3d )
1040
+ model = ipex .optimize (model , dtype = torch .float32 , level = 'O0' )
1041
+ with torch .no_grad ():
1042
+ jit_model = torch .jit .trace (model , (a1 , a2 , a3 )).eval ()
1043
+ jit_model = torch .jit .freeze (jit_model )
1044
+ #warmup run
1045
+ for _ in range (2 ):
1046
+ jit_res = jit_model (a1 , a2 , a3 )
1047
+ ori_res = model (a1 , a2 , a3 )
1048
+ self .assertEqual (jit_res , ori_res )
1050
1049
1051
- a1 = torch .randn (1 , 16 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1052
- a2 = torch .randn (1 , 48 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1053
- a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1054
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = True )
1050
+ a1 = torch .randn (1 , 16 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1051
+ a2 = torch .randn (1 , 48 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1052
+ a3 = torch .randn (1 , 32 , 17 , 14 , 31 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1053
+ with torch .no_grad ():
1054
+ jit_res = jit_model (a1 , a2 , a3 )
1055
+ ori_res = model (a1 , a2 , a3 )
1056
+ self .assertEqual (jit_res , ori_res )
1055
1057
1056
1058
a1 = torch .randn (1 , 17 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1057
1059
a2 = torch .randn (1 , 47 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1058
1060
a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float ).contiguous (memory_format = torch .channels_last_3d )
1059
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = True )
1061
+ with torch .no_grad ():
1062
+ jit_res = jit_model (a1 , a2 , a3 )
1063
+ ori_res = model (a1 , a2 , a3 )
1064
+ self .assertEqual (jit_res , ori_res )
1060
1065
1061
1066
a1 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1062
1067
a2 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1063
1068
a3 = torch .randn (1 , 32 , 13 , 24 , 33 , dtype = torch .float )
1064
- self ._test_concat_bn_relu (a1 , a2 , a3 , enable_3d = True , use_channels_last = False )
1069
+ with torch .no_grad ():
1070
+ jit_res = jit_model (a1 , a2 , a3 )
1071
+ ori_res = model (a1 , a2 , a3 )
1072
+ self .assertEqual (jit_res , ori_res )
1065
1073
1066
1074
def test_mha_scores_calculation (self ):
1067
1075
def _check_match_mha (trace_model , mat1 , mat2 , bias , node = "ipex::mha_scores_calc" ):
0 commit comments