Skip to content

Commit

Permalink
Update BatchNormKernels.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Dec 31, 2024
1 parent 1f3dfe0 commit 4d3d601
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
} else {
invstd =
static_cast<stat_accscalar_t>(1) /
device_sqrt(
std::sqrt(
static_cast<stat_accscalar_t>(var_or_invstd_[plane]) + epsilon_);
}

Expand Down Expand Up @@ -1459,7 +1459,7 @@ void batch_norm_elemt_template(
auto output_ptr = (char*)output_reshaped.data_ptr();
if (output_reshaped.is_contiguous() &&
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
sizeof(input_scalar_t) < sizeof(float)) {
sizeof(input_scalar_t) < sizeof(float) && input.size(2) % 4 == 0) {
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
4,
input_scalar_t,
Expand Down

0 comments on commit 4d3d601

Please sign in to comment.