Skip to content

Commit fe4ac6a

Browse files
authored
[Release/2.6] Fix batch norm vectorize path accuracy issue by enforcing shape alignment (#1241)
When to implement a kernel for all shapes with vectorized LD/ST, we have to handle non-aligned head (base address) and short tail (tail < vector size). Before the commit, there was lack of head handling which leaded to non-vector-size-aligned vectorized ST. The fixing enforces shapes for vectorize path. The vectorized kernel could be called only in the case in which feature dim could be divided by vector size.
1 parent 28cfac2 commit fe4ac6a

File tree

1 file changed

+29
-24
lines changed

1 file changed

+29
-24
lines changed

src/ATen/native/xpu/sycl/BatchNormKernels.cpp

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
13021302
for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE;
13031303
feature_vec_begin < fs;
13041304
feature_vec_begin += VEC_SIZE * item.get_local_range(1)) {
1305-
auto remaining = fs - feature_vec_begin;
1306-
if (remaining < VEC_SIZE) {
1307-
for (index_t idx = 0; idx < remaining; ++idx) {
1308-
index_t feature = feature_vec_begin + idx;
1309-
o[feature] = static_cast<input_scalar_t>(
1310-
gamma * (i[feature] - mean) * invstd + beta);
1311-
}
1312-
} else {
1313-
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
1314-
vec_t vec;
1305+
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
1306+
vec_t vec;
13151307
#pragma unroll
1316-
for (int vt = 0; vt < VEC_SIZE; ++vt) {
1317-
index_t feature = feature_vec_begin + vt;
1318-
vec[vt] = static_cast<input_scalar_t>(
1319-
gamma * (i[feature] - mean) * invstd + beta);
1320-
}
1321-
input_scalar_t* write_ptr = &o[feature_vec_begin];
1322-
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
1308+
for (int vt = 0; vt < VEC_SIZE; ++vt) {
1309+
index_t feature = feature_vec_begin + vt;
1310+
vec[vt] = static_cast<input_scalar_t>(
1311+
gamma * (i[feature] - mean) * invstd + beta);
13231312
}
1313+
input_scalar_t* write_ptr = &o[feature_vec_begin];
1314+
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
13241315
}
13251316
}
13261317
}
@@ -1456,13 +1447,27 @@ void batch_norm_elemt_template(
14561447
nwg_y = std::min<int>(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb));
14571448
sycl::range<2> global_range(nwg_y * tb, nwg_x * tf);
14581449

1459-
auto kfn = BatchNormTransformInputKernelFunctor<
1460-
input_scalar_t,
1461-
stat_scalar_t,
1462-
stat_accscalar_t,
1463-
true,
1464-
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
1465-
sycl_kernel_submit(global_range, local_range, queue, kfn);
1450+
auto output_ptr = (char*)output_reshaped.data_ptr();
1451+
if (output_reshaped.is_contiguous() &&
1452+
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
1453+
sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) {
1454+
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
1455+
4,
1456+
input_scalar_t,
1457+
stat_scalar_t,
1458+
stat_accscalar_t,
1459+
true,
1460+
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
1461+
sycl_kernel_submit(global_range, local_range, queue, kfn);
1462+
} else {
1463+
auto kfn = BatchNormTransformInputKernelFunctor<
1464+
input_scalar_t,
1465+
stat_scalar_t,
1466+
stat_accscalar_t,
1467+
true,
1468+
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
1469+
sycl_kernel_submit(global_range, local_range, queue, kfn);
1470+
}
14661471
}
14671472

14681473
template <

0 commit comments

Comments
 (0)