Skip to content

Commit

Permalink
Update BatchNormKernels.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Jan 2, 2025
1 parent 28cfac2 commit c17e295
Showing 1 changed file with 21 additions and 7 deletions.
28 changes: 21 additions & 7 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1456,13 +1456,27 @@ void batch_norm_elemt_template(
nwg_y = std::min<int>(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb));
sycl::range<2> global_range(nwg_y * tb, nwg_x * tf);

auto kfn = BatchNormTransformInputKernelFunctor<
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
auto output_ptr = (char*)output_reshaped.data_ptr();
if (output_reshaped.is_contiguous() &&
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) {
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
4,
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
} else {
auto kfn = BatchNormTransformInputKernelFunctor<
input_scalar_t,
stat_scalar_t,
stat_accscalar_t,
true,
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
sycl_kernel_submit(global_range, local_range, queue, kfn);
}
}

template <
Expand Down

0 comments on commit c17e295

Please sign in to comment.