From 00c9f3ee694a7a3864dce77448ea88132b02ca85 Mon Sep 17 00:00:00 2001 From: yucai-intel <108388355+yucai-intel@users.noreply.github.com> Date: Tue, 13 Aug 2024 17:44:04 +0800 Subject: [PATCH] Add aten::lshift/rshift/prelu and thieir variants (#688) - [x] prelu - [x] __lshift__ - [x] __ilshift__ - [x] __rshift__ - [x] __irshift__ --- src/ATen/native/xpu/Activation.cpp | 34 ++++++++- src/ATen/native/xpu/BinaryOps.cpp | 75 +++++++++++++++++++ src/ATen/native/xpu/XPUFallback.template | 3 - .../xpu/sycl/ActivationPreluKernels.cpp | 43 +++++++++++ .../native/xpu/sycl/ActivationPreluKernels.h | 11 +++ .../native/xpu/sycl/BinaryShiftOpsKernels.cpp | 48 ++++++++++++ .../native/xpu/sycl/BinaryShiftOpsKernels.h | 11 +++ test/xpu/xpu_test_utils.py | 3 + yaml/xpu_functions.yaml | 12 +++ 9 files changed, 236 insertions(+), 4 deletions(-) create mode 100644 src/ATen/native/xpu/sycl/ActivationPreluKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/ActivationPreluKernels.h create mode 100644 src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.cpp create mode 100644 src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.h diff --git a/src/ATen/native/xpu/Activation.cpp b/src/ATen/native/xpu/Activation.cpp index fcbbe99db..38aa44dc6 100644 --- a/src/ATen/native/xpu/Activation.cpp +++ b/src/ATen/native/xpu/Activation.cpp @@ -11,12 +11,13 @@ #include #include #include +#include #include #include #include #include -namespace at { +namespace at { Tensor XPUNativeFunctions::relu(const Tensor& self) { TORCH_CHECK( self.scalar_type() != at::kBool, "Boolean inputs not supported for relu"); @@ -633,6 +634,37 @@ Tensor& XPUNativeFunctions::softshrink_backward_out( return grad_input; } +Tensor XPUNativeFunctions::_prelu_kernel( + const Tensor& self, + const Tensor& weight) { + // Weight broadcasts over self and they have the same dtype + auto result = at::empty_like(self); + auto iter = TensorIteratorConfig() + .add_output(result) + .add_const_input(self) + .add_const_input(weight) + .build(); + native::xpu::prelu_kernel(iter); + return result; +} + +std::tuple XPUNativeFunctions::_prelu_kernel_backward( + const Tensor& grad_out, + const Tensor& self, + const Tensor& weight) { + Tensor grad_self = at::empty({0}, self.options()); + Tensor grad_weight = at::empty({0}, weight.options()); + auto iter = TensorIteratorConfig() + .add_output(grad_self) + .add_output(grad_weight) + .add_const_input(self) + .add_const_input(weight) + .add_const_input(grad_out) + .build(); + native::xpu::prelu_backward_kernel(iter); + return {grad_self, grad_weight}; +} + std::tuple XPUNativeFunctions::log_sigmoid_forward_out( const Tensor& input, Tensor& result, diff --git a/src/ATen/native/xpu/BinaryOps.cpp b/src/ATen/native/xpu/BinaryOps.cpp index 14bc49990..e17309841 100644 --- a/src/ATen/native/xpu/BinaryOps.cpp +++ b/src/ATen/native/xpu/BinaryOps.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include #include #include @@ -301,6 +302,80 @@ Tensor& XPUNativeFunctions::bitwise_xor_out( return out; } +Tensor XPUNativeFunctions::__lshift__(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + native::xpu::lshift_kernel(iter); + return iter.output(); +} + +Tensor XPUNativeFunctions::__lshift__(const Tensor& self, const Scalar& other) { + Tensor result; + auto wrapper = native::wrapped_scalar_tensor(other); + auto iter = TensorIterator::binary_op(result, self, wrapper); + native::xpu::lshift_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::__ilshift__(Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(self, self, other); + native::xpu::lshift_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::__ilshift__(Tensor& self, const Scalar& other) { + auto wrapper = native::wrapped_scalar_tensor(other); + auto iter = TensorIterator::binary_op(self, self, wrapper); + native::xpu::lshift_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::bitwise_left_shift_out( + const Tensor& self, + const Tensor& other, + Tensor& result) { + auto iter = TensorIterator::borrowing_binary_op(result, self, other); + native::xpu::lshift_kernel(iter); + return result; +} + +Tensor XPUNativeFunctions::__rshift__(const Tensor& self, const Tensor& other) { + Tensor result; + auto iter = TensorIterator::binary_op(result, self, other); + native::xpu::rshift_kernel(iter); + return iter.output(); +} + +Tensor XPUNativeFunctions::__rshift__(const Tensor& self, const Scalar& other) { + Tensor result; + auto wrapper = native::wrapped_scalar_tensor(other); + auto iter = TensorIterator::binary_op(result, self, wrapper); + native::xpu::rshift_kernel(iter); + return iter.output(); +} + +Tensor& XPUNativeFunctions::__irshift__(Tensor& self, const Tensor& other) { + auto iter = TensorIterator::binary_op(self, self, other); + native::xpu::rshift_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::__irshift__(Tensor& self, const Scalar& other) { + auto wrapper = native::wrapped_scalar_tensor(other); + auto iter = TensorIterator::binary_op(self, self, wrapper); + native::xpu::rshift_kernel(iter); + return self; +} + +Tensor& XPUNativeFunctions::bitwise_right_shift_out( + const Tensor& self, + const Tensor& other, + Tensor& result) { + auto iter = TensorIterator::borrowing_binary_op(result, self, other); + native::xpu::rshift_kernel(iter); + return result; +} + Tensor XPUNativeFunctions::gcd(const Tensor& self, const Tensor& other) { Tensor out; auto iter = TensorIterator::borrowing_binary_op(out, self, other); diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 5b2d6e5ff..75cbd3c1c 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -236,11 +236,8 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "ormqr", "_pdist_backward", "_pdist_forward", - "_prelu_kernel", - "_prelu_kernel_backward", "put_", "rrelu_with_noise", - "__rshift__.Scalar", "_scaled_dot_product_efficient_attention", "_scaled_mm", "segment_reduce", diff --git a/src/ATen/native/xpu/sycl/ActivationPreluKernels.cpp b/src/ATen/native/xpu/sycl/ActivationPreluKernels.cpp new file mode 100644 index 000000000..8a2c7012f --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationPreluKernels.cpp @@ -0,0 +1,43 @@ +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct PreluFunctor { + scalar_t operator()(scalar_t input, scalar_t weight) const { + return (input > 0) ? input : weight * input; + } +}; + +template +struct PreluBackwardFunctor { + std::tuple operator()( + scalar_t input, + scalar_t weight, + scalar_t grad) const { + auto mask = input > 0; + auto grad_input = mask ? grad : weight * grad; + auto grad_weight = mask ? scalar_t{0} : input * grad; + return std::tuple{grad_input, grad_weight}; + } +}; + +void prelu_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.dtype(), "prelu_xpu", [&] { + gpu_kernel(iter, PreluFunctor()); + }); +} + +void prelu_backward_kernel(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2( + kBFloat16, kHalf, iter.dtype(), "prelu_backward_xpu", [&] { + gpu_kernel_multiple_outputs(iter, PreluBackwardFunctor()); + }); +} + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/ActivationPreluKernels.h b/src/ATen/native/xpu/sycl/ActivationPreluKernels.h new file mode 100644 index 000000000..23f60c242 --- /dev/null +++ b/src/ATen/native/xpu/sycl/ActivationPreluKernels.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void prelu_kernel(TensorIterator& iter); + +void prelu_backward_kernel(TensorIterator& iter); + +} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.cpp b/src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.cpp new file mode 100644 index 000000000..64adba17e --- /dev/null +++ b/src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.cpp @@ -0,0 +1,48 @@ +#include +#include +#include + +#include + +namespace at::native::xpu { + +template +struct LshiftFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT; + if ((static_cast>(b) < 0) || + (b >= max_shift)) { + return 0; + } + return static_cast>(a) << b; + } +}; + +void lshift_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_xpu", [&]() { + gpu_kernel_with_scalars(iter, LshiftFunctor()); + }); +} + +template +struct RshiftFunctor { + scalar_t operator()(scalar_t a, scalar_t b) const { + // right shift value to retain sign bit for signed and no bits for + // unsigned + constexpr scalar_t max_shift = + sizeof(scalar_t) * CHAR_BIT - std::is_signed_v; + if ((static_cast>(b) < 0) || + (b >= max_shift)) { + return a >> max_shift; + } + return a >> b; + } +}; + +void rshift_kernel(TensorIteratorBase& iter) { + AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_xpu", [&]() { + gpu_kernel_with_scalars(iter, RshiftFunctor()); + }); +} + +} // namespace at::native::xpu \ No newline at end of file diff --git a/src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.h b/src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.h new file mode 100644 index 000000000..7e661d919 --- /dev/null +++ b/src/ATen/native/xpu/sycl/BinaryShiftOpsKernels.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace at::native::xpu { + +void lshift_kernel(TensorIteratorBase& iter); + +void rshift_kernel(TensorIteratorBase& iter); + +} // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index c3a268024..3bc809743 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -54,6 +54,8 @@ "bitwise_not", "bitwise_or", "bitwise_xor", + "bitwise_left_shift", + "bitwise_right_shift", "addcmul", "addcdiv", "clamp", @@ -112,6 +114,7 @@ "nn.functional.glu", "nn.functional.pad", "nn.functional.leaky_relu", + "nn.functional.prelu", "nn.functional.threshold", "nn.functional.silu", "nn.functional.hardsigmoid", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 263c522d0..8d92fa60e 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -118,6 +118,18 @@ supported: - relu - relu_ - relu.out + - _prelu_kernel + - _prelu_kernel_backward + - __lshift__.Scalar + - __lshift__.Tensor + - __ilshift__.Scalar + - __ilshift__.Tensor + - __rshift__.Scalar + - __rshift__.Tensor + - __irshift__.Scalar + - __irshift__.Tensor + - bitwise_left_shift.Tensor_out + - bitwise_right_shift.Tensor_out - threshold - threshold_ - threshold.out