diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 30a3bf5c2..61a960e30 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor { } else { invstd = static_cast(1) / - device_sqrt( + std::sqrt( static_cast(var_or_invstd_[plane]) + epsilon_); } @@ -1459,7 +1459,7 @@ void batch_norm_elemt_template( auto output_ptr = (char*)output_reshaped.data_ptr(); if (output_reshaped.is_contiguous() && memory::can_vectorize_up_to(output_ptr) >= 4 && - sizeof(input_scalar_t) < sizeof(float)) { + sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) { auto kfn = BatchNormTransformInputVectorizedKernelFunctor< 4, input_scalar_t,