Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::index_select_sparse_xpu op #1167

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ATen/native/xpu/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,5 +123,13 @@ Tensor& masked_select_out_xpu(
namedinference::compute_broadcast_outnames(self, mask);
return masked_select_out_impl(result, self, mask);
}

Tensor index_select_sparse_xpu(
const Tensor& self,
int64_t dim,
const Tensor& index) {
return xpu::index_select_sparse_kernel(self, dim, index);
}

} // namespace native
} // namespace at
271 changes: 271 additions & 0 deletions src/ATen/native/xpu/sycl/Indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
#include <ATen/native/xpu/sycl/IndexingUtils.h>
#include <ATen/native/xpu/sycl/Loops.h>
#include <ATen/native/xpu/sycl/pstl/PSTLFunctions.h>
#include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
#include <ATen/ops/arange.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/ones_like.h>
#include <ATen/ops/zeros_like.h>

#include <comm/SYCLContext.h>
#include <comm/TensorInfo.h>
Expand Down Expand Up @@ -1146,6 +1151,272 @@ void put_kernel(
});
}

// ForwardIt: only legacy random access iterator is supported.
template <class ForwardIt, class T, bool is_lower = true>
static ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) {
ForwardIt it;
typename std::iterator_traits<ForwardIt>::difference_type count, step;
// NOTE: std::distance(first, last) compiles but produces wrong results here,
// so only legacy random access iterators are safe in this code.
count = last - first;

while (count > 0) {
it = first;
step = count / 2;
// avoiding std::advance(it, step),
// although it does work unlike std::distance
it += step;
if (is_lower ? *it < value : value >= *it) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first;
}

template <typename index_t>
struct IndexSelectSparse1Functor {
index_t operator()(index_t idx) const {
SYCL_KERNEL_ASSERT(
idx >= -size && idx < size && "index_select(): index out of bounds");
return idx < 0 ? idx + size : idx;
}
IndexSelectSparse1Functor(index_t size) : size(size) {}

private:
index_t size;
};

template <typename index_t>
struct IndexSelectSparse2Functor {
index_t operator()(index_t idx_val, index_t idx_idx) const {
auto* lb = find_bound<const index_t*, index_t, true>(
ptr_sorted_dim_indices, ptr_sorted_dim_indices + nnz, idx_val);
auto* ub = find_bound<const index_t*, index_t, false>(
ptr_sorted_dim_indices, ptr_sorted_dim_indices + nnz, idx_val);
const auto idx_count = ub - lb;
ptr_intrsc_counts_nneg_index[idx_idx] = idx_count;

return lb - ptr_sorted_dim_indices;
}

IndexSelectSparse2Functor(
index_t* ptr_intrsc_counts_nneg_index,
const index_t* ptr_sorted_dim_indices,
int64_t nnz)
: ptr_intrsc_counts_nneg_index(ptr_intrsc_counts_nneg_index),
ptr_sorted_dim_indices(ptr_sorted_dim_indices),
nnz(nnz) {}

private:
index_t* ptr_intrsc_counts_nneg_index;
const index_t* ptr_sorted_dim_indices;
int64_t nnz;
};

template <typename index_t>
struct IndexSelectSparse3Functor {
index_t operator()(
index_t idx_idx,
index_t count,
index_t offset,
index_t first_match) const {
index_t* __restrict__ ptr_res_dim_indices_out =
ptr_res_dim_indices + offset;
const index_t* __restrict__ ptr_argsort_dim_indices_in =
ptr_argsort_dim_indices + first_match;
index_t* __restrict__ ptr_selected_dim_indices_out =
ptr_selected_dim_indices + offset;
for (index_t i = 0; i < count; ++i) {
*ptr_res_dim_indices_out++ = idx_idx;
*ptr_selected_dim_indices_out++ = *ptr_argsort_dim_indices_in++;
}
// A dummy return scalar for a dummy output
return static_cast<index_t>(1);
}
IndexSelectSparse3Functor(
index_t* ptr_res_dim_indices,
index_t* ptr_selected_dim_indices,
const index_t* ptr_argsort_dim_indices)
: ptr_res_dim_indices(ptr_res_dim_indices),
ptr_selected_dim_indices(ptr_selected_dim_indices),
ptr_argsort_dim_indices(ptr_argsort_dim_indices) {}

private:
index_t* ptr_res_dim_indices;
index_t* ptr_selected_dim_indices;
const index_t* ptr_argsort_dim_indices;
};

Tensor index_select_sparse_kernel(
const Tensor& self,
int64_t dim,
const Tensor& index) {
const auto ndim = self.dim();
TORCH_CHECK_INDEX(
ndim, "index_select() cannot be applied to a 0-dim tensor.");
TORCH_CHECK_INDEX(
index.dim() == 1 && index.dtype() == at::kLong &&
index.options().layout() == at::kStrided,
"index_select() argument index must be 1-D strided (non-sparse) long-tensor.");

dim = maybe_wrap_dim(dim, ndim);
const auto size = self.size(dim);
const auto sparse_dim = self.sparse_dim();
const auto dense_dim = self.dense_dim();
const auto indices = self._indices();
const auto values = self._values();
const auto nnz = values.size(0);
const auto index_len = index.size(0);
auto res_sizes = self.sizes().vec();
res_sizes[dim] = index_len;

// If indexing into sparse dimensions
if (dim < sparse_dim) {
const auto make_output =
[dim, sparse_dim, dense_dim, res_sizes, &self, &indices, &values](
const Tensor& selected_dim_indices,
const Tensor& res_dim_indices) -> Tensor {
auto res_indices = indices.index_select(1, selected_dim_indices);
res_indices[dim] = res_dim_indices;
const auto res_values = values.index_select(0, selected_dim_indices);

return at::_sparse_coo_tensor_with_dims_and_tensors(
sparse_dim,
dense_dim,
res_sizes,
res_indices,
res_values,
self.options());
};

// short-circuit if index is empty
if (!index_len) {
return make_output(index, index);
}

const auto nneg_index = [&index, size]() -> Tensor {
auto nneg_index = at::empty_like(index, at::MemoryFormat::Contiguous);

auto iter = TensorIteratorConfig()
.add_output(nneg_index)
.add_input(index)
.build();

AT_DISPATCH_INDEX_TYPES(
index.scalar_type(), "index_select_sparse_xpu", [&]() {
gpu_kernel(iter, IndexSelectSparse1Functor<index_t>(size));
});
return nneg_index;
}();

const auto dim_indices = indices[dim].contiguous();
const auto idx_nneg_index = at::arange(index_len, nneg_index.options());
const auto idx_dim_indices = at::arange(nnz, dim_indices.options());

Tensor sorted_dim_indices, argsort_dim_indices;
std::tie(sorted_dim_indices, argsort_dim_indices) =
[&]() -> std::tuple<Tensor, Tensor> {
if (dim == 0 && self.is_coalesced()) {
return std::make_tuple(dim_indices, idx_dim_indices);
} else {
return dim_indices.sort();
}
}();

Tensor intrsc_counts_nneg_index;
Tensor intrsc_first_match_nneg_index;
std::tie(intrsc_counts_nneg_index, intrsc_first_match_nneg_index) =
[&]() -> std::tuple<Tensor, Tensor> {
auto intrsc_counts_nneg_index = at::zeros_like(nneg_index);
auto intrsc_first_match_nneg_index = at::zeros_like(nneg_index);

auto iter = TensorIteratorConfig()
.add_output(intrsc_first_match_nneg_index)
.add_input(nneg_index)
.add_input(idx_nneg_index)
.build();

AT_DISPATCH_INDEX_TYPES(
nneg_index.scalar_type(), "index_select_sparse_xpu", [&]() {
index_t* ptr_intrsc_counts_nneg_index =
intrsc_counts_nneg_index.mutable_data_ptr<index_t>();
const index_t* ptr_sorted_dim_indices =
sorted_dim_indices.const_data_ptr<index_t>();
gpu_kernel(
iter,
IndexSelectSparse2Functor<index_t>(
ptr_intrsc_counts_nneg_index, ptr_sorted_dim_indices, nnz));
});

return std::make_tuple(
intrsc_counts_nneg_index, intrsc_first_match_nneg_index);
}();

// Unavoidable sync since the shape of the result is not known in advance
auto res_len = intrsc_counts_nneg_index.sum().item<int64_t>();
// Short-circuit if empty intersection
if (!res_len) {
auto empty_idx = at::empty({0}, nneg_index.options());
return make_output(empty_idx, empty_idx);
}

auto [selected_dim_indices, res_dim_indices] =
[&]() -> std::tuple<Tensor, Tensor> {
auto res_dim_indices = at::empty({res_len}, nneg_index.options());
auto selected_dim_indices = at::empty_like(res_dim_indices);
auto selected_dim_indices_offsets =
intrsc_counts_nneg_index.cumsum(0).sub_(intrsc_counts_nneg_index);

// Need to have output as TensorIterator does not allow having void
// lambdas.
auto dummy_output = at::empty({1}, dim_indices.options())
.expand(IntArrayRef({index_len}));
auto iter = TensorIteratorConfig()
.add_output(dummy_output)
// All iterations map to a single element in dummy_output
// by design, hence removed output memory overlap check.
.set_check_mem_overlap(false)
.add_input(idx_nneg_index)
.add_input(intrsc_counts_nneg_index)
.add_input(selected_dim_indices_offsets)
.add_input(intrsc_first_match_nneg_index)
.build();

AT_DISPATCH_INDEX_TYPES(
nneg_index.scalar_type(), "index_select_sparse_xpu", [&]() {
index_t* ptr_res_dim_indices =
res_dim_indices.mutable_data_ptr<index_t>();
index_t* ptr_selected_dim_indices =
selected_dim_indices.mutable_data_ptr<index_t>();
const index_t* ptr_argsort_dim_indices =
argsort_dim_indices.const_data_ptr<index_t>();
gpu_kernel(
iter,
IndexSelectSparse3Functor<index_t>(
ptr_res_dim_indices,
ptr_selected_dim_indices,
ptr_argsort_dim_indices));
});

return std::make_tuple(selected_dim_indices, res_dim_indices);
}();

return make_output(selected_dim_indices, res_dim_indices);
}
// If indexing into dense dimensions
else {
// It is sufficient to just perform `index_select` on values
// if `dim` refers to dense dimensions.
const auto res_values = values.index_select(dim - sparse_dim + 1, index);

return _sparse_coo_tensor_with_dims_and_tensors(
sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
}
}

} // namespace at::native::xpu

#pragma GCC diagnostic pop
Expand Down
5 changes: 5 additions & 0 deletions src/ATen/native/xpu/sycl/IndexingKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,9 @@ TORCH_XPU_API void put_kernel(

TORCH_XPU_API void take_kernel(TensorIterator& iter, const TensorBase& input);

TORCH_XPU_API Tensor index_select_sparse_kernel(
const Tensor& self,
int64_t dim,
const Tensor& index);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,7 @@
variants: method, function
dispatch:
XPU: index_select_xpu_
SparseXPU: index_select_sparse_xpu
tags: core

- func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
Expand Down
Loading