@@ -1302,25 +1302,16 @@ struct BatchNormTransformInputVectorizedKernelFunctor {
1302
1302
for (index_t feature_vec_begin = item.get_local_id (1 ) * VEC_SIZE;
1303
1303
feature_vec_begin < fs;
1304
1304
feature_vec_begin += VEC_SIZE * item.get_local_range (1 )) {
1305
- auto remaining = fs - feature_vec_begin;
1306
- if (remaining < VEC_SIZE) {
1307
- for (index_t idx = 0 ; idx < remaining; ++idx) {
1308
- index_t feature = feature_vec_begin + idx;
1309
- o[feature] = static_cast <input_scalar_t >(
1310
- gamma * (i[feature] - mean) * invstd + beta);
1311
- }
1312
- } else {
1313
- using vec_t = memory::aligned_vector<input_scalar_t , VEC_SIZE>;
1314
- vec_t vec;
1305
+ using vec_t = memory::aligned_vector<input_scalar_t , VEC_SIZE>;
1306
+ vec_t vec;
1315
1307
#pragma unroll
1316
- for (int vt = 0 ; vt < VEC_SIZE; ++vt) {
1317
- index_t feature = feature_vec_begin + vt;
1318
- vec[vt] = static_cast <input_scalar_t >(
1319
- gamma * (i[feature] - mean) * invstd + beta);
1320
- }
1321
- input_scalar_t * write_ptr = &o[feature_vec_begin];
1322
- *(reinterpret_cast <vec_t *>(write_ptr)) = vec;
1308
+ for (int vt = 0 ; vt < VEC_SIZE; ++vt) {
1309
+ index_t feature = feature_vec_begin + vt;
1310
+ vec[vt] = static_cast <input_scalar_t >(
1311
+ gamma * (i[feature] - mean) * invstd + beta);
1323
1312
}
1313
+ input_scalar_t * write_ptr = &o[feature_vec_begin];
1314
+ *(reinterpret_cast <vec_t *>(write_ptr)) = vec;
1324
1315
}
1325
1316
}
1326
1317
}
@@ -1456,13 +1447,27 @@ void batch_norm_elemt_template(
1456
1447
nwg_y = std::min<int >(nwg_y, syclMaxWorkItemsPerTile () / (tf * tb));
1457
1448
sycl::range<2 > global_range (nwg_y * tb, nwg_x * tf);
1458
1449
1459
- auto kfn = BatchNormTransformInputKernelFunctor<
1460
- input_scalar_t ,
1461
- stat_scalar_t ,
1462
- stat_accscalar_t ,
1463
- true ,
1464
- index_t >(input, output, mean, invstd, weight, bias, dummy_epsilon);
1465
- sycl_kernel_submit (global_range, local_range, queue, kfn);
1450
+ auto output_ptr = (char *)output_reshaped.data_ptr ();
1451
+ if (output_reshaped.is_contiguous () &&
1452
+ memory::can_vectorize_up_to<input_scalar_t >(output_ptr) >= 4 &&
1453
+ sizeof (input_scalar_t ) < sizeof (float ) && input.size (2 ) % 4 == 0 ) {
1454
+ auto kfn = BatchNormTransformInputVectorizedKernelFunctor<
1455
+ 4 ,
1456
+ input_scalar_t ,
1457
+ stat_scalar_t ,
1458
+ stat_accscalar_t ,
1459
+ true ,
1460
+ index_t >(input, output, mean, invstd, weight, bias, dummy_epsilon);
1461
+ sycl_kernel_submit (global_range, local_range, queue, kfn);
1462
+ } else {
1463
+ auto kfn = BatchNormTransformInputKernelFunctor<
1464
+ input_scalar_t ,
1465
+ stat_scalar_t ,
1466
+ stat_accscalar_t ,
1467
+ true ,
1468
+ index_t >(input, output, mean, invstd, weight, bias, dummy_epsilon);
1469
+ sycl_kernel_submit (global_range, local_range, queue, kfn);
1470
+ }
1466
1471
}
1467
1472
1468
1473
template <
0 commit comments