From f80a4c81e62bd63efac269c1f229fe7b9552ad67 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 23 Jul 2024 11:17:59 +0000 Subject: [PATCH] subgroup shuffle memeber are dpreacated, use sycl function instead --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 16 ++++++------- src/ATen/native/xpu/sycl/DistanceKernels.cpp | 2 +- src/ATen/native/xpu/sycl/GroupNormKernels.cpp | 8 +++---- src/ATen/native/xpu/sycl/Norm.h | 8 +++---- src/ATen/native/xpu/sycl/Reduce.h | 24 +++++++++---------- src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h | 2 +- src/ATen/native/xpu/sycl/SoftMaxKernels.cpp | 4 ++-- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 10f7f0eec..24dfa485d 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -263,7 +263,7 @@ static inline void group_reduce( // uint32_t SIMD = sg.get_local_range()[0]; #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op(val, static_cast(sycl::shift_group_left(sg, val, i))); } if (sub_group_num == 1) { if (lane_id == 0) { @@ -294,7 +294,7 @@ static inline void group_reduce( } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op(val, static_cast(sycl::shift_group_left(sg, val, i))); if (i >= ((sub_group_num + 1) >> 1)) break; } @@ -450,10 +450,10 @@ struct BatchNormCollectStatisticsKernelFunctor // one value per subgroup #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - stat_accscalar_t o_avg = sg.shuffle_xor(avg, i); - int o_n = sg.shuffle_xor(n, i); + stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i); + int o_n = sycl::permute_group_by_xor(sg, n, i); stat_accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n); - var_n += sg.shuffle_xor(var_n, i) + + var_n += sycl::permute_group_by_xor(sg, var_n, i) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; @@ -481,10 +481,10 @@ struct BatchNormCollectStatisticsKernelFunctor } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - stat_accscalar_t o_avg = sg.shuffle_xor(avg, i); - int o_n = sg.shuffle_xor(n, i); + stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i); + int o_n = sycl::permute_group_by_xor(sg, n, i); stat_accscalar_t factor = 1.0f / fmaxf(1.0f, n + o_n); - var_n += sg.shuffle_xor(var_n, i) + + var_n += sycl::permute_group_by_xor(sg, var_n, i) + (avg - o_avg) * (avg - o_avg) * n * o_n * factor; avg = (n * avg + o_n * o_avg) * factor; n += o_n; diff --git a/src/ATen/native/xpu/sycl/DistanceKernels.cpp b/src/ATen/native/xpu/sycl/DistanceKernels.cpp index 8bd61bdd3..eb0f1f50e 100644 --- a/src/ATen/native/xpu/sycl/DistanceKernels.cpp +++ b/src/ATen/native/xpu/sycl/DistanceKernels.cpp @@ -120,7 +120,7 @@ scalar_t subgroup_reduce_agg_without_broadcast_impl( #pragma unroll for (int offset = (SG_SIZE >> 1); offset > 0; offset >>= 1) { - F::agg(value, sg.shuffle_down(value, offset)); + F::agg(value, sycl::shift_group_left(sg, value, offset)); } return value; } diff --git a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp index 622e99ffe..6a662bb56 100644 --- a/src/ATen/native/xpu/sycl/GroupNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/GroupNormKernels.cpp @@ -27,10 +27,10 @@ struct WelfordOpsXPU inline acc_t shfl_down(acc_t acc, int offset) const { auto sg = item.get_sub_group(); return { - sg.shuffle_down(acc.mean, offset), - sg.shuffle_down(acc.m2, offset), - sg.shuffle_down(acc.n, offset), - sg.shuffle_down(acc.nf, offset)}; + sycl::shift_group_left(sg, acc.mean, offset), + sycl::shift_group_left(sg, acc.m2, offset), + sycl::shift_group_left(sg, acc.n, offset), + sycl::shift_group_left(sg, acc.nf, offset)}; } WelfordOpsXPU(acc_scalar_t correction, bool take_sqrt, sycl::nd_item<1>& item) diff --git a/src/ATen/native/xpu/sycl/Norm.h b/src/ATen/native/xpu/sycl/Norm.h index 9aee941cb..4e6378121 100644 --- a/src/ATen/native/xpu/sycl/Norm.h +++ b/src/ATen/native/xpu/sycl/Norm.h @@ -39,8 +39,8 @@ static inline void norm_group_reduce( // uint32_t SIMD = sg.get_local_range()[0]; #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - sum1 = bin_op(sum1, static_cast(sg.shuffle_down(sum1, i))); - sum2 = bin_op(sum2, static_cast(sg.shuffle_down(sum2, i))); + sum1 = bin_op(sum1, static_cast(sycl::shift_group_left(sg, sum1, i))); + sum2 = bin_op(sum2, static_cast(sycl::shift_group_left(sg, sum2, i))); } if (sub_group_num == 1) { sum1 = sycl::group_broadcast(sg, sum1, 0); @@ -73,8 +73,8 @@ static inline void norm_group_reduce( } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - sum1 = bin_op(sum1, static_cast(sg.shuffle_down(sum1, i))); - sum2 = bin_op(sum2, static_cast(sg.shuffle_down(sum2, i))); + sum1 = bin_op(sum1, static_cast(sycl::shift_group_left(sg, sum1, i))); + sum2 = bin_op(sum2, static_cast(sycl::shift_group_left(sg, sum2, i))); if (i >= ((sub_group_num + 1) >> 1)) break; } diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 1be3a5e93..eba054157 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -50,7 +50,7 @@ inline at::detail::Array group_reduce( for (int offset = 1; offset < sg_size; offset <<= 1) { #pragma unroll(out_vec_sz) for (int i = 0; i < out_vec_sz; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = combine(value[i], other); } } @@ -71,7 +71,7 @@ inline at::detail::Array group_reduce( for (int offset = 1; offset < sg_range; offset <<= 1) { #pragma unroll(out_vec_sz) for (int i = 0; i < out_vec_sz; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = combine(value[i], other); } } @@ -132,7 +132,7 @@ inline at::detail::Array group_x_reduce( for (int offset = 1; offset < dim_x; offset <<= 1) { #pragma unroll(out_vec_sz) for (int i = 0; i < out_vec_sz; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = combine(value[i], other); } } @@ -541,7 +541,7 @@ struct ReduceOp { (const scalar_t*)((const char*)src + base_offsets1); value = item_reduce(pos, input_slice); } - // TODO: Currently, there are bugs with shuffle_down when the arg_t is a + // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t is a // pair for half dtype, We temporarily workaround to do // "reduce_for_compound_dtype" function. constexpr bool is_pair = @@ -832,7 +832,7 @@ struct ReduceOp { return value_list[0]; } - // TODO: Currently, there are bugs with shuffle_down when the arg_t is a + // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t is a // pair with half dtype, We temporarily workaround to do // "reduce_for_compound_dtype" function. template @@ -850,7 +850,7 @@ struct ReduceOp { for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) { #pragma unroll(output_vec_size) for (int i = 0; i < output_vec_size; ++i) { - arg_t other = sg.shuffle_down(value[i], offset); + arg_t other = sycl::shift_group_left(sg, value[i], offset); value[i] = ops.combine(value[i], other); } } @@ -879,8 +879,8 @@ struct ReduceOp { other = std::pair< typename arg_t::first_type, typename arg_t::second_type>( - sg.shuffle_down(value[i].first, offset), - sg.shuffle_down(value[i].second, offset)); + sycl::shift_group_left(sg, value[i].first, offset), + sycl::shift_group_left(sg, value[i].second, offset)); value[i] = ops.combine(value[i], other); } } @@ -907,7 +907,7 @@ struct ReduceOp { return value; } - // TODO: Currently, there are bugs with shuffle_down when the arg_t is a + // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t is a // pair for half dtype, We temporarily workaround to do // "reduce_for_compound_dtype" function. template @@ -950,8 +950,8 @@ struct ReduceOp { std::pair other = std:: pair( - sg.shuffle_down(value[i].first, offset), - sg.shuffle_down(value[i].second, offset)); + sycl::shift_group_left(sg, value[i].first, offset), + sycl::shift_group_left(sg, value[i].second, offset)); value[i] = ops.combine(value[i], other); } } @@ -1121,7 +1121,7 @@ struct ReduceOp { decltype(combine), output_vec_size>(pos, shared_memory, value, combine); if (config.should_group_x_reduce()) { - // TODO: workaround because sg.shuffle_down will fail on `half` dtype. + // TODO: workaround because sycl::shift_group_left will fail on `half` dtype. constexpr bool is_pair = std::is_same, arg_t>::value; if constexpr (is_pair) { diff --git a/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h b/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h index 35f1d54a5..8729f31dd 100644 --- a/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h +++ b/src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h @@ -6,7 +6,7 @@ template inline T GroupReduceSumSGSizeEqualstoNumSG(sg_t& sg, T val) { auto sg_size = sg.get_local_range()[0]; for (int offset = (sg_size >> 1); offset > 0; offset >>= 1) { - val += sg.shuffle_down(val, offset); + val += sycl::shift_group_left(sg, val, offset); } return val; } diff --git a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp index 57ebb9846..669e952bb 100644 --- a/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp +++ b/src/ATen/native/xpu/sycl/SoftMaxKernels.cpp @@ -40,7 +40,7 @@ static inline void softmax_group_reduce( // uint32_t SIMD = sg.get_local_range()[0]; #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op(val, static_cast(sycl::shift_group_left(sg, val, i))); } if (sub_group_num == 1) { val = sycl::group_broadcast(sg, val, 0); @@ -68,7 +68,7 @@ static inline void softmax_group_reduce( } #pragma unroll for (int i = 1; i < SIMD; i <<= 1) { - val = bin_op(val, static_cast(sg.shuffle_down(val, i))); + val = bin_op(val, static_cast(sycl::shift_group_left(sg, val, i))); if (i >= ((sub_group_num + 1) >> 1)) break; }