From 4dffbc070b3a1a51b3ff2f48df4b1ce102bf2121 Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Tue, 9 Jul 2024 22:06:11 +0800 Subject: [PATCH] Add aten::isin and its variant operators Tensor_Tensor/Scalar_Tensor/Tensor_Scalar Signed-off-by: Feng Yuan --- src/ATen/native/xpu/TensorCompare.cpp | 232 +++++++++++++++++- src/ATen/native/xpu/XPUFallback.template | 1 - src/ATen/native/xpu/sycl/IndexingUtils.h | 2 +- .../native/xpu/sycl/TensorCompareKernels.cpp | 12 + .../native/xpu/sycl/TensorCompareKernels.h | 14 +- test/xpu/run_test_with_skip.py | 10 - test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 6 + 8 files changed, 259 insertions(+), 19 deletions(-) diff --git a/src/ATen/native/xpu/TensorCompare.cpp b/src/ATen/native/xpu/TensorCompare.cpp index 680b40c87..7b8ac819a 100644 --- a/src/ATen/native/xpu/TensorCompare.cpp +++ b/src/ATen/native/xpu/TensorCompare.cpp @@ -552,7 +552,7 @@ ::std::tuple XPUNativeFunctions::min_out( return {values, indices}; } -::std::tuple XPUNativeFunctions::max_out( +std::tuple XPUNativeFunctions::max_out( const Tensor& self, int64_t dim, bool keepdim, @@ -568,4 +568,234 @@ ::std::tuple XPUNativeFunctions::max_out( return {values, indices}; } +static inline void check_for_unsupported_isin_dtype(const ScalarType type) { + // Bail out for dtypes unsupported by the sorting algorithm to keep the + // interface consistent. + TORCH_CHECK( + type != ScalarType::Bool && type != ScalarType::BFloat16 && + type != ScalarType::ComplexFloat && type != ScalarType::ComplexDouble, + "Unsupported input type encountered for isin(): ", + type); +} + +// Sorting-based algorithm for isin(); used when the number of test elements is +// large. +static void isin_sorting( + const Tensor& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert, + const Tensor& out) { + // 1. Concatenate unique elements with unique test elements in 1D form. If + // assume_unique is true, skip calls to unique(). + Tensor elements_flat, test_elements_flat, unique_order; + if (assume_unique) { + elements_flat = elements.ravel(); + test_elements_flat = test_elements.ravel(); + } else { + std::tie(elements_flat, unique_order) = + at::_unique(elements, /*sorted=*/false, /*return_inverse=*/true); + std::tie(test_elements_flat, std::ignore) = + at::_unique(test_elements, /*sorted=*/false); + } + + // 2. Stable sort all elements, maintaining order indices to reverse the + // operation. Stable sort is necessary to keep elements before test + // elements within the sorted list. + Tensor all_elements = + at::cat({std::move(elements_flat), std::move(test_elements_flat)}); + auto [sorted_elements, sorted_order] = all_elements.sort( + /*stable=*/true, /*dim=*/0, /*descending=*/false); + + // 3. Create a mask for locations of adjacent duplicate values within the + // sorted list. Duplicate values are in both elements and test elements. + Tensor duplicate_mask = + at::empty_like(sorted_elements, TensorOptions(ScalarType::Bool)); + Tensor sorted_except_first = sorted_elements.slice(0, 1, at::indexing::None); + Tensor sorted_except_last = sorted_elements.slice(0, 0, -1); + duplicate_mask.slice(0, 0, -1).copy_( + invert ? sorted_except_first.ne(sorted_except_last) + : sorted_except_first.eq(sorted_except_last)); + duplicate_mask.index_put_({-1}, invert); + + // 4. Reorder the mask to match the pre-sorted element order. + Tensor mask = at::empty_like(duplicate_mask); + mask.index_copy_(0, sorted_order, duplicate_mask); + + // 5. Index the mask to match the pre-unique element order. If + // assume_unique is true, just take the first N items of the mask, + // where N is the original number of elements. + if (assume_unique) { + out.copy_(mask.slice(0, 0, elements.numel()).view_as(out)); + } else { + out.copy_(at::index(mask, {std::optional(unique_order)})); + } +} + +void isin_Tensor_Tensor_meta( + const Tensor& elements, + Tensor test_elements, + bool assume_unique, + bool invert, + Tensor& out) { + check_for_unsupported_isin_dtype(elements.scalar_type()); + check_for_unsupported_isin_dtype(test_elements.scalar_type()); + auto output_options = + TensorOptions(elements.device()).dtype(ScalarType::Bool); + if (out.defined()) { + xpu::resize_out(out, elements.sizes(), {}, output_options); + } else { + out = xpu::create_out(elements.sizes(), {}, output_options); + } +} + +void isin_Tensor_Tensor_impl( + const Tensor& elements, + Tensor test_elements, + bool assume_unique, + bool invert, + const Tensor& out) { + if (elements.numel() == 0) { + return; + } + + // Heuristic taken from numpy's implementation. + if (test_elements.numel() < + static_cast( + 10.0f * std::pow(static_cast(elements.numel()), 0.145))) { + out.fill_(invert); + native::xpu::isin_kernel(elements, test_elements, invert, out); + } else { + isin_sorting(elements, test_elements, assume_unique, invert, out); + } +} + +Tensor& XPUNativeFunctions::isin_out( + const Tensor& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert, + Tensor& out) { + isin_Tensor_Tensor_meta(elements, test_elements, assume_unique, invert, out); + isin_Tensor_Tensor_impl(elements, test_elements, assume_unique, invert, out); + return out; +} + +Tensor XPUNativeFunctions::isin( + const Tensor& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert) { + Tensor out; + isin_Tensor_Tensor_meta(elements, test_elements, assume_unique, invert, out); + isin_Tensor_Tensor_impl(elements, test_elements, assume_unique, invert, out); + return out; +} + +void isin_Tensor_Scalar_meta( + const Tensor& elements, + const Scalar& test_elements, + bool assume_unique, + bool invert, + Tensor& out) { + check_for_unsupported_isin_dtype(elements.scalar_type()); + check_for_unsupported_isin_dtype(test_elements.type()); + auto output_options = + TensorOptions(elements.device()).dtype(ScalarType::Bool); + if (out.defined()) { + xpu::resize_out(out, elements.sizes(), {}, output_options); + } else { + out = xpu::create_out(elements.sizes(), {}, output_options); + } +} + +void isin_Tensor_Scalar_impl( + const Tensor& elements, + const Scalar& test_elements, + bool assume_unique, + bool invert, + const Tensor& out) { + if (invert) { + at::ne_out(const_cast(out), elements, test_elements); + } else { + at::eq_out(const_cast(out), elements, test_elements); + } +} + +Tensor& XPUNativeFunctions::isin_out( + const Tensor& elements, + const Scalar& test_elements, + bool assume_unique, + bool invert, + Tensor& out) { + isin_Tensor_Scalar_meta(elements, test_elements, assume_unique, invert, out); + isin_Tensor_Scalar_impl(elements, test_elements, assume_unique, invert, out); + return out; +} + +Tensor XPUNativeFunctions::isin( + const Tensor& elements, + const Scalar& test_elements, + bool assume_unique, + bool invert) { + Tensor out; + isin_Tensor_Scalar_meta(elements, test_elements, assume_unique, invert, out); + isin_Tensor_Scalar_impl(elements, test_elements, assume_unique, invert, out); + return out; +} + +void isin_Scalar_Tensor_meta( + const Scalar& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert, + Tensor& out) { + check_for_unsupported_isin_dtype(elements.type()); + check_for_unsupported_isin_dtype(test_elements.scalar_type()); + auto output_options = + TensorOptions(test_elements.device()).dtype(ScalarType::Bool); + if (out.defined()) { + xpu::resize_out(out, {0}, {}, output_options); + } else { + out = xpu::create_out({0}, {}, output_options); + } +} + +void isin_Scalar_Tensor_impl( + const Scalar& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert, + const Tensor& out) { + // redispatch + at::isin_out( + const_cast(out), + at::native::wrapped_scalar_tensor(elements, test_elements.device()), + test_elements, + assume_unique, + invert); +} + +Tensor& XPUNativeFunctions::isin_out( + const Scalar& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert, + Tensor& out) { + isin_Scalar_Tensor_meta(elements, test_elements, assume_unique, invert, out); + isin_Scalar_Tensor_impl(elements, test_elements, assume_unique, invert, out); + return out; +} + +Tensor XPUNativeFunctions::isin( + const Scalar& elements, + const Tensor& test_elements, + bool assume_unique, + bool invert) { + Tensor out; + isin_Scalar_Tensor_meta(elements, test_elements, assume_unique, invert, out); + isin_Scalar_Tensor_impl(elements, test_elements, assume_unique, invert, out); + return out; +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index fa7fafe13..0d96e5289 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -231,7 +231,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "index_fill_.int_Scalar", "index_fill_.int_Tensor", "index_reduce.out", - "isin.Tensor_Tensor_out", "isneginf.out", "isposinf.out", "kthvalue.values", diff --git a/src/ATen/native/xpu/sycl/IndexingUtils.h b/src/ATen/native/xpu/sycl/IndexingUtils.h index 1c6d9c373..26eb2f1ea 100644 --- a/src/ATen/native/xpu/sycl/IndexingUtils.h +++ b/src/ATen/native/xpu/sycl/IndexingUtils.h @@ -99,7 +99,7 @@ static std::tuple computeLinearIndex( static std:: tuple> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) { - checkIndexTensorTypes(orig, /*allow_int*/ true); + checkIndexTensorTypes(orig); // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more // LongTensors auto indices = expandTensors(self, orig); diff --git a/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp b/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp index 41e09a51d..8159db75f 100644 --- a/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp +++ b/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp @@ -105,6 +105,18 @@ void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max) { launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max); } +void isin_kernel( + const Tensor& elements, + const Tensor& test_elements, + bool invert, + const Tensor& out) { + std::vector bc_shape(elements.dim(), 1); + bc_shape.push_back(-1); + out.copy_( + invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1) + : elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1)); +} + } // namespace xpu } // namespace native } // namespace at diff --git a/src/ATen/native/xpu/sycl/TensorCompareKernels.h b/src/ATen/native/xpu/sycl/TensorCompareKernels.h index ae44ae347..51cb74a40 100644 --- a/src/ATen/native/xpu/sycl/TensorCompareKernels.h +++ b/src/ATen/native/xpu/sycl/TensorCompareKernels.h @@ -2,9 +2,7 @@ #include -namespace at { -namespace native { -namespace xpu { +namespace at::native::xpu { void where_kernel(TensorIterator& iter); @@ -19,6 +17,10 @@ void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min); void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max); -} // namespace xpu -} // namespace native -} // namespace at +void isin_kernel( + const Tensor& elements, + const Tensor& test_elements, + bool invert, + const Tensor& out); + +} // namespace at::native::xpu diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index ea50fbc29..9a2dffe01 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -856,16 +856,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): skip_list = ( - # The following isin case fails on CPU fallback, as it could be backend-specific. - "test_isin_xpu_float16", # RuntimeError: "isin_default_cpu" not implemented for 'Half' - "test_isin_different_devices_xpu_float32", # AssertionError: RuntimeError not raised - "test_isin_different_devices_xpu_float64", # AssertionError: RuntimeError not raised - "test_isin_different_devices_xpu_int16", # AssertionError: RuntimeError not raised - "test_isin_different_devices_xpu_int32", # AssertionError: RuntimeError not raised - "test_isin_different_devices_xpu_int64", # AssertionError: RuntimeError not raised - "test_isin_different_devices_xpu_int8", # AssertionError: RuntimeError not raised - "test_isin_different_devices_xpu_uint8", # AssertionError: RuntimeError not raised - "test_isin_different_dtypes_xpu", # RuntimeError: "isin_default_cpu" not implemented for 'Half'" "test_sort_large_slice_xpu", # Hard code CUDA ) res += launch_test("test_sort_and_select_xpu.py", skip_list) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 35c29d96b..c5f9ed35d 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -64,6 +64,7 @@ "index_add", "index_put", "index_select", + "isin", "isnan", "le", "log", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2ecc6790b..768c78c20 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -87,6 +87,12 @@ supported: - ge_.Tensor - cat - cat.out + - isin.Tensor_Tensor_out + - isin.Tensor_Tensor + - isin.Tensor_Scalar_out + - isin.Tensor_Scalar + - isin.Scalar_Tensor_out + - isin.Scalar_Tensor - isnan - isnan.out - masked_fill_.Tensor