diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 30a3bf5c2..f50524db5 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor { } else { invstd = static_cast(1) / - device_sqrt( + std::sqrt( static_cast(var_or_invstd_[plane]) + epsilon_); } @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor { for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE; feature_vec_begin < fs; feature_vec_begin += VEC_SIZE * item.get_local_range(1)) { - auto remaining = fs - feature_vec_begin; - if (remaining < VEC_SIZE) { - for (index_t idx = 0; idx < remaining; ++idx) { - index_t feature = feature_vec_begin + idx; - o[feature] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - } else { - using vec_t = memory::aligned_vector; - vec_t vec; + using vec_t = memory::aligned_vector; + vec_t vec; #pragma unroll - for (int vt = 0; vt < VEC_SIZE; ++vt) { - index_t feature = feature_vec_begin + vt; - vec[vt] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - input_scalar_t* write_ptr = &o[feature_vec_begin]; - *(reinterpret_cast(write_ptr)) = vec; + for (int vt = 0; vt < VEC_SIZE; ++vt) { + index_t feature = feature_vec_begin + vt; + vec[vt] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); } + input_scalar_t* write_ptr = &o[feature_vec_begin]; + *(reinterpret_cast(write_ptr)) = vec; } } } @@ -1459,7 +1450,7 @@ void batch_norm_elemt_template( auto output_ptr = (char*)output_reshaped.data_ptr(); if (output_reshaped.is_contiguous() && memory::can_vectorize_up_to(output_ptr) >= 4 && - sizeof(input_scalar_t) < sizeof(float)) { + sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) { auto kfn = BatchNormTransformInputVectorizedKernelFunctor< 4, input_scalar_t,