From c17e29501a71bcce563333b7edd0a72c99fa5550 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 2 Jan 2025 15:40:22 +0800 Subject: [PATCH] Update BatchNormKernels.cpp --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index cc1b3d326..61a960e30 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1456,13 +1456,27 @@ void batch_norm_elemt_template( nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); - auto kfn = BatchNormTransformInputKernelFunctor< - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - sycl_kernel_submit(global_range, local_range, queue, kfn); + auto output_ptr = (char*)output_reshaped.data_ptr(); + if (output_reshaped.is_contiguous() && + memory::can_vectorize_up_to(output_ptr) >= 4 && + sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) { + auto kfn = BatchNormTransformInputVectorizedKernelFunctor< + 4, + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } else { + auto kfn = BatchNormTransformInputKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } } template <