Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::topk and its variants #547

Merged
merged 11 commits into from
Jul 23, 2024
92 changes: 92 additions & 0 deletions src/ATen/native/xpu/TensorTopK.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/TensorTopKKernel.h>
#include <ATen/xpu/XPUNativeFunctions.h>
#include <comm/RegisterUtils.h>

namespace at {

void topk_meta(
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted,
Tensor& values,
Tensor& indices) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");

// Build the output size, which is the dim being selected set to
// size k
DimVector topKSize(self.sizes().vec());
if (!topKSize.empty()) {
topKSize[dim] = k;
}

if (values.defined()) {
at::xpu::resize_out(values, topKSize, {}, self.options());
} else {
values = at::xpu::create_out(topKSize, {}, self.options());
}

if (indices.defined()) {
at::xpu::resize_out(indices, topKSize, {}, self.options().dtype(at::kLong));
} else {
indices =
at::xpu::create_out(topKSize, {}, self.options().dtype(at::kLong));
}
}

void topk_out_impl(
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted,
Tensor& values,
Tensor& indices) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
xytintel marked this conversation as resolved.
Show resolved Hide resolved

if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
} else {
native::xpu::topk_kernel(self, k, dim, largest, sorted, values, indices);
}
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::topk(
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
Tensor values, indices;
topk_meta(self, k, dim, largest, sorted, values, indices);
topk_out_impl(self, k, dim, largest, sorted, values, indices);
return std::tuple<Tensor, Tensor>(values, indices);
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::topk_out(
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted,
Tensor& values,
Tensor& indices) {
topk_meta(self, k, dim, largest, sorted, values, indices);
topk_out_impl(self, k, dim, largest, sorted, values, indices);
return std::forward_as_tuple(values, indices);
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"special_zeta.out",
"take",
"_thnn_fused_gru_cell",
"topk.values",
"_to_sparse",
"_to_sparse_csr",
"triangular_solve.X",
Expand Down
224 changes: 224 additions & 0 deletions src/ATen/native/xpu/sycl/SortingKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,124 @@ void segmented_radix_sort_pairs_kernel(
}
}

// ======================= group radix select =======================

template <int n, typename index_t>
inline index_t make_alignment_n(index_t size) {
return (size + n - 1) / n * n;
}

template <typename method_t, typename key_t, typename value_t>
struct SegmentedGroupRadixSelectPairsFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
enum {
MAX_KV_BYTES = std::max(sizeof(key_t), sizeof(value_t)),
};

[[intel::reqd_sub_group_size(method_t::SUBGROUP_SIZE)]] void operator()(
sycl::nd_item<1> item) const {
int seg_idx = item.get_group(0);
int seg_offset = seg_idx * nelements_;
auto method = method_t(item, slm_);

auto keys_in_seg = keys_in_ + seg_offset;
auto values_in_seg =
values_in_ == nullptr ? nullptr : values_in_ + seg_offset;

key_t* keys_temp = reinterpret_cast<key_t*>(
slm_.template get_multi_ptr<sycl::access::decorated::no>().get() +
make_alignment_n<MAX_KV_BYTES>(method_t::LocalMemorySize()));
value_t* values_temp = reinterpret_cast<value_t*>(
reinterpret_cast<char*>(keys_temp) +
make_alignment_n<MAX_KV_BYTES>(k_ * sizeof(key_t)));

method.load_keys(keys_in_seg, nelements_);
method.load_values(values_in_seg, nelements_);

int num_start = method_t::PROCESSING_LENGTH;
while (num_start < nelements_) {
method.topk(KeyTraits<key_t>::endbit(), 0, k_, keys_temp, values_temp);
item.barrier(sycl_local_fence);
method.topk_append_keys(
keys_in_seg, keys_temp, nelements_, num_start, k_);
method.topk_append_values(
values_in_seg, values_temp, nelements_, num_start, k_);
num_start += method_t::PROCESSING_LENGTH - k_;
item.barrier(sycl_local_fence);
}

method.topk(
KeyTraits<key_t>::endbit(),
0,
k_,
keys_out_ + seg_idx * k_,
values_out_ + seg_idx * k_);
}

void sycl_ker_config_convention(sycl::handler& cgh) {
slm_ = sycl_local_acc_t<char>(
make_alignment_n<MAX_KV_BYTES>(method_t::LocalMemorySize()) +
make_alignment_n<MAX_KV_BYTES>(k_ * sizeof(key_t)) +
k_ * sizeof(value_t),
fengyuan14 marked this conversation as resolved.
Show resolved Hide resolved
cgh);
}

SegmentedGroupRadixSelectPairsFunctor(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int nelements,
int k)
: keys_in_(keys_in),
keys_out_(keys_out),
values_in_(values_in),
values_out_(values_out),
nelements_(nelements),
k_(k) {}

private:
const key_t* keys_in_;
key_t* keys_out_;
const value_t* values_in_;
value_t* values_out_;
int nelements_;
int k_;
sycl_local_acc_t<char> slm_;
};

template <
typename key_t,
typename value_t,
bool IS_DESCENDING,
int KEYS_PER_ITEM,
int GROUP_SIZE,
int SUBGROUP_SIZE>
inline void group_radix_select_pairs_kernel(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int num_segments,
int num_elements,
int k) {
using method_t = GroupRadixSort<
key_t,
GROUP_SIZE,
SUBGROUP_SIZE,
KEYS_PER_ITEM,
IS_DESCENDING,
value_t>;
TORCH_CHECK(k <= method_t::PROCESSING_LENGTH);
auto caller = SegmentedGroupRadixSelectPairsFunctor<method_t, key_t, value_t>(
keys_in, keys_out, values_in, values_out, num_elements, k);
sycl_kernel_submit(
num_segments * GROUP_SIZE,
GROUP_SIZE,
at::xpu::getCurrentSYCLQueue(),
caller);
}

// ======================= interface =======================

// NOTE: Subgroup size of 32 provides better performance currently.
Expand Down Expand Up @@ -507,6 +625,112 @@ void sort_pairs(
keys_in, keys_out, values_in, values_out, 1, num_elements, descending);
}

inline uint64_t radix_select_last_power2(uint64_t n) {
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n++;
return n;
}

template <
typename key_t,
typename value_t,
bool IS_DESCENDING,
int SUBGROUP_SIZE = 32>
void segmented_group_select_pairs_(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int num_segments,
int num_elements,
int k) {
#define RUN_RADIX_SELECT(PADDED_N) \
{ \
group_radix_select_pairs_kernel< \
key_t, \
value_t, \
IS_DESCENDING, \
4, \
PADDED_N / 4, \
SUBGROUP_SIZE>( \
keys_in, \
keys_out, \
values_in, \
values_out, \
num_segments, \
num_elements, \
k); \
}
constexpr int max_group_size = 1024; // simd32-specific
if (num_elements <= max_group_size * 4) {
switch (radix_select_last_power2(num_elements)) {
case 4096:
RUN_RADIX_SELECT(4096); // gsz 1024
break;
case 2048:
RUN_RADIX_SELECT(2048); // gsz 512
break;
case 1024:
RUN_RADIX_SELECT(1024); // gsz 256
break;
case 512:
RUN_RADIX_SELECT(512); // gsz 128
break;
default:
RUN_RADIX_SELECT(256); // gsz 64
break;
}
} else {
switch (max_group_size) {
case 1024:
RUN_RADIX_SELECT(4096);
break;
case 512:
RUN_RADIX_SELECT(2048);
break;
default:
RUN_RADIX_SELECT(1024);
break;
}
}
#undef RUN_RADIX_SELECT
}

template <typename key_t, typename value_t>
void segmented_group_select_pairs(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int num_segments,
int num_elements,
int k,
bool largest) {
if (largest)
segmented_group_select_pairs_<key_t, value_t, true>(
keys_in,
keys_out,
values_in,
values_out,
num_segments,
num_elements,
k);
else
segmented_group_select_pairs_<key_t, value_t, false>(
keys_in,
keys_out,
values_in,
values_out,
num_segments,
num_elements,
k);
}

} // namespace xpu
} // namespace native
} // namespace at
Loading