Skip to content

Commit

Permalink
subgroup shuffle memeber are dpreacated, use sycl function instead
Browse files Browse the repository at this point in the history
  • Loading branch information
guangyey committed Jul 23, 2024
1 parent 5f4970b commit f80a4c8
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 32 deletions.
16 changes: 8 additions & 8 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ static inline void group_reduce(
// uint32_t SIMD = sg.get_local_range()[0];
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
val = bin_op(val, static_cast<accscalar_t>(sg.shuffle_down(val, i)));
val = bin_op(val, static_cast<accscalar_t>(sycl::shift_group_left(sg, val, i)));
}
if (sub_group_num == 1) {
if (lane_id == 0) {
Expand Down Expand Up @@ -294,7 +294,7 @@ static inline void group_reduce(
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
val = bin_op(val, static_cast<accscalar_t>(sg.shuffle_down(val, i)));
val = bin_op(val, static_cast<accscalar_t>(sycl::shift_group_left(sg, val, i)));
if (i >= ((sub_group_num + 1) >> 1))
break;
}
Expand Down Expand Up @@ -450,10 +450,10 @@ struct BatchNormCollectStatisticsKernelFunctor
// one value per subgroup
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
stat_accscalar_t o_avg = sg.shuffle_xor(avg, i);
int o_n = sg.shuffle_xor(n, i);
stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i);
int o_n = sycl::permute_group_by_xor(sg, n, i);
stat_accscalar_t factor = 1.0 / fmaxf(1.0, n + o_n);
var_n += sg.shuffle_xor(var_n, i) +
var_n += sycl::permute_group_by_xor(sg, var_n, i) +
(avg - o_avg) * (avg - o_avg) * n * o_n * factor;
avg = (n * avg + o_n * o_avg) * factor;
n += o_n;
Expand Down Expand Up @@ -481,10 +481,10 @@ struct BatchNormCollectStatisticsKernelFunctor
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
stat_accscalar_t o_avg = sg.shuffle_xor(avg, i);
int o_n = sg.shuffle_xor(n, i);
stat_accscalar_t o_avg = sycl::permute_group_by_xor(sg, avg, i);
int o_n = sycl::permute_group_by_xor(sg, n, i);
stat_accscalar_t factor = 1.0f / fmaxf(1.0f, n + o_n);
var_n += sg.shuffle_xor(var_n, i) +
var_n += sycl::permute_group_by_xor(sg, var_n, i) +
(avg - o_avg) * (avg - o_avg) * n * o_n * factor;
avg = (n * avg + o_n * o_avg) * factor;
n += o_n;
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/DistanceKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ scalar_t subgroup_reduce_agg_without_broadcast_impl(

#pragma unroll
for (int offset = (SG_SIZE >> 1); offset > 0; offset >>= 1) {
F::agg(value, sg.shuffle_down(value, offset));
F::agg(value, sycl::shift_group_left(sg, value, offset));
}
return value;
}
Expand Down
8 changes: 4 additions & 4 deletions src/ATen/native/xpu/sycl/GroupNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ struct WelfordOpsXPU
inline acc_t shfl_down(acc_t acc, int offset) const {
auto sg = item.get_sub_group();
return {
sg.shuffle_down(acc.mean, offset),
sg.shuffle_down(acc.m2, offset),
sg.shuffle_down(acc.n, offset),
sg.shuffle_down(acc.nf, offset)};
sycl::shift_group_left(sg, acc.mean, offset),
sycl::shift_group_left(sg, acc.m2, offset),
sycl::shift_group_left(sg, acc.n, offset),
sycl::shift_group_left(sg, acc.nf, offset)};
}

WelfordOpsXPU(acc_scalar_t correction, bool take_sqrt, sycl::nd_item<1>& item)
Expand Down
8 changes: 4 additions & 4 deletions src/ATen/native/xpu/sycl/Norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ static inline void norm_group_reduce(
// uint32_t SIMD = sg.get_local_range()[0];
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
sum1 = bin_op(sum1, static_cast<accscalar_t>(sg.shuffle_down(sum1, i)));
sum2 = bin_op(sum2, static_cast<accscalar_t>(sg.shuffle_down(sum2, i)));
sum1 = bin_op(sum1, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum1, i)));
sum2 = bin_op(sum2, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum2, i)));
}
if (sub_group_num == 1) {
sum1 = sycl::group_broadcast(sg, sum1, 0);
Expand Down Expand Up @@ -73,8 +73,8 @@ static inline void norm_group_reduce(
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
sum1 = bin_op(sum1, static_cast<accscalar_t>(sg.shuffle_down(sum1, i)));
sum2 = bin_op(sum2, static_cast<accscalar_t>(sg.shuffle_down(sum2, i)));
sum1 = bin_op(sum1, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum1, i)));
sum2 = bin_op(sum2, static_cast<accscalar_t>(sycl::shift_group_left(sg, sum2, i)));
if (i >= ((sub_group_num + 1) >> 1))
break;
}
Expand Down
24 changes: 12 additions & 12 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
for (int offset = 1; offset < sg_size; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
Expand All @@ -71,7 +71,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_reduce(
for (int offset = 1; offset < sg_range; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
Expand Down Expand Up @@ -132,7 +132,7 @@ inline at::detail::Array<arg_t, out_vec_sz> group_x_reduce(
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll(out_vec_sz)
for (int i = 0; i < out_vec_sz; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = combine(value[i], other);
}
}
Expand Down Expand Up @@ -541,7 +541,7 @@ struct ReduceOp {
(const scalar_t*)((const char*)src + base_offsets1);
value = item_reduce<output_vec_size>(pos, input_slice);
}
// TODO: Currently, there are bugs with shuffle_down when the arg_t is a
// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t is a
// pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
constexpr bool is_pair =
Expand Down Expand Up @@ -832,7 +832,7 @@ struct ReduceOp {
return value_list[0];
}

// TODO: Currently, there are bugs with shuffle_down when the arg_t is a
// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t is a
// pair with half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
Expand All @@ -850,7 +850,7 @@ struct ReduceOp {
for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
arg_t other = sg.shuffle_down(value[i], offset);
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = ops.combine(value[i], other);
}
}
Expand Down Expand Up @@ -879,8 +879,8 @@ struct ReduceOp {
other = std::pair<
typename arg_t::first_type,
typename arg_t::second_type>(
sg.shuffle_down(value[i].first, offset),
sg.shuffle_down(value[i].second, offset));
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
Expand All @@ -907,7 +907,7 @@ struct ReduceOp {
return value;
}

// TODO: Currently, there are bugs with shuffle_down when the arg_t is a
// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t is a
// pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
Expand Down Expand Up @@ -950,8 +950,8 @@ struct ReduceOp {
std::pair<typename arg_t::first_type, typename arg_t::second_type>
other = std::
pair<typename arg_t::first_type, typename arg_t::second_type>(
sg.shuffle_down(value[i].first, offset),
sg.shuffle_down(value[i].second, offset));
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
Expand Down Expand Up @@ -1121,7 +1121,7 @@ struct ReduceOp {
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
if (config.should_group_x_reduce()) {
// TODO: workaround because sg.shuffle_down will fail on `half` dtype.
// TODO: workaround because sycl::shift_group_left will fail on `half` dtype.
constexpr bool is_pair =
std::is_same<std::pair<scalar_t, int64_t>, arg_t>::value;
if constexpr (is_pair) {
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/SYCLGroupAlgorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ template <typename T, typename sg_t>
inline T GroupReduceSumSGSizeEqualstoNumSG(sg_t& sg, T val) {
auto sg_size = sg.get_local_range()[0];
for (int offset = (sg_size >> 1); offset > 0; offset >>= 1) {
val += sg.shuffle_down(val, offset);
val += sycl::shift_group_left(sg, val, offset);
}
return val;
}
Expand Down
4 changes: 2 additions & 2 deletions src/ATen/native/xpu/sycl/SoftMaxKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ static inline void softmax_group_reduce(
// uint32_t SIMD = sg.get_local_range()[0];
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
val = bin_op(val, static_cast<accscalar_t>(sg.shuffle_down(val, i)));
val = bin_op(val, static_cast<accscalar_t>(sycl::shift_group_left(sg, val, i)));
}
if (sub_group_num == 1) {
val = sycl::group_broadcast(sg, val, 0);
Expand Down Expand Up @@ -68,7 +68,7 @@ static inline void softmax_group_reduce(
}
#pragma unroll
for (int i = 1; i < SIMD; i <<= 1) {
val = bin_op(val, static_cast<accscalar_t>(sg.shuffle_down(val, i)));
val = bin_op(val, static_cast<accscalar_t>(sycl::shift_group_left(sg, val, i)));
if (i >= ((sub_group_num + 1) >> 1))
break;
}
Expand Down

0 comments on commit f80a4c8

Please sign in to comment.