From 48eddc20ddb632fb42479359f457da54adb51a90 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 30 Dec 2024 15:01:44 +0800 Subject: [PATCH] Update BatchNormKernels.cpp --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 30 +++++-------------- 1 file changed, 8 insertions(+), 22 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 30a3bf5c2..cc1b3d326 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor { } else { invstd = static_cast(1) / - device_sqrt( + std::sqrt( static_cast(var_or_invstd_[plane]) + epsilon_); } @@ -1456,27 +1456,13 @@ void batch_norm_elemt_template( nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); - auto output_ptr = (char*)output_reshaped.data_ptr(); - if (output_reshaped.is_contiguous() && - memory::can_vectorize_up_to(output_ptr) >= 4 && - sizeof(input_scalar_t) < sizeof(float)) { - auto kfn = BatchNormTransformInputVectorizedKernelFunctor< - 4, - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - sycl_kernel_submit(global_range, local_range, queue, kfn); - } else { - auto kfn = BatchNormTransformInputKernelFunctor< - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - sycl_kernel_submit(global_range, local_range, queue, kfn); - } + auto kfn = BatchNormTransformInputKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); } template <