Skip to content

Commit

Permalink
Add aten::lshift/rshift/prelu and thieir variants (#688)
Browse files Browse the repository at this point in the history
- [x] prelu
- [x] __lshift__
- [x] __ilshift__
- [x] __rshift__
- [x] __irshift__
  • Loading branch information
yucai-intel authored Aug 13, 2024
1 parent a818677 commit 00c9f3e
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 4 deletions.
34 changes: 33 additions & 1 deletion src/ATen/native/xpu/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
#include <ATen/native/xpu/sycl/ActivationLeakyReluKernels.h>
#include <ATen/native/xpu/sycl/ActivationLogSigmoidKernels.h>
#include <ATen/native/xpu/sycl/ActivationMishKernels.h>
#include <ATen/native/xpu/sycl/ActivationPreluKernels.h>
#include <ATen/native/xpu/sycl/ActivationSiluKernels.h>
#include <ATen/native/xpu/sycl/ActivationSoftplusKernels.h>
#include <ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h>
#include <ATen/native/xpu/sycl/ActivationThresholdKernel.h>
namespace at {

namespace at {
Tensor XPUNativeFunctions::relu(const Tensor& self) {
TORCH_CHECK(
self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
Expand Down Expand Up @@ -633,6 +634,37 @@ Tensor& XPUNativeFunctions::softshrink_backward_out(
return grad_input;
}

Tensor XPUNativeFunctions::_prelu_kernel(
const Tensor& self,
const Tensor& weight) {
// Weight broadcasts over self and they have the same dtype
auto result = at::empty_like(self);
auto iter = TensorIteratorConfig()
.add_output(result)
.add_const_input(self)
.add_const_input(weight)
.build();
native::xpu::prelu_kernel(iter);
return result;
}

std::tuple<Tensor, Tensor> XPUNativeFunctions::_prelu_kernel_backward(
const Tensor& grad_out,
const Tensor& self,
const Tensor& weight) {
Tensor grad_self = at::empty({0}, self.options());
Tensor grad_weight = at::empty({0}, weight.options());
auto iter = TensorIteratorConfig()
.add_output(grad_self)
.add_output(grad_weight)
.add_const_input(self)
.add_const_input(weight)
.add_const_input(grad_out)
.build();
native::xpu::prelu_backward_kernel(iter);
return {grad_self, grad_weight};
}

std::tuple<Tensor&, Tensor&> XPUNativeFunctions::log_sigmoid_forward_out(
const Tensor& input,
Tensor& result,
Expand Down
75 changes: 75 additions & 0 deletions src/ATen/native/xpu/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
#include <ATen/native/xpu/sycl/BinaryShiftOpsKernels.h>
#include <ATen/native/xpu/sycl/CopysignKernel.h>
#include <ATen/native/xpu/sycl/GcdLcmKernels.h>
#include <ATen/native/xpu/sycl/LogAddExpKernels.h>
Expand Down Expand Up @@ -301,6 +302,80 @@ Tensor& XPUNativeFunctions::bitwise_xor_out(
return out;
}

Tensor XPUNativeFunctions::__lshift__(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
native::xpu::lshift_kernel(iter);
return iter.output();
}

Tensor XPUNativeFunctions::__lshift__(const Tensor& self, const Scalar& other) {
Tensor result;
auto wrapper = native::wrapped_scalar_tensor(other);
auto iter = TensorIterator::binary_op(result, self, wrapper);
native::xpu::lshift_kernel(iter);
return iter.output();
}

Tensor& XPUNativeFunctions::__ilshift__(Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(self, self, other);
native::xpu::lshift_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::__ilshift__(Tensor& self, const Scalar& other) {
auto wrapper = native::wrapped_scalar_tensor(other);
auto iter = TensorIterator::binary_op(self, self, wrapper);
native::xpu::lshift_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::bitwise_left_shift_out(
const Tensor& self,
const Tensor& other,
Tensor& result) {
auto iter = TensorIterator::borrowing_binary_op(result, self, other);
native::xpu::lshift_kernel(iter);
return result;
}

Tensor XPUNativeFunctions::__rshift__(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
native::xpu::rshift_kernel(iter);
return iter.output();
}

Tensor XPUNativeFunctions::__rshift__(const Tensor& self, const Scalar& other) {
Tensor result;
auto wrapper = native::wrapped_scalar_tensor(other);
auto iter = TensorIterator::binary_op(result, self, wrapper);
native::xpu::rshift_kernel(iter);
return iter.output();
}

Tensor& XPUNativeFunctions::__irshift__(Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(self, self, other);
native::xpu::rshift_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::__irshift__(Tensor& self, const Scalar& other) {
auto wrapper = native::wrapped_scalar_tensor(other);
auto iter = TensorIterator::binary_op(self, self, wrapper);
native::xpu::rshift_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::bitwise_right_shift_out(
const Tensor& self,
const Tensor& other,
Tensor& result) {
auto iter = TensorIterator::borrowing_binary_op(result, self, other);
native::xpu::rshift_kernel(iter);
return result;
}

Tensor XPUNativeFunctions::gcd(const Tensor& self, const Tensor& other) {
Tensor out;
auto iter = TensorIterator::borrowing_binary_op(out, self, other);
Expand Down
3 changes: 0 additions & 3 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,8 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"ormqr",
"_pdist_backward",
"_pdist_forward",
"_prelu_kernel",
"_prelu_kernel_backward",
"put_",
"rrelu_with_noise",
"__rshift__.Scalar",
"_scaled_dot_product_efficient_attention",
"_scaled_mm",
"segment_reduce",
Expand Down
43 changes: 43 additions & 0 deletions src/ATen/native/xpu/sycl/ActivationPreluKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/xpu/sycl/Loops.h>

namespace at::native::xpu {

template <typename scalar_t>
struct PreluFunctor {
scalar_t operator()(scalar_t input, scalar_t weight) const {
return (input > 0) ? input : weight * input;
}
};

template <typename scalar_t>
struct PreluBackwardFunctor {
std::tuple<scalar_t, scalar_t> operator()(
scalar_t input,
scalar_t weight,
scalar_t grad) const {
auto mask = input > 0;
auto grad_input = mask ? grad : weight * grad;
auto grad_weight = mask ? scalar_t{0} : input * grad;
return std::tuple<scalar_t, scalar_t>{grad_input, grad_weight};
}
};

void prelu_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "prelu_xpu", [&] {
gpu_kernel(iter, PreluFunctor<scalar_t>());
});
}

void prelu_backward_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kBFloat16, kHalf, iter.dtype(), "prelu_backward_xpu", [&] {
gpu_kernel_multiple_outputs(iter, PreluBackwardFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
11 changes: 11 additions & 0 deletions src/ATen/native/xpu/sycl/ActivationPreluKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

void prelu_kernel(TensorIterator& iter);

void prelu_backward_kernel(TensorIterator& iter);

} // namespace at::native::xpu
48 changes: 48 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include <ATen/ATen.h>
#include <ATen/Dispatch.h>
#include <ATen/native/TensorIterator.h>

#include <ATen/native/xpu/sycl/Loops.h>

namespace at::native::xpu {

template <typename scalar_t>
struct LshiftFunctor {
scalar_t operator()(scalar_t a, scalar_t b) const {
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
(b >= max_shift)) {
return 0;
}
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
}
};

void lshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_xpu", [&]() {
gpu_kernel_with_scalars(iter, LshiftFunctor<scalar_t>());
});
}

template <typename scalar_t>
struct RshiftFunctor {
scalar_t operator()(scalar_t a, scalar_t b) const {
// right shift value to retain sign bit for signed and no bits for
// unsigned
constexpr scalar_t max_shift =
sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
(b >= max_shift)) {
return a >> max_shift;
}
return a >> b;
}
};

void rshift_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_xpu", [&]() {
gpu_kernel_with_scalars(iter, RshiftFunctor<scalar_t>());
});
}

} // namespace at::native::xpu
11 changes: 11 additions & 0 deletions src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#pragma once

#include <ATen/native/TensorIterator.h>

namespace at::native::xpu {

void lshift_kernel(TensorIteratorBase& iter);

void rshift_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
3 changes: 3 additions & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
"bitwise_not",
"bitwise_or",
"bitwise_xor",
"bitwise_left_shift",
"bitwise_right_shift",
"addcmul",
"addcdiv",
"clamp",
Expand Down Expand Up @@ -112,6 +114,7 @@
"nn.functional.glu",
"nn.functional.pad",
"nn.functional.leaky_relu",
"nn.functional.prelu",
"nn.functional.threshold",
"nn.functional.silu",
"nn.functional.hardsigmoid",
Expand Down
12 changes: 12 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,18 @@ supported:
- relu
- relu_
- relu.out
- _prelu_kernel
- _prelu_kernel_backward
- __lshift__.Scalar
- __lshift__.Tensor
- __ilshift__.Scalar
- __ilshift__.Tensor
- __rshift__.Scalar
- __rshift__.Tensor
- __irshift__.Scalar
- __irshift__.Tensor
- bitwise_left_shift.Tensor_out
- bitwise_right_shift.Tensor_out
- threshold
- threshold_
- threshold.out
Expand Down

0 comments on commit 00c9f3e

Please sign in to comment.