From f634c3ccc24312b445be566270d2037190c2e6bc Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Fri, 3 Jan 2025 09:57:31 +0800 Subject: [PATCH] Fix batch norm vectorize path accuracy issue by enforcing shape alignment (#1238) When to implement a kernel for all shapes with vectorized LD/ST, we have to handle non-aligned head (base address) and short tail (tail < vector size). Before the commit, there was lack of head handling which leaded to non-vector-size-aligned vectorized ST. The fixing enforces shapes for vectorize path. The vectorized kernel could be called only in the case in which feature dim could be divided by vector size. There will always be aligned head and exact tail. --------- Co-authored-by: mengfei25 --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 30a3bf5c2..f50524db5 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor { } else { invstd = static_cast(1) / - device_sqrt( + std::sqrt( static_cast(var_or_invstd_[plane]) + epsilon_); } @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor { for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE; feature_vec_begin < fs; feature_vec_begin += VEC_SIZE * item.get_local_range(1)) { - auto remaining = fs - feature_vec_begin; - if (remaining < VEC_SIZE) { - for (index_t idx = 0; idx < remaining; ++idx) { - index_t feature = feature_vec_begin + idx; - o[feature] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - } else { - using vec_t = memory::aligned_vector; - vec_t vec; + using vec_t = memory::aligned_vector; + vec_t vec; #pragma unroll - for (int vt = 0; vt < VEC_SIZE; ++vt) { - index_t feature = feature_vec_begin + vt; - vec[vt] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - input_scalar_t* write_ptr = &o[feature_vec_begin]; - *(reinterpret_cast(write_ptr)) = vec; + for (int vt = 0; vt < VEC_SIZE; ++vt) { + index_t feature = feature_vec_begin + vt; + vec[vt] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); } + input_scalar_t* write_ptr = &o[feature_vec_begin]; + *(reinterpret_cast(write_ptr)) = vec; } } } @@ -1459,7 +1450,7 @@ void batch_norm_elemt_template( auto output_ptr = (char*)output_reshaped.data_ptr(); if (output_reshaped.is_contiguous() && memory::can_vectorize_up_to(output_ptr) >= 4 && - sizeof(input_scalar_t) < sizeof(float)) { + sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) { auto kfn = BatchNormTransformInputVectorizedKernelFunctor< 4, input_scalar_t,