Skip to content

Commit

Permalink
Add aten::topk and its variants (#547)
Browse files Browse the repository at this point in the history
Task list:
- [x] topk
- [x] topk.values
  • Loading branch information
xytintel authored Jul 23, 2024
1 parent 6bb1633 commit 1e61056
Show file tree
Hide file tree
Showing 8 changed files with 741 additions and 1 deletion.
97 changes: 97 additions & 0 deletions src/ATen/native/xpu/TensorTopK.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <ATen/ATen.h>
#include <ATen/core/Tensor.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/xpu/sycl/TensorTopKKernel.h>
#include <ATen/xpu/XPUNativeFunctions.h>
#include <comm/RegisterUtils.h>

namespace at {

void topk_meta(
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted,
Tensor& values,
Tensor& indices) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");
int64_t sliceSize = self.dim() == 0 ? 1 : self.size(dim);
TORCH_CHECK(k >= 0 && k <= sliceSize, "k not in range for dimension");

// Build the output size, which is the dim being selected set to
// size k
DimVector topKSize(self.sizes().vec());
if (!topKSize.empty()) {
topKSize[dim] = k;
}

if (values.defined()) {
at::xpu::resize_out(values, topKSize, {}, self.options());
} else {
values = at::xpu::create_out(topKSize, {}, self.options());
}

if (indices.defined()) {
at::xpu::resize_out(indices, topKSize, {}, self.options().dtype(at::kLong));
} else {
indices =
at::xpu::create_out(topKSize, {}, self.options().dtype(at::kLong));
}
}

void topk_out_impl(
const Tensor& self,
int64_t k,
int64_t dim_,
bool largest,
bool sorted,
Tensor& values,
Tensor& indices) {
int64_t dim = maybe_wrap_dim(dim_, self.dim(), /*wrap_scalar=*/true);
TORCH_CHECK(
k >= 0 && k <= (self.dim() > 0 ? self.size(dim) : 1),
"selected index k out of range");

// If k is 0 the result is an empty tensor, so we don't need to launch a kernel.
if (k == 0) {
return;
}

if (self.dim() == 0 && self.numel() == 1) {
values.copy_(self);
indices.zero_();
} else {
native::xpu::topk_kernel(self, k, dim, largest, sorted, values, indices);
}
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::topk(
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted) {
Tensor values, indices;
topk_meta(self, k, dim, largest, sorted, values, indices);
topk_out_impl(self, k, dim, largest, sorted, values, indices);
return std::tuple<Tensor, Tensor>(values, indices);
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::topk_out(
const Tensor& self,
int64_t k,
int64_t dim,
bool largest,
bool sorted,
Tensor& values,
Tensor& indices) {
topk_meta(self, k, dim, largest, sorted, values, indices);
topk_out_impl(self, k, dim, largest, sorted, values, indices);
return std::forward_as_tuple(values, indices);
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"special_zeta.out",
"take",
"_thnn_fused_gru_cell",
"topk.values",
"_to_sparse",
"_to_sparse_csr",
"triangular_solve.X",
Expand Down
224 changes: 224 additions & 0 deletions src/ATen/native/xpu/sycl/SortingKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,124 @@ void segmented_radix_sort_pairs_kernel(
}
}

// ======================= group radix select =======================

template <int n, typename index_t>
inline index_t make_alignment_n(index_t size) {
return (size + n - 1) / n * n;
}

template <typename method_t, typename key_t, typename value_t>
struct SegmentedGroupRadixSelectPairsFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
enum {
MAX_KV_BYTES = std::max(sizeof(key_t), sizeof(value_t)),
};

[[intel::reqd_sub_group_size(method_t::SUBGROUP_SIZE)]] void operator()(
sycl::nd_item<1> item) const {
int seg_idx = item.get_group(0);
int seg_offset = seg_idx * nelements_;
auto method = method_t(item, slm_);

auto keys_in_seg = keys_in_ + seg_offset;
auto values_in_seg =
values_in_ == nullptr ? nullptr : values_in_ + seg_offset;

key_t* keys_temp = reinterpret_cast<key_t*>(
slm_.template get_multi_ptr<sycl::access::decorated::no>().get() +
make_alignment_n<MAX_KV_BYTES>(method_t::LocalMemorySize()));
value_t* values_temp = reinterpret_cast<value_t*>(
reinterpret_cast<char*>(keys_temp) +
make_alignment_n<MAX_KV_BYTES>(k_ * sizeof(key_t)));

method.load_keys(keys_in_seg, nelements_);
method.load_values(values_in_seg, nelements_);

int num_start = method_t::PROCESSING_LENGTH;
while (num_start < nelements_) {
method.topk(KeyTraits<key_t>::endbit(), 0, k_, keys_temp, values_temp);
item.barrier(sycl_local_fence);
method.topk_append_keys(
keys_in_seg, keys_temp, nelements_, num_start, k_);
method.topk_append_values(
values_in_seg, values_temp, nelements_, num_start, k_);
num_start += method_t::PROCESSING_LENGTH - k_;
item.barrier(sycl_local_fence);
}

method.topk(
KeyTraits<key_t>::endbit(),
0,
k_,
keys_out_ + seg_idx * k_,
values_out_ + seg_idx * k_);
}

void sycl_ker_config_convention(sycl::handler& cgh) {
slm_ = sycl_local_acc_t<char>(
make_alignment_n<MAX_KV_BYTES>(method_t::LocalMemorySize()) +
make_alignment_n<MAX_KV_BYTES>(k_ * sizeof(key_t)) +
k_ * sizeof(value_t),
cgh);
}

SegmentedGroupRadixSelectPairsFunctor(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int nelements,
int k)
: keys_in_(keys_in),
keys_out_(keys_out),
values_in_(values_in),
values_out_(values_out),
nelements_(nelements),
k_(k) {}

private:
const key_t* keys_in_;
key_t* keys_out_;
const value_t* values_in_;
value_t* values_out_;
int nelements_;
int k_;
sycl_local_acc_t<char> slm_;
};

template <
typename key_t,
typename value_t,
bool IS_DESCENDING,
int KEYS_PER_ITEM,
int GROUP_SIZE,
int SUBGROUP_SIZE>
inline void group_radix_select_pairs_kernel(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int num_segments,
int num_elements,
int k) {
using method_t = GroupRadixSort<
key_t,
GROUP_SIZE,
SUBGROUP_SIZE,
KEYS_PER_ITEM,
IS_DESCENDING,
value_t>;
TORCH_CHECK(k <= method_t::PROCESSING_LENGTH);
auto caller = SegmentedGroupRadixSelectPairsFunctor<method_t, key_t, value_t>(
keys_in, keys_out, values_in, values_out, num_elements, k);
sycl_kernel_submit(
num_segments * GROUP_SIZE,
GROUP_SIZE,
at::xpu::getCurrentSYCLQueue(),
caller);
}

// ======================= interface =======================

// NOTE: Subgroup size of 32 provides better performance currently.
Expand Down Expand Up @@ -512,6 +630,112 @@ void sort_pairs(
keys_in, keys_out, values_in, values_out, 1, num_elements, descending);
}

inline uint64_t radix_select_last_power2(uint64_t n) {
n--;
n |= n >> 1;
n |= n >> 2;
n |= n >> 4;
n |= n >> 8;
n |= n >> 16;
n++;
return n;
}

template <
typename key_t,
typename value_t,
bool IS_DESCENDING,
int SUBGROUP_SIZE = 32>
void segmented_group_select_pairs_(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int num_segments,
int num_elements,
int k) {
#define RUN_RADIX_SELECT(PADDED_N) \
{ \
group_radix_select_pairs_kernel< \
key_t, \
value_t, \
IS_DESCENDING, \
4, \
PADDED_N / 4, \
SUBGROUP_SIZE>( \
keys_in, \
keys_out, \
values_in, \
values_out, \
num_segments, \
num_elements, \
k); \
}
constexpr int max_group_size = 1024; // simd32-specific
if (num_elements <= max_group_size * 4) {
switch (radix_select_last_power2(num_elements)) {
case 4096:
RUN_RADIX_SELECT(4096); // gsz 1024
break;
case 2048:
RUN_RADIX_SELECT(2048); // gsz 512
break;
case 1024:
RUN_RADIX_SELECT(1024); // gsz 256
break;
case 512:
RUN_RADIX_SELECT(512); // gsz 128
break;
default:
RUN_RADIX_SELECT(256); // gsz 64
break;
}
} else {
switch (max_group_size) {
case 1024:
RUN_RADIX_SELECT(4096);
break;
case 512:
RUN_RADIX_SELECT(2048);
break;
default:
RUN_RADIX_SELECT(1024);
break;
}
}
#undef RUN_RADIX_SELECT
}

template <typename key_t, typename value_t>
void segmented_group_select_pairs(
const key_t* keys_in,
key_t* keys_out,
const value_t* values_in,
value_t* values_out,
int num_segments,
int num_elements,
int k,
bool largest) {
if (largest)
segmented_group_select_pairs_<key_t, value_t, true>(
keys_in,
keys_out,
values_in,
values_out,
num_segments,
num_elements,
k);
else
segmented_group_select_pairs_<key_t, value_t, false>(
keys_in,
keys_out,
values_in,
values_out,
num_segments,
num_elements,
k);
}

} // namespace xpu
} // namespace native
} // namespace at
Loading

0 comments on commit 1e61056

Please sign in to comment.