Skip to content

Commit

Permalink
mtl: reduction: Bypass group reduction on SLM by sparing workload to …
Browse files Browse the repository at this point in the history
…other SGs when SIMD=8 (#692)

Fixing: #611
(test_min_xpu_bool)
Tracking: #698
Another change: Aligning memory access utility `c10::load` with PyTorch
usage, `c10::load` makes bool data with value compatible.

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Aug 7, 2024
1 parent 718bc42 commit fb365ac
Showing 1 changed file with 17 additions and 13 deletions.
30 changes: 17 additions & 13 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,14 @@ struct ReduceConfig {
template <typename T, class KernelClass>
void set_group_dimension(int64_t dim0, int64_t dim1) {
auto max_wg_sz = syclMaxWorkGroupSize<KernelClass>();
auto max_sg_sz = syclMaxSubGroupSize();
// Bypass reduction on SLM by sparing workload to other SGs. As the result,
// reduction of small shape input only requires some shift operations
// in side of SG. It is functional WA. We got case failures on some
// platforms supporting SIMD8.
// https://github.com/intel/torch-xpu-ops/issues/698
auto max_sg_sz = syclMinSubGroupSize() == 8
? syclMinSubGroupSize()
: syclMaxSubGroupSize();
const int max_num_items = max_wg_sz / output_vec_size;
int dim0_pow2 = dim0 < max_num_items ? static_cast<int>(last_pow2(dim0))
: max_num_items;
Expand Down Expand Up @@ -686,7 +693,7 @@ struct ReduceOp {
config.should_reduce_tail(pos)) {
value = ops.reduce(
value,
data[pos.get_local_id(1) - shift],
c10::load(data + pos.get_local_id(1) - shift),
pos.get_local_id(1) - shift);
}
// align data to vector start
Expand All @@ -698,9 +705,6 @@ struct ReduceOp {
shift = align_elements - shift;
}

// Do the vectorized reduction
using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;

index_t idx = config.input_idx(pos);
const index_t stride = config.step_input;

Expand All @@ -713,14 +717,12 @@ struct ReduceOp {
value_list[i] = ident;
}

load_t values;

while (idx * input_vec_size + input_vec_size - 1 < end) {
values = reinterpret_cast<const load_t*>(data)[idx];
const auto values_vec = memory::load_vector<input_vec_size>(data, idx);
#pragma unroll
for (index_t i = 0; i < input_vec_size; ++i) {
value_list[i] = ops.reduce(
value_list[i], values[i], shift + idx * input_vec_size + i);
value_list[i], values_vec[i], shift + idx * input_vec_size + i);
}
idx += stride;
}
Expand All @@ -730,7 +732,8 @@ struct ReduceOp {
if (config.should_reduce_tail(pos)) {
int idx = tail_start + pos.get_local_id(1);
if ((index_t)idx < end) {
value_list[0] = ops.reduce(value_list[0], data[idx], idx + shift);
const auto value = c10::load(data + idx);
value_list[0] = ops.reduce(value_list[0], value, idx + shift);
}
}

Expand All @@ -754,7 +757,6 @@ struct ReduceOp {
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using load_t =
at::native::memory::aligned_vector<scalar_t, output_vec_size>;
const load_t* data = reinterpret_cast<const load_t*>(data_);

arg_vec_t value_list[vt0];

Expand All @@ -771,7 +773,8 @@ struct ReduceOp {
while (idx + (vt0 - 1) * stride < end) {
#pragma unroll(vt0)
for (index_t i = 0; i < vt0; ++i) {
values[i] = data[calc(idx + i * stride) / output_vec_size];
const auto offset = calc(idx + i * stride) / output_vec_size;
values[i] = memory::load_vector<output_vec_size>(data_, offset);
}
#pragma unroll(vt0)
for (index_t i = 0; i < vt0; ++i) {
Expand All @@ -791,7 +794,8 @@ struct ReduceOp {
if (idx >= end) {
break;
}
values[i] = data[calc(idx) / output_vec_size];
const auto offset = calc(idx) / output_vec_size;
values[i] = memory::load_vector<output_vec_size>(data_, offset);
idx += stride;
}
idx = idx_;
Expand Down

0 comments on commit fb365ac

Please sign in to comment.