@@ -1287,7 +1287,7 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
1287
1287
} else {
1288
1288
invstd =
1289
1289
static_cast <stat_accscalar_t >(1 ) /
1290
- device_sqrt (
1290
+ std::sqrt (
1291
1291
static_cast <stat_accscalar_t >(var_or_invstd_[plane]) + epsilon_);
1292
1292
}
1293
1293
@@ -1456,27 +1456,13 @@ void batch_norm_elemt_template(
1456
1456
nwg_y = std::min<int >(nwg_y, syclMaxWorkItemsPerTile () / (tf * tb));
1457
1457
sycl::range<2 > global_range (nwg_y * tb, nwg_x * tf);
1458
1458
1459
- auto output_ptr = (char *)output_reshaped.data_ptr ();
1460
- if (output_reshaped.is_contiguous () &&
1461
- memory::can_vectorize_up_to<input_scalar_t >(output_ptr) >= 4 &&
1462
- sizeof (input_scalar_t ) < sizeof (float )) {
1463
- auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
1464
- 4 ,
1465
- input_scalar_t ,
1466
- stat_scalar_t ,
1467
- stat_accscalar_t ,
1468
- true ,
1469
- index_t >(input, output, mean, invstd, weight, bias, dummy_epsilon);
1470
- sycl_kernel_submit (global_range, local_range, queue, kfn);
1471
- } else {
1472
- auto kfn = BatchNormTransformInputKernelFunctor<
1473
- input_scalar_t ,
1474
- stat_scalar_t ,
1475
- stat_accscalar_t ,
1476
- true ,
1477
- index_t >(input, output, mean, invstd, weight, bias, dummy_epsilon);
1478
- sycl_kernel_submit (global_range, local_range, queue, kfn);
1479
- }
1459
+ auto kfn = BatchNormTransformInputKernelFunctor<
1460
+ input_scalar_t ,
1461
+ stat_scalar_t ,
1462
+ stat_accscalar_t ,
1463
+ true ,
1464
+ index_t >(input, output, mean, invstd, weight, bias, dummy_epsilon);
1465
+ sycl_kernel_submit (global_range, local_range, queue, kfn);
1480
1466
}
1481
1467
1482
1468
template <
0 commit comments