diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 30a3bf5c2..cc1b3d326 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor { } else { invstd = static_cast(1) / - device_sqrt( + std::sqrt( static_cast(var_or_invstd_[plane]) + epsilon_); } @@ -1456,27 +1456,13 @@ void batch_norm_elemt_template( nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); - auto output_ptr = (char*)output_reshaped.data_ptr(); - if (output_reshaped.is_contiguous() && - memory::can_vectorize_up_to(output_ptr) >= 4 && - sizeof(input_scalar_t) < sizeof(float)) { - auto kfn = BatchNormTransformInputVectorizedKernelFunctor< - 4, - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - sycl_kernel_submit(global_range, local_range, queue, kfn); - } else { - auto kfn = BatchNormTransformInputKernelFunctor< - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - sycl_kernel_submit(global_range, local_range, queue, kfn); - } + auto kfn = BatchNormTransformInputKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); } template <