From 6e826960dfa7c7e9e9a1a4c4d0b07410b821257b Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 2 Jan 2025 15:49:09 +0800 Subject: [PATCH] Update BatchNormKernels.cpp --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 25 ++++++------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 61a960e30..f50524db5 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor { for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE; feature_vec_begin < fs; feature_vec_begin += VEC_SIZE * item.get_local_range(1)) { - auto remaining = fs - feature_vec_begin; - if (remaining < VEC_SIZE) { - for (index_t idx = 0; idx < remaining; ++idx) { - index_t feature = feature_vec_begin + idx; - o[feature] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - } else { - using vec_t = memory::aligned_vector; - vec_t vec; + using vec_t = memory::aligned_vector; + vec_t vec; #pragma unroll - for (int vt = 0; vt < VEC_SIZE; ++vt) { - index_t feature = feature_vec_begin + vt; - vec[vt] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - input_scalar_t* write_ptr = &o[feature_vec_begin]; - *(reinterpret_cast(write_ptr)) = vec; + for (int vt = 0; vt < VEC_SIZE; ++vt) { + index_t feature = feature_vec_begin + vt; + vec[vt] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); } + input_scalar_t* write_ptr = &o[feature_vec_begin]; + *(reinterpret_cast(write_ptr)) = vec; } } }