Skip to content

Commit

Permalink
[Release/2.6] Fix batch norm vectorize path accuracy issue by enforci…
Browse files Browse the repository at this point in the history
…ng shape alignment (#1241)

When to implement a kernel for all shapes with vectorized LD/ST, we have
to handle non-aligned head (base address) and short tail (tail < vector
size). Before the commit, there was lack of head handling which leaded
to non-vector-size-aligned vectorized ST.
The fixing enforces shapes for vectorize path. The vectorized kernel
could be called only in the case in which feature dim could be divided
by vector size.
  • Loading branch information
xytintel authored Jan 3, 2025
1 parent 28cfac2 commit fe4ac6a
Showing 1 changed file with 29 additions and 24 deletions.
53 changes: 29 additions & 24 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE;
feature_vec_begin < fs;
feature_vec_begin += VEC_SIZE * item.get_local_range(1)) {
auto remaining = fs - feature_vec_begin;
if (remaining < VEC_SIZE) {
for (index_t idx = 0; idx < remaining; ++idx) {
index_t feature = feature_vec_begin + idx;
o[feature] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
} else {
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
using vec_t = memory::aligned_vector<input_scalar_t, VEC_SIZE>;
vec_t vec;
#pragma unroll
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
for (int vt = 0; vt < VEC_SIZE; ++vt) {
index_t feature = feature_vec_begin + vt;
vec[vt] = static_cast<input_scalar_t>(
gamma * (i[feature] - mean) * invstd + beta);
}
input_scalar_t* write_ptr = &o[feature_vec_begin];
*(reinterpret_cast<vec_t*>(write_ptr)) = vec;
}
}
}
Expand Down Expand Up @@ -1456,13 +1447,27 @@ void batch_norm_elemt_template(
nwg_y = std::min<int>(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb));
sycl::range<2> global_range(nwg_y * tb, nwg_x * tf);

auto kfn = BatchNormTransformInputKernelFunctor<
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
auto output_ptr = (char*)output_reshaped.data_ptr();
if (output_reshaped.is_contiguous() &&
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) {
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
4,
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
} else {
auto kfn = BatchNormTransformInputKernelFunctor<
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
}
}

template <
Expand Down

0 comments on commit fe4ac6a

Please sign in to comment.