Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::logical_and/or/xor and variant operators #529

Merged
merged 13 commits into from
Jul 17, 2024
95 changes: 95 additions & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <ATen/native/xpu/sycl/BinaryBitwiseOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryGeometricKernels.h>
#include <ATen/native/xpu/sycl/BinaryKernels.h>
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
Expand Down Expand Up @@ -525,4 +526,98 @@ Tensor XPUNativeFunctions::copysign(const Tensor& self, const Tensor& other) {
return iter.output();
}

// We need explicit cast to OutFunc because each *_out func is overloaded twice.
// Without An explicit cast, merely referring to *_out function is ambiguous.
using OutFunc =
std::add_const<Tensor& (&)(Tensor&, const Tensor&, const Tensor&)>::type;

template <typename OutImpl>
Tensor comparison_op(
const Tensor& self,
const Tensor& other,
OutImpl& out_impl) {
Tensor result = at::empty({0}, self.options().dtype(kBool));
return out_impl(result, self, other);
}

template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Tensor& other, OutImpl& out_impl) {
return out_impl(self, self, other);
}

template <typename OutImpl>
Tensor& comparison_op_out(
Tensor& result,
const Tensor& self,
const Scalar& other,
OutImpl& out_impl) {
return out_impl(result, self, native::wrapped_scalar_tensor(other));
}

template <typename OutImpl>
Tensor comparison_op(
const Tensor& self,
const Scalar& other,
OutImpl& out_impl) {
return comparison_op(self, native::wrapped_scalar_tensor(other), out_impl);
}

template <typename OutImpl>
Tensor& comparison_op_(Tensor& self, const Scalar& other, OutImpl& out_impl) {
return out_impl(self, self, native::wrapped_scalar_tensor(other));
}

Tensor& XPUNativeFunctions::logical_and_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
auto iter = TensorIterator::comparison_op(out, self, other);
fengyuan14 marked this conversation as resolved.
Show resolved Hide resolved
native::xpu::logical_and_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::logical_and(
const Tensor& self,
const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_and_out));
}

Tensor& XPUNativeFunctions::logical_and_(Tensor& self, const Tensor& other) {
return comparison_op_(self, other, static_cast<OutFunc>(at::logical_and_out));
}

Tensor& XPUNativeFunctions::logical_or_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
auto iter = TensorIterator::comparison_op(out, self, other);
fengyuan14 marked this conversation as resolved.
Show resolved Hide resolved
native::xpu::logical_or_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::logical_or(const Tensor& self, const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_or_out));
}

Tensor& XPUNativeFunctions::logical_or_(Tensor& self, const Tensor& other) {
return comparison_op_(self, other, static_cast<OutFunc>(at::logical_or_out));
}

Tensor& XPUNativeFunctions::logical_xor_out(
const Tensor& self,
const Tensor& other,
Tensor& out) {
auto iter = TensorIterator::comparison_op(out, self, other);
native::xpu::logical_xor_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::logical_xor(const Tensor& self, const Tensor& other) {
return comparison_op(self, other, static_cast<OutFunc>(at::logical_xor_out));
}

Tensor& XPUNativeFunctions::logical_xor_(Tensor& self, const Tensor& other) {
return comparison_op_(self, other, static_cast<OutFunc>(at::logical_xor_out));
}

} // namespace at
3 changes: 0 additions & 3 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"logaddexp2.out",
"logaddexp.out",
"_logcumsumexp",
"logical_and.out",
"logical_or.out",
"logical_xor.out",
"logit",
"logit_backward.grad_input",
"log_normal_",
Expand Down
78 changes: 78 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/BinaryOps.h>
#include <ATen/native/xpu/sycl/BinaryInternal.h>
#include <ATen/native/xpu/sycl/Loops.h>

namespace at::native::xpu {

template <typename scalar_t>
struct LogicalAndFunctor {
bool operator()(scalar_t a, scalar_t b) const {
return a && b;
}
};

void logical_and_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_and_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalAndFunctor<scalar_t>());
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, ScalarType::BFloat16, dtype, "logical_and_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalAndFunctor<scalar_t>());
});
}
}

template <typename scalar_t>
struct LogicalOrFunctor {
bool operator()(scalar_t a, scalar_t b) const {
return a || b;
}
};

void logical_or_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_or_xpu", [&]() {
gpu_kernel_with_scalars(iter, LogicalOrFunctor<scalar_t>());
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, ScalarType::BFloat16, dtype, "logical_or_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalOrFunctor<scalar_t>());
});
}
}

template <typename scalar_t>
struct LogicalXorFunctor {
bool operator()(scalar_t a, scalar_t b) const {
return bool(a) != bool(b);
}
};

void logical_xor_kernel(TensorIteratorBase& iter) {
auto dtype = iter.common_dtype();
if (at::isComplexType(dtype)) {
AT_DISPATCH_COMPLEX_TYPES(dtype, "logical_xor_xpu", [&]() {
gpu_kernel_with_scalars(iter, LogicalXorFunctor<scalar_t>());
});
} else {
AT_DISPATCH_ALL_TYPES_AND3(
kHalf, kBool, ScalarType::BFloat16, dtype, "logical_xor_xpu", [&]() {
opmath_symmetric_gpu_kernel_with_scalars<scalar_t, bool>(
iter, LogicalXorFunctor<scalar_t>());
});
}
}

} // namespace at::native::xpu
13 changes: 13 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

void logical_and_kernel(TensorIteratorBase& iter);

void logical_or_kernel(TensorIteratorBase& iter);

void logical_xor_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
3 changes: 3 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@
"le",
"log",
"lt",
"logical_and",
"logical_or",
"logical_xor",
"logical_not",
"masked_fill",
"maximum",
Expand Down
9 changes: 9 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,15 @@ supported:
- log
- log_
- log.out
- logical_and
- logical_and_
- logical_and.out
- logical_or
- logical_or_
- logical_or.out
- logical_xor
- logical_xor_
- logical_xor.out
- logical_not
- logical_not_
- logical_not.out
Expand Down