Skip to content

Commit

Permalink
Add aten::logical_and/or/xor and variant operators (#529)
Browse files Browse the repository at this point in the history
- [x]  logical_and
- [x] logical_or
- [x] logical_xor

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
yucai-intel and fengyuan14 authored Jul 17, 2024
1 parent 4b6c788 commit 348cd80
Show file tree
Hide file tree
Showing 6 changed files with 198 additions and 3 deletions.
95 changes: 95 additions & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/xpu/sycl/BinaryBitwiseOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryGeometricKernels.h>
#include <ATen/native/xpu/sycl/BinaryKernels.h>
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
Expand Down Expand Up @@ -525,4 +526,98 @@ Tensor XPUNativeFunctions::copysign(const Tensor& self, const Tensor& other) {
return iter.output();
}

// We need explicit cast to OutFunc because each *_out func is overloaded twice.
// Without An explicit cast, merely referring to *_out function is ambiguous.
using OutFunc =
std::add_const<Tensor& (&)(Tensor&, const Tensor&, const Tensor&)>::type;

template <typename OutImpl>
Tensor comparison_op(
const Tensor& self,
const Tensor& other,
OutImpl& out_impl) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return out_impl(result, self, other);
}

template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
return out_impl(self, self, other);
}

template <typename OutImpl>
Tensor& comparison_op_out(
Tensor& result,
const Tensor& self,
const Scalar& other,
OutImpl& out_impl) {
return out_impl(result, self, native::wrapped_scalar_tensor(other));
}

template <typename OutImpl>
Tensor comparison_op(
const Tensor& self,
const Scalar& other,
OutImpl& out_impl) {
return comparison_op(self, native::wrapped_scalar_tensor(other), out_impl);
}

template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Scalar& other, OutImpl& out_impl) {
return out_impl(self, self, native::wrapped_scalar_tensor(other));
}

Tensor& XPUNativeFunctions::logical_and_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
auto iter = TensorIterator::comparison_op(out, self, other);
native::xpu::logical_and_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::logical_and(
const Tensor& self,
const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out));
}

Tensor& XPUNativeFunctions::logical_and_(Tensor& self, const Tensor& other) {
return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out));
}

Tensor& XPUNativeFunctions::logical_or_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
auto iter = TensorIterator::comparison_op(out, self, other);
native::xpu::logical_or_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::logical_or(const Tensor& self, const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out));
}

Tensor& XPUNativeFunctions::logical_or_(Tensor& self, const Tensor& other) {
return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out));
}

Tensor& XPUNativeFunctions::logical_xor_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
auto iter = TensorIterator::comparison_op(out, self, other);
native::xpu::logical_xor_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::logical_xor(const Tensor& self, const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out));
}

Tensor& XPUNativeFunctions::logical_xor_(Tensor& self, const Tensor& other) {
return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out));
}

} // namespace at
3 changes: 0 additions & 3 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"logaddexp2.out",
"logaddexp.out",
"_logcumsumexp",
"logical_and.out",
"logical_or.out",
"logical_xor.out",
"logit",
"logit_backward.grad_input",
"log_normal_",
Expand Down
78 changes: 78 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/BinaryOps.h>
#include <ATen/native/xpu/sycl/BinaryInternal.h>
#include <ATen/native/xpu/sycl/Loops.h>

namespace at::native::xpu {

template <typename scalar_t>
struct LogicalAndFunctor {
bool operator()(scalar_t a, scalar_t b) const {
return a && b;
}
};

void logical_and_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalAndFunctor<scalar_t>());
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, ScalarType::BFloat16, dtype, "logical_and_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalAndFunctor<scalar_t>());
});
}
}

template <typename scalar_t>
struct LogicalOrFunctor {
bool operator()(scalar_t a, scalar_t b) const {
return a || b;
}
};

void logical_or_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_xpu", [&]() {
gpu_kernel_with_scalars(iter, LogicalOrFunctor<scalar_t>());
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, ScalarType::BFloat16, dtype, "logical_or_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalOrFunctor<scalar_t>());
});
}
}

template <typename scalar_t>
struct LogicalXorFunctor {
bool operator()(scalar_t a, scalar_t b) const {
return bool(a) != bool(b);
}
};

void logical_xor_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_xpu", [&]() {
gpu_kernel_with_scalars(iter, LogicalXorFunctor<scalar_t>());
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, ScalarType::BFloat16, dtype, "logical_xor_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalXorFunctor<scalar_t>());
});
}
}

} // namespace at::native::xpu
13 changes: 13 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

void logical_and_kernel(TensorIteratorBase& iter);

void logical_or_kernel(TensorIteratorBase& iter);

void logical_xor_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
3 changes: 3 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@
"le",
"log",
"lt",
"logical_and",
"logical_or",
"logical_xor",
"logical_not",
"masked_fill",
"maximum",
Expand Down
9 changes: 9 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ supported:
- log
- log_
- log.out
- logical_and
- logical_and_
- logical_and.out
- logical_or
- logical_or_
- logical_or.out
- logical_xor
- logical_xor_
- logical_xor.out
- logical_not
- logical_not_
- logical_not.out
Expand Down

0 comments on commit 348cd80

Please sign in to comment.