Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::sign/signbit and their variants #625

Merged
merged 13 commits into from
Aug 1, 2024
65 changes: 65 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,71 @@ Tensor& XPUNativeFunctions::sigmoid_out(const Tensor& self, Tensor& out) {
return out;
}

Tensor XPUNativeFunctions::sign(const Tensor& self) {
TORCH_CHECK(
!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
Tensor out;
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
native::xpu::sign_kernel(iter);
return iter.output();
}

Tensor& XPUNativeFunctions::sign_(Tensor& self) {
TORCH_CHECK(
!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
TensorIterator iter;
iter.build_borrowing_unary_op(self, self);
native::xpu::sign_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::sign_out(const Tensor& self, Tensor& out) {
TORCH_CHECK(
!self.is_complex(),
"Unlike NumPy, torch.sign is not intended to support complex numbers. Please use torch.sgn instead.");
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
native::xpu::sign_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::signbit(const Tensor& self) {
TORCH_CHECK(
!self.is_complex(), "signbit is not implemented for complex tensors.");

Tensor out;
TensorIterator iter;
iter.build_borrowing_unary_force_boolean_op(out, self);

if (self.dtype() == at::kBool) {
iter.output().fill_(false);
} else {
native::xpu::signbit_kernel(iter);
}
return iter.output();
}

Tensor& XPUNativeFunctions::signbit_out(const Tensor& self, Tensor& out) {
TORCH_CHECK(
!self.is_complex(), "signbit is not implemented for complex tensors.");
TORCH_CHECK(
out.dtype() == at::kBool,
"signbit does not support non-boolean outputs.");

TensorIterator iter;
iter.build_borrowing_unary_force_boolean_op(out, self);

if (self.dtype() == at::kBool) {
out.fill_(false);
} else {
native::xpu::signbit_kernel(iter);
}
return out;
}

Tensor& XPUNativeFunctions::logit_out(
const Tensor& self,
std::optional<double> eps,
Expand Down
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_scaled_mm",
"segment_reduce",
"_segment_reduce_backward",
"signbit.out",
"sign.out",
"sinc.out",
"special_airy_ai.out",
"special_bessel_j0.out",
Expand Down
29 changes: 29 additions & 0 deletions src/ATen/native/xpu/sycl/UnarySignKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,35 @@ void sign_kernel(TensorIteratorBase& iter) {
}
}

template <typename scalar_t>
struct SignbitIntFunctor {
bool operator()(scalar_t a) const {
return is_negative(a);
}
};

template <typename scalar_t>
struct SignbitFunctor {
bool operator()(scalar_t a) const {
using opmath_t = at::opmath_type<scalar_t>;
return std::signbit(opmath_t{a});
}
};

void signbit_kernel(TensorIteratorBase& iter) {
// NOTE: signbit does not always support integral arguments.
if (at::isIntegralType(iter.input_dtype(), /*includeBool=*/false)) {
AT_DISPATCH_INTEGRAL_TYPES(iter.input_dtype(), "signbit_xpu", [&]() {
gpu_kernel(iter, SignbitIntFunctor<scalar_t>());
});
} else {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, ScalarType::Half, iter.input_dtype(), "signbit_xpu", [&]() {
gpu_kernel(iter, SignbitFunctor<scalar_t>());
});
}
}

template <typename scalar_t>
struct LogicalNotFunctor {
scalar_t operator()(scalar_t a) const {
Expand Down
2 changes: 2 additions & 0 deletions src/ATen/native/xpu/sycl/UnarySignKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ void sgn_kernel(TensorIteratorBase& iter);

void sign_kernel(TensorIteratorBase& iter);

void signbit_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
6 changes: 6 additions & 0 deletions test/xpu/run_test_with_skip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2240,6 +2240,12 @@ def launch_test(test_case, skip_list=None, exe_list=None):
"test_scaled_mm_vs_emulated_float16_xpu",
"test_scaled_mm_vs_emulated_float32_xpu",
"test_scaled_mm_vs_emulated_row_wise_bfloat16_xpu",

# https://github.com/intel/torch-xpu-ops/issues/676
# Mismatched elements: 9 / 1003002 (0.0%)
# Greatest absolute difference: 711.126220703125 at index (472, 999) (up to 0.1 allowed)
# Greatest relative difference: 2.7107455730438232 at index (472, 997) (up to 0.1 allowed)
"test_cublas_addmm_size_1000_xpu_float32",
)
res += launch_test("test_matmul_cuda_xpu.py", skip_list=skip_list)

Expand Down
2 changes: 2 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@
"sigmoid",
"logsigmoid",
"sgn",
"sign",
"signbit",
"round",
"nn.functional.embedding_bag",
"bucketize",
Expand Down
5 changes: 5 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ supported:
- sigmoid
- sigmoid.out
- sigmoid_
- sign
- sign.out
- sign_
- signbit
- signbit.out
- sigmoid_backward.grad_input
- sigmoid_backward
- hardsigmoid.out
Expand Down