Skip to content

Commit 48eddc2

Browse files
authored
Update BatchNormKernels.cpp
1 parent 214f33b commit 48eddc2

File tree

1 file changed

+8
-22
lines changed

1 file changed

+8
-22
lines changed

src/ATen/native/xpu/sycl/BatchNormKernels.cpp

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
12871287
} else {
12881288
invstd =
12891289
static_cast<stat_accscalar_t>(1) /
1290-
device_sqrt(
1290+
std::sqrt(
12911291
static_cast<stat_accscalar_t>(var_or_invstd_[plane]) + epsilon_);
12921292
}
12931293

@@ -1456,27 +1456,13 @@ void batch_norm_elemt_template(
14561456
nwg_y = std::min<int>(nwg_y, syclMaxWorkItemsPerTile() / (tf * tb));
14571457
sycl::range<2> global_range(nwg_y * tb, nwg_x * tf);
14581458

1459-
auto output_ptr = (char*)output_reshaped.data_ptr();
1460-
if (output_reshaped.is_contiguous() &&
1461-
memory::can_vectorize_up_to<input_scalar_t>(output_ptr) >= 4 &&
1462-
sizeof(input_scalar_t) < sizeof(float)) {
1463-
auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
1464-
4,
1465-
input_scalar_t,
1466-
stat_scalar_t,
1467-
stat_accscalar_t,
1468-
true,
1469-
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
1470-
sycl_kernel_submit(global_range, local_range, queue, kfn);
1471-
} else {
1472-
auto kfn = BatchNormTransformInputKernelFunctor<
1473-
input_scalar_t,
1474-
stat_scalar_t,
1475-
stat_accscalar_t,
1476-
true,
1477-
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
1478-
sycl_kernel_submit(global_range, local_range, queue, kfn);
1479-
}
1459+
auto kfn = BatchNormTransformInputKernelFunctor<
1460+
input_scalar_t,
1461+
stat_scalar_t,
1462+
stat_accscalar_t,
1463+
true,
1464+
index_t>(input, output, mean, invstd, weight, bias, dummy_epsilon);
1465+
sycl_kernel_submit(global_range, local_range, queue, kfn);
14801466
}
14811467

14821468
template <

0 commit comments

Comments
 (0)