diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index cc1b3d326..f50524db5 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor { for (index_t feature_vec_begin = item.get_local_id(1) * VEC_SIZE; feature_vec_begin < fs; feature_vec_begin += VEC_SIZE * item.get_local_range(1)) { - auto remaining = fs - feature_vec_begin; - if (remaining < VEC_SIZE) { - for (index_t idx = 0; idx < remaining; ++idx) { - index_t feature = feature_vec_begin + idx; - o[feature] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - } else { - using vec_t = memory::aligned_vector; - vec_t vec; + using vec_t = memory::aligned_vector; + vec_t vec; #pragma unroll - for (int vt = 0; vt < VEC_SIZE; ++vt) { - index_t feature = feature_vec_begin + vt; - vec[vt] = static_cast( - gamma * (i[feature] - mean) * invstd + beta); - } - input_scalar_t* write_ptr = &o[feature_vec_begin]; - *(reinterpret_cast(write_ptr)) = vec; + for (int vt = 0; vt < VEC_SIZE; ++vt) { + index_t feature = feature_vec_begin + vt; + vec[vt] = static_cast( + gamma * (i[feature] - mean) * invstd + beta); } + input_scalar_t* write_ptr = &o[feature_vec_begin]; + *(reinterpret_cast(write_ptr)) = vec; } } } @@ -1456,13 +1447,27 @@ void batch_norm_elemt_template( nwg_y = std::min(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb)); sycl::range<2> global_range(nwg_y * tb, nwg_x * tf); - auto kfn = BatchNormTransformInputKernelFunctor< - input_scalar_t, - stat_scalar_t, - stat_accscalar_t, - true, - index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); - sycl_kernel_submit(global_range, local_range, queue, kfn); + auto output_ptr = (char*)output_reshaped.data_ptr(); + if (output_reshaped.is_contiguous() && + memory::can_vectorize_up_to(output_ptr) >= 4 && + sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) { + auto kfn = BatchNormTransformInputVectorizedKernelFunctor< + 4, + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } else { + auto kfn = BatchNormTransformInputKernelFunctor< + input_scalar_t, + stat_scalar_t, + stat_accscalar_t, + true, + index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon); + sycl_kernel_submit(global_range, local_range, queue, kfn); + } } template <