From 889f3f3cf4f384ba3f56e910abd89077306ed290 Mon Sep 17 00:00:00 2001 From: "Yu, Guangye" Date: Tue, 6 Aug 2024 01:12:49 +0000 Subject: [PATCH] Remove reduce workard since using xpu::pair directly --- src/ATen/native/xpu/sycl/Reduce.h | 181 +++--------------------------- 1 file changed, 15 insertions(+), 166 deletions(-) diff --git a/src/ATen/native/xpu/sycl/Reduce.h b/src/ATen/native/xpu/sycl/Reduce.h index 276b21175..a2316dc87 100644 --- a/src/ATen/native/xpu/sycl/Reduce.h +++ b/src/ATen/native/xpu/sycl/Reduce.h @@ -542,27 +542,17 @@ struct ReduceOp { (const scalar_t*)((const char*)src + base_offsets1); value = item_reduce(pos, input_slice); } - // TODO: Currently, there are bugs with sycl::shift_group_left when the - // arg_t is a pair for half dtype, We temporarily workaround to do - // "reduce_for_compound_dtype" function. - constexpr bool is_pair = - std::is_same, arg_t>::value; auto combine = [=](arg1_t value, arg2_t other) -> arg1_t { return ops.combine(value, other); }; if (config.should_group_x_reduce() && config.should_group_y_reduce()) { - if constexpr (is_pair) { - value = group_reduce_for_compound_dtype( - pos, value, shared); - } else { - value = group_reduce< - arg_t, - decltype(pos), - decltype(combine), - output_vec_size>(pos, config.num_items, shared, value, combine); - } + value = group_reduce< + arg_t, + decltype(pos), + decltype(combine), + output_vec_size>(pos, config.num_items, shared, value, combine); } else { if (config.should_group_y_reduce()) { value = group_y_reduce< @@ -572,16 +562,11 @@ struct ReduceOp { output_vec_size>(pos, shared, value, combine); } if (config.should_group_x_reduce()) { - if constexpr (is_pair) { - value = group_x_reduce_for_compound_dtype( - pos, value, shared); - } else { - value = group_x_reduce< - arg_t, - decltype(pos), - decltype(combine), - output_vec_size>(pos, shared, value, combine); - } + value = group_x_reduce< + arg_t, + decltype(pos), + decltype(combine), + output_vec_size>(pos, shared, value, combine); } } @@ -833,133 +818,6 @@ struct ReduceOp { return value_list[0]; } - // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t - // is a pair with half dtype, We temporarily workaround to do - // "reduce_for_compound_dtype" function. - template - at::detail::Array group_reduce_for_compound_dtype( - sycl::nd_item<2> pos, - at::detail::Array value, - sycl_local_ptr shared_memory) const { - auto sg = pos.get_sub_group(); - uint32_t sbgrpSize = sg.get_local_range()[0]; - int l_x = pos.get_local_linear_id(); - int sg_lid = sg.get_local_linear_id(); - int sg_gid = sg.get_group_linear_id(); - int sg_range = sg.get_group_range()[0]; - - for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) { -#pragma unroll(output_vec_size) - for (int i = 0; i < output_vec_size; ++i) { - arg_t other = sycl::shift_group_left(sg, value[i], offset); - value[i] = ops.combine(value[i], other); - } - } - - using args_vec_t = at::detail::Array; - sycl_local_ptr shared{shared_memory}; - - if (sg_lid == 0) { - shared[sg_gid] = value; - } - pos.barrier(sycl_local_fence); - - if (sg_range <= (int)sbgrpSize) { - // sub-group reduce -#pragma unroll(output_vec_size) - for (int i = 0; i < output_vec_size; i++) { - value[i] = ident; - } - if (sg_gid == 0 && sg_lid < sg_range) { - value = shared[sg_lid]; - for (int offset = 1; offset < sg_range; offset <<= 1) { -#pragma unroll(output_vec_size) - for (int i = 0; i < output_vec_size; ++i) { - // Shuffle down separately for first and second pair. - at::xpu:: - pair - other = at::xpu::pair< - typename arg_t::first_type, - typename arg_t::second_type>( - sycl::shift_group_left(sg, value[i].first, offset), - sycl::shift_group_left(sg, value[i].second, offset)); - value[i] = ops.combine(value[i], other); - } - } - } - } else { - // work item tree reduce - if (l_x < sg_range) { - value = shared[l_x]; - } - - for (int offset = sg_range / 2; offset > 0; offset >>= 1) { - if (l_x < offset) { - args_vec_t other = shared[l_x + offset]; -#pragma unroll(output_vec_size) - for (int i = 0; i < output_vec_size; ++i) { - value[i] = ops.combine(value[i], other[i]); - } - shared[l_x] = value; - } - pos.barrier(sycl_local_fence); - } - } - - return value; - } - - // TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t - // is a pair for half dtype, We temporarily workaround to do - // "reduce_for_compound_dtype" function. - template - at::detail::Array group_x_reduce_for_compound_dtype( - sycl::nd_item<2> pos, - at::detail::Array value, - sycl_local_ptr shared_memory) const { - using args_vec_t = at::detail::Array; - auto l_x = pos.get_local_id(1), l_y = pos.get_local_id(0); - auto gp_x = pos.get_local_range(1); - - int dim_x = gp_x; - sycl_local_ptr shared(shared_memory); - auto sg = pos.get_sub_group(); - uint32_t sbgrpSize = sg.get_local_range()[0]; - if (dim_x > (int)sbgrpSize) { - int address_base = l_x + l_y * gp_x; - shared[address_base] = value; - for (int offset = dim_x / 2; offset >= (int)sbgrpSize; offset >>= 1) { - pos.barrier(sycl_local_fence); - if ((int)l_x < offset && - (int)l_x + offset < (int)gp_x /* redundant??? */) { - args_vec_t other = shared[address_base + offset]; -#pragma unroll(output_vec_size) - for (int i = 0; i < output_vec_size; ++i) { - value[i] = ops.combine(value[i], other[i]); - } - shared[address_base] = value; - } - } - dim_x = sbgrpSize; - } - - pos.barrier(sycl_local_fence); - - // sub-group reduction - for (int offset = 1; offset < dim_x; offset <<= 1) { -#pragma unroll(output_vec_size) - for (int i = 0; i < output_vec_size; ++i) { - at::xpu::pair - other = xpu:: - pair( - sycl::shift_group_left(sg, value[i].first, offset), - sycl::shift_group_left(sg, value[i].second, offset)); - value[i] = ops.combine(value[i], other); - } - } - return value; - } - // In/out from slm pointers void mark_group_finished(sycl::nd_item<2> pos, sycl_local_ptr finished) const { @@ -1124,20 +982,11 @@ struct ReduceOp { decltype(combine), output_vec_size>(pos, shared_memory, value, combine); if (config.should_group_x_reduce()) { - // TODO: workaround because sycl::shift_group_left will fail on `half` - // dtype. - constexpr bool is_pair = - std::is_same, arg_t>::value; - if constexpr (is_pair) { - value = group_x_reduce_for_compound_dtype( - pos, value, shared_memory); - } else { - value = group_x_reduce< - arg_t, - decltype(pos), - decltype(combine), - output_vec_size>(pos, shared_memory, value, combine); - } + value = group_x_reduce< + arg_t, + decltype(pos), + decltype(combine), + output_vec_size>(pos, shared_memory, value, combine); } if (should_store) { if (accumulate) {