From fefef2cc38bc2c7ddc875395d63137b62e5a4977 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 13 Dec 2024 05:50:48 +0000 Subject: [PATCH 1/2] add index_select_sparse_xpu --- src/ATen/native/xpu/Indexing.cpp | 8 + src/ATen/native/xpu/sycl/Indexing.cpp | 271 +++++++++++++++++++++ src/ATen/native/xpu/sycl/IndexingKernels.h | 5 + yaml/native/native_functions.yaml | 1 + 4 files changed, 285 insertions(+) diff --git a/src/ATen/native/xpu/Indexing.cpp b/src/ATen/native/xpu/Indexing.cpp index 6ba148607..bb8c07a92 100644 --- a/src/ATen/native/xpu/Indexing.cpp +++ b/src/ATen/native/xpu/Indexing.cpp @@ -123,5 +123,13 @@ Tensor& masked_select_out_xpu( namedinference::compute_broadcast_outnames(self, mask); return masked_select_out_impl(result, self, mask); } + +Tensor index_select_sparse_xpu( + const Tensor& self, + int64_t dim, + const Tensor& index) { + return xpu::index_select_sparse_kernel(self, dim, index); +} + } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index bcbd50c42..7c05f8f54 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -14,6 +14,11 @@ #include #include #include +#include +#include +#include +#include +#include #include #include @@ -1146,6 +1151,272 @@ void put_kernel( }); } +// ForwardIt: only legacy random access iterator is supported. +template +static ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) { + ForwardIt it; + typename std::iterator_traits::difference_type count, step; + // NOTE: std::distance(first, last) compiles but produces wrong results here, + // so only legacy random access iterators are safe in this code. + count = last - first; + + while (count > 0) { + it = first; + step = count / 2; + // avoiding std::advance(it, step), + // although it does work unlike std::distance + it += step; + if (is_lower ? *it < value : value >= *it) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first; +} + +template +struct IndexSelectSparse1Functor { + index_t operator()(index_t idx) const { + SYCL_KERNEL_ASSERT( + idx >= -size && idx < size && "index_select(): index out of bounds"); + return idx < 0 ? idx + size : idx; + } + IndexSelectSparse1Functor(index_t size) : size(size) {} + + private: + index_t size; +}; + +template +struct IndexSelectSparse2Functor { + index_t operator()(index_t idx_val, index_t idx_idx) const { + auto* lb = find_bound( + ptr_sorted_dim_indices, ptr_sorted_dim_indices + nnz, idx_val); + auto* ub = find_bound( + ptr_sorted_dim_indices, ptr_sorted_dim_indices + nnz, idx_val); + const auto idx_count = ub - lb; + ptr_intrsc_counts_nneg_index[idx_idx] = idx_count; + + return lb - ptr_sorted_dim_indices; + } + + IndexSelectSparse2Functor( + index_t* ptr_intrsc_counts_nneg_index, + const index_t* ptr_sorted_dim_indices, + int64_t nnz) + : ptr_intrsc_counts_nneg_index(ptr_intrsc_counts_nneg_index), + ptr_sorted_dim_indices(ptr_sorted_dim_indices), + nnz(nnz) {} + + private: + index_t* ptr_intrsc_counts_nneg_index; + const index_t* ptr_sorted_dim_indices; + int64_t nnz; +}; + +template +struct IndexSelectSparse3Functor { + index_t operator()( + index_t idx_idx, + index_t count, + index_t offset, + index_t first_match) const { + index_t* __restrict__ ptr_res_dim_indices_out = + ptr_res_dim_indices + offset; + const index_t* __restrict__ ptr_argsort_dim_indices_in = + ptr_argsort_dim_indices + first_match; + index_t* __restrict__ ptr_selected_dim_indices_out = + ptr_selected_dim_indices + offset; + for (index_t i = 0; i < count; ++i) { + *ptr_res_dim_indices_out++ = idx_idx; + *ptr_selected_dim_indices_out++ = *ptr_argsort_dim_indices_in++; + } + // A dummy return scalar for a dummy output + return static_cast(1); + } + IndexSelectSparse3Functor( + index_t* ptr_res_dim_indices, + index_t* ptr_selected_dim_indices, + const index_t* ptr_argsort_dim_indices) + : ptr_res_dim_indices(ptr_res_dim_indices), + ptr_selected_dim_indices(ptr_selected_dim_indices), + ptr_argsort_dim_indices(ptr_argsort_dim_indices) {} + + private: + index_t* ptr_res_dim_indices; + index_t* ptr_selected_dim_indices; + const index_t* ptr_argsort_dim_indices; +}; + +Tensor index_select_sparse_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index) { + const auto ndim = self.dim(); + TORCH_CHECK_INDEX( + ndim, "index_select() cannot be applied to a 0-dim tensor."); + TORCH_CHECK_INDEX( + index.dim() == 1 && index.dtype() == at::kLong && + index.options().layout() == at::kStrided, + "index_select() argument index must be 1-D strided (non-sparse) long-tensor."); + + dim = maybe_wrap_dim(dim, ndim); + const auto size = self.size(dim); + const auto sparse_dim = self.sparse_dim(); + const auto dense_dim = self.dense_dim(); + const auto indices = self._indices(); + const auto values = self._values(); + const auto nnz = values.size(0); + const auto index_len = index.size(0); + auto res_sizes = self.sizes().vec(); + res_sizes[dim] = index_len; + + // If indexing into sparse dimensions + if (dim < sparse_dim) { + const auto make_output = + [dim, sparse_dim, dense_dim, res_sizes, &self, &indices, &values]( + const Tensor& selected_dim_indices, + const Tensor& res_dim_indices) -> Tensor { + auto res_indices = indices.index_select(1, selected_dim_indices); + res_indices[dim] = res_dim_indices; + const auto res_values = values.index_select(0, selected_dim_indices); + + return at::_sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, + dense_dim, + res_sizes, + res_indices, + res_values, + self.options()); + }; + + // short-circuit if index is empty + if (!index_len) { + return make_output(index, index); + } + + const auto nneg_index = [&index, size]() -> Tensor { + auto nneg_index = at::empty_like(index, at::MemoryFormat::Contiguous); + + auto iter = TensorIteratorConfig() + .add_output(nneg_index) + .add_input(index) + .build(); + + AT_DISPATCH_INDEX_TYPES( + index.scalar_type(), "index_select_sparse_xpu", [&]() { + gpu_kernel(iter, IndexSelectSparse1Functor(size)); + }); + return nneg_index; + }(); + + const auto dim_indices = indices[dim].contiguous(); + const auto idx_nneg_index = at::arange(index_len, nneg_index.options()); + const auto idx_dim_indices = at::arange(nnz, dim_indices.options()); + + Tensor sorted_dim_indices, argsort_dim_indices; + std::tie(sorted_dim_indices, argsort_dim_indices) = + [&]() -> std::tuple { + if (dim == 0 && self.is_coalesced()) { + return std::make_tuple(dim_indices, idx_dim_indices); + } else { + return dim_indices.sort(); + } + }(); + + Tensor intrsc_counts_nneg_index; + Tensor intrsc_first_match_nneg_index; + std::tie(intrsc_counts_nneg_index, intrsc_first_match_nneg_index) = + [&]() -> std::tuple { + auto intrsc_counts_nneg_index = at::zeros_like(nneg_index); + auto intrsc_first_match_nneg_index = at::zeros_like(nneg_index); + + auto iter = TensorIteratorConfig() + .add_output(intrsc_first_match_nneg_index) + .add_input(nneg_index) + .add_input(idx_nneg_index) + .build(); + + AT_DISPATCH_INDEX_TYPES( + nneg_index.scalar_type(), "index_select_sparse_xpu", [&]() { + index_t* ptr_intrsc_counts_nneg_index = + intrsc_counts_nneg_index.mutable_data_ptr(); + const index_t* ptr_sorted_dim_indices = + sorted_dim_indices.const_data_ptr(); + gpu_kernel( + iter, + IndexSelectSparse2Functor( + ptr_intrsc_counts_nneg_index, ptr_sorted_dim_indices, nnz)); + }); + + return std::make_tuple( + intrsc_counts_nneg_index, intrsc_first_match_nneg_index); + }(); + + // Unavoidable sync since the shape of the result is not known in advance + auto res_len = intrsc_counts_nneg_index.sum().item(); + // Short-circuit if empty intersection + if (!res_len) { + auto empty_idx = at::empty({0}, nneg_index.options()); + return make_output(empty_idx, empty_idx); + } + + auto [selected_dim_indices, res_dim_indices] = + [&]() -> std::tuple { + auto res_dim_indices = at::empty({res_len}, nneg_index.options()); + auto selected_dim_indices = at::empty_like(res_dim_indices); + auto selected_dim_indices_offsets = + intrsc_counts_nneg_index.cumsum(0).sub_(intrsc_counts_nneg_index); + + // Need to have output as TensorIterator does not allow having void + // lambdas. + auto dummy_output = at::empty({1}, dim_indices.options()) + .expand(IntArrayRef({index_len})); + auto iter = TensorIteratorConfig() + .add_output(dummy_output) + // All iterations map to a single element in dummy_output + // by design, hence removed output memory overlap check. + .set_check_mem_overlap(false) + .add_input(idx_nneg_index) + .add_input(intrsc_counts_nneg_index) + .add_input(selected_dim_indices_offsets) + .add_input(intrsc_first_match_nneg_index) + .build(); + + AT_DISPATCH_INDEX_TYPES( + nneg_index.scalar_type(), "index_select_sparse_xpu", [&]() { + index_t* ptr_res_dim_indices = + res_dim_indices.mutable_data_ptr(); + index_t* ptr_selected_dim_indices = + selected_dim_indices.mutable_data_ptr(); + const index_t* ptr_argsort_dim_indices = + argsort_dim_indices.const_data_ptr(); + gpu_kernel( + iter, + IndexSelectSparse3Functor( + ptr_res_dim_indices, + ptr_selected_dim_indices, + ptr_argsort_dim_indices)); + }); + + return std::make_tuple(selected_dim_indices, res_dim_indices); + }(); + + return make_output(selected_dim_indices, res_dim_indices); + } + // If indexing into dense dimensions + else { + // It is sufficient to just perform `index_select` on values + // if `dim` refers to dense dimensions. + const auto res_values = values.index_select(dim - sparse_dim + 1, index); + + return _sparse_coo_tensor_with_dims_and_tensors( + sparse_dim, dense_dim, res_sizes, indices, res_values, self.options()); + } +} + } // namespace at::native::xpu #pragma GCC diagnostic pop diff --git a/src/ATen/native/xpu/sycl/IndexingKernels.h b/src/ATen/native/xpu/sycl/IndexingKernels.h index 87deedaa5..e5f434585 100644 --- a/src/ATen/native/xpu/sycl/IndexingKernels.h +++ b/src/ATen/native/xpu/sycl/IndexingKernels.h @@ -65,4 +65,9 @@ TORCH_XPU_API void put_kernel( TORCH_XPU_API void take_kernel(TensorIterator& iter, const TensorBase& input); +TORCH_XPU_API Tensor index_select_sparse_kernel( + const Tensor& self, + int64_t dim, + const Tensor& index); + } // namespace at::native::xpu diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index f76f49fb8..9b5386db0 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -796,6 +796,7 @@ variants: method, function dispatch: XPU: index_select_xpu_ + SparseCPU: index_select_sparse_xpu tags: core - func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) From 26766cdfb811dcf1614c37548020952f9f49c5c4 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 13 Dec 2024 05:53:03 +0000 Subject: [PATCH 2/2] fix --- yaml/native/native_functions.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index 9b5386db0..9aff1c8cc 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -796,7 +796,7 @@ variants: method, function dispatch: XPU: index_select_xpu_ - SparseCPU: index_select_sparse_xpu + SparseXPU: index_select_sparse_xpu tags: core - func: gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)