Skip to content

Commit

Permalink
Add aten::isin and its variant operators
Browse files Browse the repository at this point in the history
Tensor_Tensor/Scalar_Tensor/Tensor_Scalar

Signed-off-by: Feng Yuan <[email protected]>
  • Loading branch information
fengyuan14 committed Jul 9, 2024
1 parent 9df4ab4 commit 4dffbc0
Show file tree
Hide file tree
Showing 8 changed files with 259 additions and 19 deletions.
232 changes: 231 additions & 1 deletion src/ATen/native/xpu/TensorCompare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ ::std::tuple<Tensor&, Tensor&> XPUNativeFunctions::min_out(
return {values, indices};
}

::std::tuple<Tensor&, Tensor&> XPUNativeFunctions::max_out(
std::tuple<Tensor&, Tensor&> XPUNativeFunctions::max_out(
const Tensor& self,
int64_t dim,
bool keepdim,
Expand All @@ -568,4 +568,234 @@ ::std::tuple<Tensor&, Tensor&> XPUNativeFunctions::max_out(
return {values, indices};
}

static inline void check_for_unsupported_isin_dtype(const ScalarType type) {
// Bail out for dtypes unsupported by the sorting algorithm to keep the
// interface consistent.
TORCH_CHECK(
type != ScalarType::Bool && type != ScalarType::BFloat16 &&
type != ScalarType::ComplexFloat && type != ScalarType::ComplexDouble,
"Unsupported input type encountered for isin(): ",
type);
}

// Sorting-based algorithm for isin(); used when the number of test elements is
// large.
static void isin_sorting(
const Tensor& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert,
const Tensor& out) {
// 1. Concatenate unique elements with unique test elements in 1D form. If
// assume_unique is true, skip calls to unique().
Tensor elements_flat, test_elements_flat, unique_order;
if (assume_unique) {
elements_flat = elements.ravel();
test_elements_flat = test_elements.ravel();
} else {
std::tie(elements_flat, unique_order) =
at::_unique(elements, /*sorted=*/false, /*return_inverse=*/true);
std::tie(test_elements_flat, std::ignore) =
at::_unique(test_elements, /*sorted=*/false);
}

// 2. Stable sort all elements, maintaining order indices to reverse the
// operation. Stable sort is necessary to keep elements before test
// elements within the sorted list.
Tensor all_elements =
at::cat({std::move(elements_flat), std::move(test_elements_flat)});
auto [sorted_elements, sorted_order] = all_elements.sort(
/*stable=*/true, /*dim=*/0, /*descending=*/false);

// 3. Create a mask for locations of adjacent duplicate values within the
// sorted list. Duplicate values are in both elements and test elements.
Tensor duplicate_mask =
at::empty_like(sorted_elements, TensorOptions(ScalarType::Bool));
Tensor sorted_except_first = sorted_elements.slice(0, 1, at::indexing::None);
Tensor sorted_except_last = sorted_elements.slice(0, 0, -1);
duplicate_mask.slice(0, 0, -1).copy_(
invert ? sorted_except_first.ne(sorted_except_last)
: sorted_except_first.eq(sorted_except_last));
duplicate_mask.index_put_({-1}, invert);

// 4. Reorder the mask to match the pre-sorted element order.
Tensor mask = at::empty_like(duplicate_mask);
mask.index_copy_(0, sorted_order, duplicate_mask);

// 5. Index the mask to match the pre-unique element order. If
// assume_unique is true, just take the first N items of the mask,
// where N is the original number of elements.
if (assume_unique) {
out.copy_(mask.slice(0, 0, elements.numel()).view_as(out));
} else {
out.copy_(at::index(mask, {std::optional<Tensor>(unique_order)}));
}
}

void isin_Tensor_Tensor_meta(
const Tensor& elements,
Tensor test_elements,
bool assume_unique,
bool invert,
Tensor& out) {
check_for_unsupported_isin_dtype(elements.scalar_type());
check_for_unsupported_isin_dtype(test_elements.scalar_type());
auto output_options =
TensorOptions(elements.device()).dtype(ScalarType::Bool);
if (out.defined()) {
xpu::resize_out(out, elements.sizes(), {}, output_options);
} else {
out = xpu::create_out(elements.sizes(), {}, output_options);
}
}

void isin_Tensor_Tensor_impl(
const Tensor& elements,
Tensor test_elements,
bool assume_unique,
bool invert,
const Tensor& out) {
if (elements.numel() == 0) {
return;
}

// Heuristic taken from numpy's implementation.
if (test_elements.numel() <
static_cast<int64_t>(
10.0f * std::pow(static_cast<double>(elements.numel()), 0.145))) {
out.fill_(invert);
native::xpu::isin_kernel(elements, test_elements, invert, out);
} else {
isin_sorting(elements, test_elements, assume_unique, invert, out);
}
}

Tensor& XPUNativeFunctions::isin_out(
const Tensor& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert,
Tensor& out) {
isin_Tensor_Tensor_meta(elements, test_elements, assume_unique, invert, out);
isin_Tensor_Tensor_impl(elements, test_elements, assume_unique, invert, out);
return out;
}

Tensor XPUNativeFunctions::isin(
const Tensor& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert) {
Tensor out;
isin_Tensor_Tensor_meta(elements, test_elements, assume_unique, invert, out);
isin_Tensor_Tensor_impl(elements, test_elements, assume_unique, invert, out);
return out;
}

void isin_Tensor_Scalar_meta(
const Tensor& elements,
const Scalar& test_elements,
bool assume_unique,
bool invert,
Tensor& out) {
check_for_unsupported_isin_dtype(elements.scalar_type());
check_for_unsupported_isin_dtype(test_elements.type());
auto output_options =
TensorOptions(elements.device()).dtype(ScalarType::Bool);
if (out.defined()) {
xpu::resize_out(out, elements.sizes(), {}, output_options);
} else {
out = xpu::create_out(elements.sizes(), {}, output_options);
}
}

void isin_Tensor_Scalar_impl(
const Tensor& elements,
const Scalar& test_elements,
bool assume_unique,
bool invert,
const Tensor& out) {
if (invert) {
at::ne_out(const_cast<Tensor&>(out), elements, test_elements);
} else {
at::eq_out(const_cast<Tensor&>(out), elements, test_elements);
}
}

Tensor& XPUNativeFunctions::isin_out(
const Tensor& elements,
const Scalar& test_elements,
bool assume_unique,
bool invert,
Tensor& out) {
isin_Tensor_Scalar_meta(elements, test_elements, assume_unique, invert, out);
isin_Tensor_Scalar_impl(elements, test_elements, assume_unique, invert, out);
return out;
}

Tensor XPUNativeFunctions::isin(
const Tensor& elements,
const Scalar& test_elements,
bool assume_unique,
bool invert) {
Tensor out;
isin_Tensor_Scalar_meta(elements, test_elements, assume_unique, invert, out);
isin_Tensor_Scalar_impl(elements, test_elements, assume_unique, invert, out);
return out;
}

void isin_Scalar_Tensor_meta(
const Scalar& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert,
Tensor& out) {
check_for_unsupported_isin_dtype(elements.type());
check_for_unsupported_isin_dtype(test_elements.scalar_type());
auto output_options =
TensorOptions(test_elements.device()).dtype(ScalarType::Bool);
if (out.defined()) {
xpu::resize_out(out, {0}, {}, output_options);
} else {
out = xpu::create_out({0}, {}, output_options);
}
}

void isin_Scalar_Tensor_impl(
const Scalar& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert,
const Tensor& out) {
// redispatch
at::isin_out(
const_cast<Tensor&>(out),
at::native::wrapped_scalar_tensor(elements, test_elements.device()),
test_elements,
assume_unique,
invert);
}

Tensor& XPUNativeFunctions::isin_out(
const Scalar& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert,
Tensor& out) {
isin_Scalar_Tensor_meta(elements, test_elements, assume_unique, invert, out);
isin_Scalar_Tensor_impl(elements, test_elements, assume_unique, invert, out);
return out;
}

Tensor XPUNativeFunctions::isin(
const Scalar& elements,
const Tensor& test_elements,
bool assume_unique,
bool invert) {
Tensor out;
isin_Scalar_Tensor_meta(elements, test_elements, assume_unique, invert, out);
isin_Scalar_Tensor_impl(elements, test_elements, assume_unique, invert, out);
return out;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"index_fill_.int_Scalar",
"index_fill_.int_Tensor",
"index_reduce.out",
"isin.Tensor_Tensor_out",
"isneginf.out",
"isposinf.out",
"kthvalue.values",
Expand Down
2 changes: 1 addition & 1 deletion src/ATen/native/xpu/sycl/IndexingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ static std::tuple<Tensor, int64_t, int64_t, int64_t> computeLinearIndex(
static std::
tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>>
makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
checkIndexTensorTypes(orig, /*allow_int*/ true);
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more
// LongTensors
auto indices = expandTensors(self, orig);
Expand Down
12 changes: 12 additions & 0 deletions src/ATen/native/xpu/sycl/TensorCompareKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,18 @@ void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max) {
launch_clamp_scalar(iter, max, max, at::native::detail::ClampLimits::Max);
}

void isin_kernel(
const Tensor& elements,
const Tensor& test_elements,
bool invert,
const Tensor& out) {
std::vector<int64_t> bc_shape(elements.dim(), 1);
bc_shape.push_back(-1);
out.copy_(
invert ? elements.unsqueeze(-1).ne(test_elements.view(bc_shape)).all(-1)
: elements.unsqueeze(-1).eq(test_elements.view(bc_shape)).any(-1));
}

} // namespace xpu
} // namespace native
} // namespace at
14 changes: 8 additions & 6 deletions src/ATen/native/xpu/sycl/TensorCompareKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@

#include <ATen/native/TensorIterator.h>

namespace at {
namespace native {
namespace xpu {
namespace at::native::xpu {

void where_kernel(TensorIterator& iter);

Expand All @@ -19,6 +17,10 @@ void clamp_min_scalar_kernel(TensorIteratorBase& iter, Scalar min);

void clamp_max_scalar_kernel(TensorIteratorBase& iter, Scalar max);

} // namespace xpu
} // namespace native
} // namespace at
void isin_kernel(
const Tensor& elements,
const Tensor& test_elements,
bool invert,
const Tensor& out);

} // namespace at::native::xpu
10 changes: 0 additions & 10 deletions test/xpu/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,16 +856,6 @@ def launch_test(test_case, skip_list=None, exe_list=None):


skip_list = (
# The following isin case fails on CPU fallback, as it could be backend-specific.
"test_isin_xpu_float16", # RuntimeError: "isin_default_cpu" not implemented for 'Half'
"test_isin_different_devices_xpu_float32", # AssertionError: RuntimeError not raised
"test_isin_different_devices_xpu_float64", # AssertionError: RuntimeError not raised
"test_isin_different_devices_xpu_int16", # AssertionError: RuntimeError not raised
"test_isin_different_devices_xpu_int32", # AssertionError: RuntimeError not raised
"test_isin_different_devices_xpu_int64", # AssertionError: RuntimeError not raised
"test_isin_different_devices_xpu_int8", # AssertionError: RuntimeError not raised
"test_isin_different_devices_xpu_uint8", # AssertionError: RuntimeError not raised
"test_isin_different_dtypes_xpu", # RuntimeError: "isin_default_cpu" not implemented for 'Half'"
"test_sort_large_slice_xpu", # Hard code CUDA
)
res += launch_test("test_sort_and_select_xpu.py", skip_list)
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"index_add",
"index_put",
"index_select",
"isin",
"isnan",
"le",
"log",
Expand Down
6 changes: 6 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ supported:
- ge_.Tensor
- cat
- cat.out
- isin.Tensor_Tensor_out
- isin.Tensor_Tensor
- isin.Tensor_Scalar_out
- isin.Tensor_Scalar
- isin.Scalar_Tensor_out
- isin.Scalar_Tensor
- isnan
- isnan.out
- masked_fill_.Tensor
Expand Down

0 comments on commit 4dffbc0

Please sign in to comment.