Skip to content

Commit

Permalink
Remove reduce workard since using xpu::pair directly
Browse files Browse the repository at this point in the history
  • Loading branch information
guangyey committed Aug 6, 2024
1 parent 2f75c47 commit 889f3f3
Showing 1 changed file with 15 additions and 166 deletions.
181 changes: 15 additions & 166 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -542,27 +542,17 @@ struct ReduceOp {
(const scalar_t*)((const char*)src + base_offsets1);
value = item_reduce<output_vec_size>(pos, input_slice);
}
// TODO: Currently, there are bugs with sycl::shift_group_left when the
// arg_t is a pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
constexpr bool is_pair =
std::is_same<at::xpu::pair<scalar_t, int64_t>, arg_t>::value;

auto combine = [=](arg1_t value, arg2_t other) -> arg1_t {
return ops.combine(value, other);
};

if (config.should_group_x_reduce() && config.should_group_y_reduce()) {
if constexpr (is_pair) {
value = group_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared);
} else {
value = group_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, config.num_items, shared, value, combine);
}
value = group_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, config.num_items, shared, value, combine);
} else {
if (config.should_group_y_reduce()) {
value = group_y_reduce<
Expand All @@ -572,16 +562,11 @@ struct ReduceOp {
output_vec_size>(pos, shared, value, combine);
}
if (config.should_group_x_reduce()) {
if constexpr (is_pair) {
value = group_x_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared);
} else {
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared, value, combine);
}
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared, value, combine);
}
}

Expand Down Expand Up @@ -833,133 +818,6 @@ struct ReduceOp {
return value_list[0];
}

// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t
// is a pair with half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
at::detail::Array<arg_t, output_vec_size> group_reduce_for_compound_dtype(
sycl::nd_item<2> pos,
at::detail::Array<arg_t, output_vec_size> value,
sycl_local_ptr<void> shared_memory) const {
auto sg = pos.get_sub_group();
uint32_t sbgrpSize = sg.get_local_range()[0];
int l_x = pos.get_local_linear_id();
int sg_lid = sg.get_local_linear_id();
int sg_gid = sg.get_group_linear_id();
int sg_range = sg.get_group_range()[0];

for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = ops.combine(value[i], other);
}
}

using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
sycl_local_ptr<args_vec_t> shared{shared_memory};

if (sg_lid == 0) {
shared[sg_gid] = value;
}
pos.barrier(sycl_local_fence);

if (sg_range <= (int)sbgrpSize) {
// sub-group reduce
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; i++) {
value[i] = ident;
}
if (sg_gid == 0 && sg_lid < sg_range) {
value = shared[sg_lid];
for (int offset = 1; offset < sg_range; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
// Shuffle down separately for first and second pair.
at::xpu::
pair<typename arg_t::first_type, typename arg_t::second_type>
other = at::xpu::pair<
typename arg_t::first_type,
typename arg_t::second_type>(
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
}
} else {
// work item tree reduce
if (l_x < sg_range) {
value = shared[l_x];
}

for (int offset = sg_range / 2; offset > 0; offset >>= 1) {
if (l_x < offset) {
args_vec_t other = shared[l_x + offset];
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
value[i] = ops.combine(value[i], other[i]);
}
shared[l_x] = value;
}
pos.barrier(sycl_local_fence);
}
}

return value;
}

// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t
// is a pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
at::detail::Array<arg_t, output_vec_size> group_x_reduce_for_compound_dtype(
sycl::nd_item<2> pos,
at::detail::Array<arg_t, output_vec_size> value,
sycl_local_ptr<void> shared_memory) const {
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
auto l_x = pos.get_local_id(1), l_y = pos.get_local_id(0);
auto gp_x = pos.get_local_range(1);

int dim_x = gp_x;
sycl_local_ptr<args_vec_t> shared(shared_memory);
auto sg = pos.get_sub_group();
uint32_t sbgrpSize = sg.get_local_range()[0];
if (dim_x > (int)sbgrpSize) {
int address_base = l_x + l_y * gp_x;
shared[address_base] = value;
for (int offset = dim_x / 2; offset >= (int)sbgrpSize; offset >>= 1) {
pos.barrier(sycl_local_fence);
if ((int)l_x < offset &&
(int)l_x + offset < (int)gp_x /* redundant??? */) {
args_vec_t other = shared[address_base + offset];
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
value[i] = ops.combine(value[i], other[i]);
}
shared[address_base] = value;
}
}
dim_x = sbgrpSize;
}

pos.barrier(sycl_local_fence);

// sub-group reduction
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
at::xpu::pair<typename arg_t::first_type, typename arg_t::second_type>
other = xpu::
pair<typename arg_t::first_type, typename arg_t::second_type>(
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
return value;
}

// In/out from slm pointers
void mark_group_finished(sycl::nd_item<2> pos, sycl_local_ptr<bool> finished)
const {
Expand Down Expand Up @@ -1124,20 +982,11 @@ struct ReduceOp {
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
if (config.should_group_x_reduce()) {
// TODO: workaround because sycl::shift_group_left will fail on `half`
// dtype.
constexpr bool is_pair =
std::is_same<at::xpu::pair<scalar_t, int64_t>, arg_t>::value;
if constexpr (is_pair) {
value = group_x_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared_memory);
} else {
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
}
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
}
if (should_store) {
if (accumulate) {
Expand Down

0 comments on commit 889f3f3

Please sign in to comment.