Skip to content

Commit

Permalink
Update Reduce.h
Browse files Browse the repository at this point in the history
  • Loading branch information
xytintel authored Aug 5, 2024
1 parent 2f75c47 commit a9de4ae
Showing 1 changed file with 28 additions and 182 deletions.
210 changes: 28 additions & 182 deletions src/ATen/native/xpu/sycl/Reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,19 +241,19 @@ struct ReduceConfig {
template <typename T, class KernelClass>
void set_group_dimension(int64_t dim0, int64_t dim1) {
auto max_wg_sz = syclMaxWorkGroupSize<KernelClass>();
auto max_sg_sz = syclMaxSubGroupSize();
auto min_sg_sz = syclMinSubGroupSize();
const int max_num_items = max_wg_sz / output_vec_size;
int dim0_pow2 = dim0 < max_num_items ? static_cast<int>(last_pow2(dim0))
: max_num_items;
int dim1_pow2 = dim1 < max_num_items ? static_cast<int>(last_pow2(dim1))
: max_num_items;
group_width = std::min(dim0_pow2, int(max_sg_sz));
group_width = std::min(dim0_pow2, int(min_sg_sz));
group_height = std::min(dim1_pow2, int(max_num_items / group_width));
group_width = std::min(dim0_pow2, int(max_num_items / group_height));
num_items = group_width * group_height;

if (num_items < max_sg_sz)
group_width = max_sg_sz;
if (num_items < min_sg_sz)
group_width = min_sg_sz;
}

int split_input(int parallelism) {
Expand Down Expand Up @@ -542,27 +542,17 @@ struct ReduceOp {
(const scalar_t*)((const char*)src + base_offsets1);
value = item_reduce<output_vec_size>(pos, input_slice);
}
// TODO: Currently, there are bugs with sycl::shift_group_left when the
// arg_t is a pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
constexpr bool is_pair =
std::is_same<at::xpu::pair<scalar_t, int64_t>, arg_t>::value;

auto combine = [=](arg1_t value, arg2_t other) -> arg1_t {
return ops.combine(value, other);
};

if (config.should_group_x_reduce() && config.should_group_y_reduce()) {
if constexpr (is_pair) {
value = group_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared);
} else {
value = group_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, config.num_items, shared, value, combine);
}
value = group_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, config.num_items, shared, value, combine);
} else {
if (config.should_group_y_reduce()) {
value = group_y_reduce<
Expand All @@ -572,16 +562,11 @@ struct ReduceOp {
output_vec_size>(pos, shared, value, combine);
}
if (config.should_group_x_reduce()) {
if constexpr (is_pair) {
value = group_x_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared);
} else {
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared, value, combine);
}
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared, value, combine);
}
}

Expand Down Expand Up @@ -701,7 +686,7 @@ struct ReduceOp {
config.should_reduce_tail(pos)) {
value = ops.reduce(
value,
data[pos.get_local_id(1) - shift],
c10::load(data + pos.get_local_id(1) - shift),
pos.get_local_id(1) - shift);
}
// align data to vector start
Expand All @@ -713,9 +698,6 @@ struct ReduceOp {
shift = align_elements - shift;
}

// Do the vectorized reduction
using load_t = at::native::memory::aligned_vector<scalar_t, input_vec_size>;

index_t idx = config.input_idx(pos);
const index_t stride = config.step_input;

Expand All @@ -728,14 +710,12 @@ struct ReduceOp {
value_list[i] = ident;
}

load_t values;

while (idx * input_vec_size + input_vec_size - 1 < end) {
values = reinterpret_cast<const load_t*>(data)[idx];
const auto values_vec = memory::load_vector<input_vec_size>(data, idx);
#pragma unroll
for (index_t i = 0; i < input_vec_size; ++i) {
value_list[i] = ops.reduce(
value_list[i], values[i], shift + idx * input_vec_size + i);
value_list[i], values_vec[i], shift + idx * input_vec_size + i);
}
idx += stride;
}
Expand All @@ -745,7 +725,8 @@ struct ReduceOp {
if (config.should_reduce_tail(pos)) {
int idx = tail_start + pos.get_local_id(1);
if ((index_t)idx < end) {
value_list[0] = ops.reduce(value_list[0], data[idx], idx + shift);
const auto value = c10::load(data + idx);
value_list[0] = ops.reduce(value_list[0], value, idx + shift);
}
}

Expand All @@ -769,7 +750,6 @@ struct ReduceOp {
using arg_vec_t = at::detail::Array<arg_t, output_vec_size>;
using load_t =
at::native::memory::aligned_vector<scalar_t, output_vec_size>;
const load_t* data = reinterpret_cast<const load_t*>(data_);

arg_vec_t value_list[vt0];

Expand All @@ -786,7 +766,8 @@ struct ReduceOp {
while (idx + (vt0 - 1) * stride < end) {
#pragma unroll(vt0)
for (index_t i = 0; i < vt0; ++i) {
values[i] = data[calc(idx + i * stride) / output_vec_size];
const auto offset = calc(idx + i * stride) / output_vec_size;
values[i] = memory::load_vector<output_vec_size>(data_, offset);
}
#pragma unroll(vt0)
for (index_t i = 0; i < vt0; ++i) {
Expand All @@ -806,7 +787,8 @@ struct ReduceOp {
if (idx >= end) {
break;
}
values[i] = data[calc(idx) / output_vec_size];
const auto offset = calc(idx) / output_vec_size;
values[i] = memory::load_vector<output_vec_size>(data_, offset);
idx += stride;
}
idx = idx_;
Expand All @@ -833,133 +815,6 @@ struct ReduceOp {
return value_list[0];
}

// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t
// is a pair with half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
at::detail::Array<arg_t, output_vec_size> group_reduce_for_compound_dtype(
sycl::nd_item<2> pos,
at::detail::Array<arg_t, output_vec_size> value,
sycl_local_ptr<void> shared_memory) const {
auto sg = pos.get_sub_group();
uint32_t sbgrpSize = sg.get_local_range()[0];
int l_x = pos.get_local_linear_id();
int sg_lid = sg.get_local_linear_id();
int sg_gid = sg.get_group_linear_id();
int sg_range = sg.get_group_range()[0];

for (int offset = 1; offset < (int)sbgrpSize; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
arg_t other = sycl::shift_group_left(sg, value[i], offset);
value[i] = ops.combine(value[i], other);
}
}

using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
sycl_local_ptr<args_vec_t> shared{shared_memory};

if (sg_lid == 0) {
shared[sg_gid] = value;
}
pos.barrier(sycl_local_fence);

if (sg_range <= (int)sbgrpSize) {
// sub-group reduce
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; i++) {
value[i] = ident;
}
if (sg_gid == 0 && sg_lid < sg_range) {
value = shared[sg_lid];
for (int offset = 1; offset < sg_range; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
// Shuffle down separately for first and second pair.
at::xpu::
pair<typename arg_t::first_type, typename arg_t::second_type>
other = at::xpu::pair<
typename arg_t::first_type,
typename arg_t::second_type>(
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
}
} else {
// work item tree reduce
if (l_x < sg_range) {
value = shared[l_x];
}

for (int offset = sg_range / 2; offset > 0; offset >>= 1) {
if (l_x < offset) {
args_vec_t other = shared[l_x + offset];
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
value[i] = ops.combine(value[i], other[i]);
}
shared[l_x] = value;
}
pos.barrier(sycl_local_fence);
}
}

return value;
}

// TODO: Currently, there are bugs with sycl::shift_group_left when the arg_t
// is a pair for half dtype, We temporarily workaround to do
// "reduce_for_compound_dtype" function.
template <int output_vec_size>
at::detail::Array<arg_t, output_vec_size> group_x_reduce_for_compound_dtype(
sycl::nd_item<2> pos,
at::detail::Array<arg_t, output_vec_size> value,
sycl_local_ptr<void> shared_memory) const {
using args_vec_t = at::detail::Array<arg_t, output_vec_size>;
auto l_x = pos.get_local_id(1), l_y = pos.get_local_id(0);
auto gp_x = pos.get_local_range(1);

int dim_x = gp_x;
sycl_local_ptr<args_vec_t> shared(shared_memory);
auto sg = pos.get_sub_group();
uint32_t sbgrpSize = sg.get_local_range()[0];
if (dim_x > (int)sbgrpSize) {
int address_base = l_x + l_y * gp_x;
shared[address_base] = value;
for (int offset = dim_x / 2; offset >= (int)sbgrpSize; offset >>= 1) {
pos.barrier(sycl_local_fence);
if ((int)l_x < offset &&
(int)l_x + offset < (int)gp_x /* redundant??? */) {
args_vec_t other = shared[address_base + offset];
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
value[i] = ops.combine(value[i], other[i]);
}
shared[address_base] = value;
}
}
dim_x = sbgrpSize;
}

pos.barrier(sycl_local_fence);

// sub-group reduction
for (int offset = 1; offset < dim_x; offset <<= 1) {
#pragma unroll(output_vec_size)
for (int i = 0; i < output_vec_size; ++i) {
at::xpu::pair<typename arg_t::first_type, typename arg_t::second_type>
other = xpu::
pair<typename arg_t::first_type, typename arg_t::second_type>(
sycl::shift_group_left(sg, value[i].first, offset),
sycl::shift_group_left(sg, value[i].second, offset));
value[i] = ops.combine(value[i], other);
}
}
return value;
}

// In/out from slm pointers
void mark_group_finished(sycl::nd_item<2> pos, sycl_local_ptr<bool> finished)
const {
Expand Down Expand Up @@ -1124,20 +979,11 @@ struct ReduceOp {
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
if (config.should_group_x_reduce()) {
// TODO: workaround because sycl::shift_group_left will fail on `half`
// dtype.
constexpr bool is_pair =
std::is_same<at::xpu::pair<scalar_t, int64_t>, arg_t>::value;
if constexpr (is_pair) {
value = group_x_reduce_for_compound_dtype<output_vec_size>(
pos, value, shared_memory);
} else {
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
}
value = group_x_reduce<
arg_t,
decltype(pos),
decltype(combine),
output_vec_size>(pos, shared_memory, value, combine);
}
if (should_store) {
if (accumulate) {
Expand Down

0 comments on commit a9de4ae

Please sign in to comment.