Skip to content

Commit a044479

Browse files
Add necessary Concat-bn-relu fusion checks to runtime kernel (#546)
* Add fusion checks back to runtime kernel * restore xu's dyn patch for concat-bn-relu fusion * fix typo * adding warmup runs before UT checks for test_jit.py
1 parent d9ef0bc commit a044479

File tree

7 files changed

+136
-108
lines changed

7 files changed

+136
-108
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/kernels/jit_kernels/ConcatBnReluKrnl.cpp

+36-23
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@ namespace {
1616

1717
at::Tensor concat_bn_relu_kernel_impl(
1818
const c10::List<at::Tensor>& a,
19+
const at::Tensor& bn_scale,
1920
const at::Tensor& bn_beta,
20-
const c10::optional<at::Tensor>& bn_scale,
21+
const c10::optional<at::Tensor>& bn_weight,
2122
const c10::optional<at::Tensor>& bn_bias,
2223
const c10::optional<at::Tensor>& bn_mean,
2324
const c10::optional<at::Tensor>& bn_var,
@@ -27,36 +28,49 @@ at::Tensor concat_bn_relu_kernel_impl(
2728
bool bn_cudnn_enabled,
2829
int dim) {
2930
int64_t list_length = a.size();
31+
std::vector<int64_t> output_dim = a[0].sizes().vec();
32+
int64_t tensor_length = a[0].ndimension();
3033

31-
c10::MaybeOwned<at::Tensor> weight_maybe_owned =
32-
at::borrow_from_optional_tensor(bn_scale);
33-
const at::Tensor& bn_weight = *weight_maybe_owned;
34-
std::vector<long int> output_dim(a[0].ndimension());
35-
for (int64_t i = 0; i < list_length; ++i) {
36-
output_dim[1] += a[i].size(1);
37-
}
38-
for (int64_t i = 0; i < a[0].ndimension(); ++i) {
39-
if (i != 1) {
40-
output_dim[i] = a[0].size(i);
34+
// Check if the memory format is channelslast(3d) and if the channel size can
35+
// be divided by 16
36+
auto check_format_channelsize = [](at::Tensor tensor) {
37+
return (
38+
(tensor.suggest_memory_format() == at::MemoryFormat::ChannelsLast ||
39+
tensor.suggest_memory_format() == at::MemoryFormat::ChannelsLast3d) &&
40+
tensor.size(1) % 16 == 0);
41+
};
42+
43+
// Check the first tensor
44+
bool tensor_check = check_format_channelsize(a[0]);
45+
// Check the rest input tensors
46+
for (int64_t i = 1; i < list_length; ++i) {
47+
tensor_check = (tensor_check && check_format_channelsize(a[i]));
48+
for (int64_t j = 0; j < tensor_length; ++j) {
49+
if (j == 1) {
50+
output_dim[1] += a[i].size(j);
51+
} else {
52+
tensor_check = (tensor_check && a[i].size(j) == a[0].size(j));
53+
}
4154
}
4255
}
43-
at::Tensor output = at::empty(
44-
output_dim,
45-
a[0].options()
46-
.dtype(at::kFloat)
47-
.memory_format(a[0].suggest_memory_format()));
48-
4956
#if defined(CPU_CAPABILITY_AVX512)
50-
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
51-
float>(a, bn_weight, bn_beta, output);
52-
return output;
53-
#else
57+
if (tensor_check) {
58+
at::Tensor output = at::empty(
59+
output_dim,
60+
a[0].options()
61+
.dtype(at::kFloat)
62+
.memory_format(a[0].suggest_memory_format()));
63+
torch_ipex::cpu::kernel::vec::vec512::ConcatBnReluKernelImpl_ChannelsLast<
64+
float>(a, bn_scale, bn_beta, output);
65+
return output;
66+
}
67+
#endif
5468
std::vector<at::Tensor> concat_input(list_length);
5569
for (int64_t i = 0; i < list_length; ++i)
5670
concat_input[i] = a[i];
5771
auto bn_res = at::batch_norm(
5872
at::cat(concat_input, (int64_t)dim),
59-
bn_scale,
73+
bn_weight,
6074
bn_bias,
6175
bn_mean,
6276
bn_var,
@@ -65,7 +79,6 @@ at::Tensor concat_bn_relu_kernel_impl(
6579
bn_eps,
6680
bn_cudnn_enabled);
6781
return at::relu(bn_res);
68-
#endif
6982
}
7083

7184
#if defined(DYN_DISP_BUILD)

intel_extension_for_pytorch/csrc/cpu/vec512/concat_bn_relu.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void ConcatBnReluKernelImpl_ChannelsLast(
6464

6565
for (int64_t i = 0; i < list_length; ++i) {
6666
input_channels[i + 1] = input_channels[i] + a[i].size(1);
67-
input_ptr[i] = a[i].data_ptr<T>();
67+
input_ptr[i] = a[i].contiguous(a[i].suggest_memory_format()).data_ptr<T>();
6868
}
6969
// Return the product of all the input dimensions except for the channel
7070
// and check if the dimension and sizes of the tensors meet the fusion

intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConcatBnRelu.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ DEFINE_DISPATCH(concat_bn_relu_kernel_stub);
2323
**/
2424
at::Tensor ConcatBnRelu(
2525
const c10::List<at::Tensor>& a,
26+
const at::Tensor& bn_scale,
2627
const at::Tensor& bn_beta,
27-
const c10::optional<at::Tensor>& bn_scale,
28+
const c10::optional<at::Tensor>& bn_weight,
2829
const c10::optional<at::Tensor>& bn_bias,
2930
const c10::optional<at::Tensor>& bn_mean,
3031
const c10::optional<at::Tensor>& bn_var,
@@ -33,14 +34,15 @@ at::Tensor ConcatBnRelu(
3334
double bn_eps,
3435
bool bn_cudnn_enabled,
3536
int dim) {
36-
IPEX_RECORD_FUNCTION("ConcatBnRelu", std::vector<c10::IValue>({}));
37+
IPEX_RECORD_FUNCTION("ipex::concat_bn_relu", std::vector<c10::IValue>({}));
3738

3839
#if defined(DYN_DISP_BUILD)
3940
return concat_bn_relu_kernel_stub(
4041
kCPU,
4142
a,
42-
bn_beta,
4343
bn_scale,
44+
bn_beta,
45+
bn_weight,
4446
bn_bias,
4547
bn_mean,
4648
bn_var,
@@ -52,8 +54,9 @@ at::Tensor ConcatBnRelu(
5254
#else
5355
return concat_bn_relu_kernel_impl(
5456
a,
55-
bn_beta,
5657
bn_scale,
58+
bn_beta,
59+
bn_weight,
5760
bn_bias,
5861
bn_mean,
5962
bn_var,

intel_extension_for_pytorch/csrc/jit/cpu/kernels/ConcatBnRelu.h

+5-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ namespace cpu {
1515
* */
1616
at::Tensor ConcatBnRelu(
1717
const c10::List<at::Tensor>& a,
18+
const at::Tensor& bn_scale,
1819
const at::Tensor& bn_beta,
19-
const c10::optional<at::Tensor>& bn_scale,
20+
const c10::optional<at::Tensor>& bn_weight,
2021
const c10::optional<at::Tensor>& bn_bias,
2122
const c10::optional<at::Tensor>& bn_mean,
2223
const c10::optional<at::Tensor>& bn_var,
@@ -32,8 +33,9 @@ namespace {
3233

3334
at::Tensor concat_bn_relu_kernel_impl(
3435
const c10::List<at::Tensor>& a,
36+
const at::Tensor& bn_scale,
3537
const at::Tensor& bn_beta,
36-
const c10::optional<at::Tensor>& bn_scale,
38+
const c10::optional<at::Tensor>& bn_weight,
3739
const c10::optional<at::Tensor>& bn_bias,
3840
const c10::optional<at::Tensor>& bn_mean,
3941
const c10::optional<at::Tensor>& bn_var,
@@ -50,6 +52,7 @@ at::Tensor concat_bn_relu_kernel_impl(
5052
using concat_bn_relu_kernel_fn = at::Tensor (*)(
5153
const c10::List<at::Tensor>&,
5254
const at::Tensor&,
55+
const at::Tensor&,
5356
const c10::optional<at::Tensor>&,
5457
const c10::optional<at::Tensor>&,
5558
const c10::optional<at::Tensor>&,

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,10 @@ void FuseConcatBnRelu(std::shared_ptr<Graph>& graph) {
444444
%alpha: int = prim::Constant[value=1]()
445445
%u1 = aten::add(%running_var, %eps, %alpha)
446446
%u2 = aten::sqrt(%u1)
447-
%u3 = aten::div(%running_mean, %u2)
448-
%u4 = aten::mul(%weight, %u3)
449-
%beta = aten::sub(%bias, %u4, %alpha)
450-
%b = ipex::concat_bn_relu(%input, %beta, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled, %dim)
447+
%scale = aten::div(%weight, %u2)
448+
%u3 = aten::mul(%running_mean, %scale)
449+
%beta = aten::sub(%bias, %u3, %alpha)
450+
%b = ipex::concat_bn_relu(%input, %scale, %beta, %weight, %bias, %running_mean, %running_var, %training, %momentum, %eps, %cudnn_enabled, %dim)
451451
return (%b) )";
452452

453453
auto fusion_filter = [](const Match& match,

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

+14-13
Original file line numberDiff line numberDiff line change
@@ -675,24 +675,25 @@ RegisterOperators op({
675675
},
676676
aliasAnalysisFromSchema()),
677677
Operator(
678-
"ipex::concat_bn_relu(Tensor[] a, Tensor bn_beta, "
678+
"ipex::concat_bn_relu(Tensor[] a, Tensor bn_scale, Tensor bn_beta, "
679679
"Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled, int dim) -> "
680680
"Tensor",
681681
[](const Node* node) -> Operation {
682682
return [](Stack* stack) {
683683
auto result = ConcatBnRelu(
684-
(std::move(peek(stack, 0, 11))).toTensorList(),
685-
(std::move(peek(stack, 1, 11))).toTensor(),
686-
toOptionalTensor(std::move(peek(stack, 2, 11))),
687-
toOptionalTensor(std::move(peek(stack, 3, 11))),
688-
toOptionalTensor(std::move(peek(stack, 4, 11))),
689-
toOptionalTensor(std::move(peek(stack, 5, 11))),
690-
(std::move(peek(stack, 6, 11))).toBool(),
691-
(std::move(peek(stack, 7, 11))).toDouble(),
692-
(std::move(peek(stack, 8, 11))).toDouble(),
693-
(std::move(peek(stack, 9, 11))).toBool(),
694-
(std::move(peek(stack, 10, 11))).toInt());
695-
drop(stack, 11);
684+
(std::move(peek(stack, 0, 12))).toTensorList(),
685+
(std::move(peek(stack, 1, 12))).toTensor(),
686+
(std::move(peek(stack, 2, 12))).toTensor(),
687+
toOptionalTensor(std::move(peek(stack, 3, 12))),
688+
toOptionalTensor(std::move(peek(stack, 4, 12))),
689+
toOptionalTensor(std::move(peek(stack, 5, 12))),
690+
toOptionalTensor(std::move(peek(stack, 6, 12))),
691+
(std::move(peek(stack, 7, 12))).toBool(),
692+
(std::move(peek(stack, 8, 12))).toDouble(),
693+
(std::move(peek(stack, 9, 12))).toDouble(),
694+
(std::move(peek(stack, 10, 12))).toBool(),
695+
(std::move(peek(stack, 11, 12))).toInt());
696+
drop(stack, 12);
696697
pack(stack, std::move(result));
697698
return 0;
698699
};

tests/cpu/test_jit.py

+69-61
Original file line numberDiff line numberDiff line change
@@ -961,46 +961,6 @@ def test_add_layernorm(self):
961961
node = "ipex::add_layernorm"
962962
self.assertTrue(any(n.kind() == node for n in trace_graph.nodes()))
963963

964-
def _test_concat_bn_relu(self, a1, a2, a3, enable_3d=True, use_channels_last=True):
965-
if enable_3d:
966-
if use_channels_last:
967-
model = ConcatBnRelu3d().eval().to(memory_format=torch.channels_last_3d)
968-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
969-
with torch.no_grad():
970-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
971-
jit_model = torch.jit.freeze(jit_model)
972-
jit_res = jit_model(a1, a2, a3)
973-
ori_res = model(a1, a2, a3)
974-
self.assertEqual(jit_res, ori_res)
975-
else:
976-
model = ConcatBnRelu3d().eval()
977-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
978-
with torch.no_grad():
979-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
980-
jit_model = torch.jit.freeze(jit_model)
981-
jit_res = jit_model(a1, a2, a3)
982-
ori_res = model(a1, a2, a3)
983-
self.assertEqual(jit_res, ori_res)
984-
else:
985-
if use_channels_last:
986-
model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last)
987-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
988-
with torch.no_grad():
989-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
990-
jit_model = torch.jit.freeze(jit_model)
991-
jit_res = jit_model(a1, a2, a3)
992-
ori_res = model(a1, a2, a3)
993-
self.assertEqual(jit_res, ori_res)
994-
else:
995-
model = ConcatBnRelu2d().eval()
996-
model = ipex.optimize(model, dtype=torch.float32, level='O0')
997-
with torch.no_grad():
998-
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
999-
jit_model = torch.jit.freeze(jit_model)
1000-
jit_res = jit_model(a1, a2, a3)
1001-
ori_res = model(a1, a2, a3)
1002-
self.assertEqual(jit_res, ori_res)
1003-
1004964
def test_concat_bn_relu(self):
1005965
a1 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
1006966
a2 = torch.randn(1, 32, 13, 24, dtype=torch.bfloat16).contiguous(memory_format=torch.channels_last)
@@ -1010,8 +970,10 @@ def test_concat_bn_relu(self):
1010970
with torch.no_grad():
1011971
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1012972
jit_model = torch.jit.freeze(jit_model)
1013-
jit_res = jit_model(a1, a2, a3)
1014-
ori_res = model(a1, a2, a3)
973+
#warmup run
974+
for _ in range(2):
975+
jit_res = jit_model(a1, a2, a3)
976+
ori_res = model(a1, a2, a3)
1015977
self.assertEqual(jit_res, ori_res)
1016978

1017979
a1 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
@@ -1022,46 +984,92 @@ def test_concat_bn_relu(self):
1022984
with torch.no_grad():
1023985
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1024986
jit_model = torch.jit.freeze(jit_model)
1025-
jit_res = jit_model(a1, a2, a3)
1026-
ori_res = model(a1, a2, a3)
987+
#warmup run
988+
for _ in range(2):
989+
jit_res = jit_model(a1, a2, a3)
990+
ori_res = model(a1, a2, a3)
1027991
self.assertEqual(jit_res, ori_res)
1028992

1029-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=False, use_channels_last=True)
993+
model = ConcatBnRelu2d().eval().to(memory_format=torch.channels_last)
994+
model = ipex.optimize(model, dtype=torch.float32, level='O0')
995+
with torch.no_grad():
996+
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
997+
jit_model = torch.jit.freeze(jit_model)
998+
#warmup run
999+
for _ in range(2):
1000+
jit_res = jit_model(a1, a2, a3)
1001+
ori_res = model(a1, a2, a3)
1002+
self.assertEqual(jit_res, ori_res)
10301003

1031-
a1 = torch.randn(1, 16, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1032-
a2 = torch.randn(1, 48, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1033-
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1034-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=False, use_channels_last=True)
1004+
a1 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1005+
a2 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1006+
a3 = torch.randn(1, 32, 18, 53, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1007+
with torch.no_grad():
1008+
jit_res = jit_model(a1, a2, a3)
1009+
ori_res = model(a1, a2, a3)
1010+
self.assertEqual(jit_res, ori_res)
10351011

1036-
a1 = torch.randn(1, 17, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1037-
a2 = torch.randn(1, 47, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1038-
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1039-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=False, use_channels_last=True)
1012+
a1 = torch.randn(1, 16, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1013+
a2 = torch.randn(1, 48, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1014+
a3 = torch.randn(1, 32, 24, 116, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1015+
with torch.no_grad():
1016+
jit_res = jit_model(a1, a2, a3)
1017+
ori_res = model(a1, a2, a3)
1018+
self.assertEqual(jit_res, ori_res)
1019+
1020+
a1 = torch.randn(1, 17, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1021+
a2 = torch.randn(1, 47, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1022+
a3 = torch.randn(1, 32, 15, 24, dtype=torch.float).contiguous(memory_format=torch.channels_last)
1023+
with torch.no_grad():
1024+
jit_res = jit_model(a1, a2, a3)
1025+
ori_res = model(a1, a2, a3)
1026+
self.assertEqual(jit_res, ori_res)
10401027

10411028
a1 = torch.randn(1, 32, 13, 24, dtype=torch.float)
10421029
a2 = torch.randn(1, 32, 13, 24, dtype=torch.float)
10431030
a3 = torch.randn(1, 32, 13, 24, dtype=torch.float)
1044-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=False, use_channels_last=False)
1031+
with torch.no_grad():
1032+
jit_res = jit_model(a1, a2, a3)
1033+
ori_res = model(a1, a2, a3)
1034+
self.assertEqual(jit_res, ori_res)
10451035

10461036
a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
10471037
a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
10481038
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1049-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=True, use_channels_last=True)
1039+
model = ConcatBnRelu3d().eval().to(memory_format=torch.channels_last_3d)
1040+
model = ipex.optimize(model, dtype=torch.float32, level='O0')
1041+
with torch.no_grad():
1042+
jit_model = torch.jit.trace(model, (a1, a2, a3)).eval()
1043+
jit_model = torch.jit.freeze(jit_model)
1044+
#warmup run
1045+
for _ in range(2):
1046+
jit_res = jit_model(a1, a2, a3)
1047+
ori_res = model(a1, a2, a3)
1048+
self.assertEqual(jit_res, ori_res)
10501049

1051-
a1 = torch.randn(1, 16, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1052-
a2 = torch.randn(1, 48, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1053-
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1054-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=True, use_channels_last=True)
1050+
a1 = torch.randn(1, 16, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1051+
a2 = torch.randn(1, 48, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1052+
a3 = torch.randn(1, 32, 17, 14, 31, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1053+
with torch.no_grad():
1054+
jit_res = jit_model(a1, a2, a3)
1055+
ori_res = model(a1, a2, a3)
1056+
self.assertEqual(jit_res, ori_res)
10551057

10561058
a1 = torch.randn(1, 17, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
10571059
a2 = torch.randn(1, 47, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
10581060
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float).contiguous(memory_format=torch.channels_last_3d)
1059-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=True, use_channels_last=True)
1061+
with torch.no_grad():
1062+
jit_res = jit_model(a1, a2, a3)
1063+
ori_res = model(a1, a2, a3)
1064+
self.assertEqual(jit_res, ori_res)
10601065

10611066
a1 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
10621067
a2 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
10631068
a3 = torch.randn(1, 32, 13, 24, 33, dtype=torch.float)
1064-
self._test_concat_bn_relu(a1, a2, a3, enable_3d=True, use_channels_last=False)
1069+
with torch.no_grad():
1070+
jit_res = jit_model(a1, a2, a3)
1071+
ori_res = model(a1, a2, a3)
1072+
self.assertEqual(jit_res, ori_res)
10651073

10661074
def test_mha_scores_calculation(self):
10671075
def _check_match_mha(trace_model, mat1, mat2, bias, node = "ipex::mha_scores_calc"):

0 commit comments

Comments
 (0)