From 348cd80b6b4bd6b1199a89542812278c6167cdf0 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Wed, 17 Jul 2024 10:49:16 +0800 Subject: [PATCH] Add aten::logical_and/or/xor and variant operators (#529) - [x] logical_and - [x] logical_or - [x] logical_xor --------- Co-authored-by: Feng Yuan --- src/ATen/native/xpu/BinaryOps.cpp | 95 +++++++++++++++++++ src/ATen/native/xpu/XPUFallback.template | 3 - .../xpu/sycl/BinaryLogicalOpsKernels.cpp | 78 +++++++++++++++ .../native/xpu/sycl/BinaryLogicalOpsKernels.h | 13 +++ test/xpu/xpu_test_utils.py | 3 + yaml/xpu_functions.yaml | 9 ++ 6 files changed, 198 insertions(+), 3 deletions(-) create mode 100644 src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h diff --git a/src/ATen/native/xpu/BinaryOps.cpp b/src/ATen/native/xpu/BinaryOps.cpp index 2ec722722..a88a1eee8 100644 --- a/src/ATen/native/xpu/BinaryOps.cpp +++ b/src/ATen/native/xpu/BinaryOps.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -525,4 +526,98 @@ Tensor XPUNativeFunctions::copysign(const Tensor& self, const Tensor& other) { return iter.output(); } +// We need explicit cast to OutFunc because each *_out func is overloaded twice. +// Without An explicit cast, merely referring to *_out function is ambiguous. +using OutFunc = + std::add_const::type; + +template +Tensor comparison_op( + const Tensor& self, + const Tensor& other, + OutImpl& out_impl) { + Tensor result = at::empty({0}, self.options().dtype(kBool)); + return out_impl(result, self, other); +} + +template +Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) { + return out_impl(self, self, other); +} + +template +Tensor& comparison_op_out( + Tensor& result, + const Tensor& self, + const Scalar& other, + OutImpl& out_impl) { + return out_impl(result, self, native::wrapped_scalar_tensor(other)); +} + +template +Tensor comparison_op( + const Tensor& self, + const Scalar& other, + OutImpl& out_impl) { + return comparison_op(self, native::wrapped_scalar_tensor(other), out_impl); +} + +template +Tensor& comparison_op_(Tensor& self, const Scalar& other, OutImpl& out_impl) { + return out_impl(self, self, native::wrapped_scalar_tensor(other)); +} + +Tensor& XPUNativeFunctions::logical_and_out( + const Tensor& self, + const Tensor& other, + Tensor& out) { + auto iter = TensorIterator::comparison_op(out, self, other); + native::xpu::logical_and_kernel(iter); + return out; +} + +Tensor XPUNativeFunctions::logical_and( + const Tensor& self, + const Tensor& other) { + return comparison_op(self, other, static_cast(at::logical_and_out)); +} + +Tensor& XPUNativeFunctions::logical_and_(Tensor& self, const Tensor& other) { + return comparison_op_(self, other, static_cast(at::logical_and_out)); +} + +Tensor& XPUNativeFunctions::logical_or_out( + const Tensor& self, + const Tensor& other, + Tensor& out) { + auto iter = TensorIterator::comparison_op(out, self, other); + native::xpu::logical_or_kernel(iter); + return out; +} + +Tensor XPUNativeFunctions::logical_or(const Tensor& self, const Tensor& other) { + return comparison_op(self, other, static_cast(at::logical_or_out)); +} + +Tensor& XPUNativeFunctions::logical_or_(Tensor& self, const Tensor& other) { + return comparison_op_(self, other, static_cast(at::logical_or_out)); +} + +Tensor& XPUNativeFunctions::logical_xor_out( + const Tensor& self, + const Tensor& other, + Tensor& out) { + auto iter = TensorIterator::comparison_op(out, self, other); + native::xpu::logical_xor_kernel(iter); + return out; +} + +Tensor XPUNativeFunctions::logical_xor(const Tensor& self, const Tensor& other) { + return comparison_op(self, other, static_cast(at::logical_xor_out)); +} + +Tensor& XPUNativeFunctions::logical_xor_(Tensor& self, const Tensor& other) { + return comparison_op_(self, other, static_cast(at::logical_xor_out)); +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 53b7feb02..355eb8f43 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -246,9 +246,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "logaddexp2.out", "logaddexp.out", "_logcumsumexp", - "logical_and.out", - "logical_or.out", - "logical_xor.out", "logit", "logit_backward.grad_input", "log_normal_", diff --git a/src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.cpp b/src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.cpp new file mode 100644 index 000000000..23146b47d --- /dev/null +++ b/src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.cpp @@ -0,0 +1,78 @@ +#include +#include +#include + +#include +#include +#include + +namespace at::native::xpu { + +template +struct LogicalAndFunctor { + bool operator()(scalar_t a, scalar_t b) const { + return a && b; + } +}; + +void logical_and_kernel(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_xpu", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, LogicalAndFunctor()); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, kBool, ScalarType::BFloat16, dtype, "logical_and_xpu", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, LogicalAndFunctor()); + }); + } +} + +template +struct LogicalOrFunctor { + bool operator()(scalar_t a, scalar_t b) const { + return a || b; + } +}; + +void logical_or_kernel(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_xpu", [&]() { + gpu_kernel_with_scalars(iter, LogicalOrFunctor()); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, kBool, ScalarType::BFloat16, dtype, "logical_or_xpu", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, LogicalOrFunctor()); + }); + } +} + +template +struct LogicalXorFunctor { + bool operator()(scalar_t a, scalar_t b) const { + return bool(a) != bool(b); + } +}; + +void logical_xor_kernel(TensorIteratorBase& iter) { + auto dtype = iter.common_dtype(); + if (at::isComplexType(dtype)) { + AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_xpu", [&]() { + gpu_kernel_with_scalars(iter, LogicalXorFunctor()); + }); + } else { + AT_DISPATCH_ALL_TYPES_AND3( + kHalf, kBool, ScalarType::BFloat16, dtype, "logical_xor_xpu", [&]() { + opmath_symmetric_gpu_kernel_with_scalars( + iter, LogicalXorFunctor()); + }); + } +} + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h b/src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h new file mode 100644 index 000000000..ee641d9fb --- /dev/null +++ b/src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void logical_and_kernel(TensorIteratorBase& iter); + +void logical_or_kernel(TensorIteratorBase& iter); + +void logical_xor_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index a47cd1c9e..0ee677e6b 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -78,6 +78,9 @@ "le", "log", "lt", + "logical_and", + "logical_or", + "logical_xor", "logical_not", "masked_fill", "maximum", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 67d8ba5e6..3240c0db9 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -175,6 +175,15 @@ supported: - log - log_ - log.out + - logical_and + - logical_and_ + - logical_and.out + - logical_or + - logical_or_ + - logical_or.out + - logical_xor + - logical_xor_ + - logical_xor.out - logical_not - logical_not_ - logical_not.out