Skip to content

Commit 00c9f3e

Browse files
authored
Add aten::lshift/rshift/prelu and thieir variants (#688)
- [x] prelu - [x] __lshift__ - [x] __ilshift__ - [x] __rshift__ - [x] __irshift__
1 parent a818677 commit 00c9f3e

File tree

9 files changed

+236
-4
lines changed

9 files changed

+236
-4
lines changed

src/ATen/native/xpu/Activation.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
#include <ATen/native/xpu/sycl/ActivationLeakyReluKernels.h>
1212
#include <ATen/native/xpu/sycl/ActivationLogSigmoidKernels.h>
1313
#include <ATen/native/xpu/sycl/ActivationMishKernels.h>
14+
#include <ATen/native/xpu/sycl/ActivationPreluKernels.h>
1415
#include <ATen/native/xpu/sycl/ActivationSiluKernels.h>
1516
#include <ATen/native/xpu/sycl/ActivationSoftplusKernels.h>
1617
#include <ATen/native/xpu/sycl/ActivationSoftshrinkKernels.h>
1718
#include <ATen/native/xpu/sycl/ActivationThresholdKernel.h>
18-
namespace at {
1919

20+
namespace at {
2021
Tensor XPUNativeFunctions::relu(const Tensor& self) {
2122
TORCH_CHECK(
2223
self.scalar_type() != at::kBool, "Boolean inputs not supported for relu");
@@ -633,6 +634,37 @@ Tensor& XPUNativeFunctions::softshrink_backward_out(
633634
return grad_input;
634635
}
635636

637+
Tensor XPUNativeFunctions::_prelu_kernel(
638+
const Tensor& self,
639+
const Tensor& weight) {
640+
// Weight broadcasts over self and they have the same dtype
641+
auto result = at::empty_like(self);
642+
auto iter = TensorIteratorConfig()
643+
.add_output(result)
644+
.add_const_input(self)
645+
.add_const_input(weight)
646+
.build();
647+
native::xpu::prelu_kernel(iter);
648+
return result;
649+
}
650+
651+
std::tuple<Tensor, Tensor> XPUNativeFunctions::_prelu_kernel_backward(
652+
const Tensor& grad_out,
653+
const Tensor& self,
654+
const Tensor& weight) {
655+
Tensor grad_self = at::empty({0}, self.options());
656+
Tensor grad_weight = at::empty({0}, weight.options());
657+
auto iter = TensorIteratorConfig()
658+
.add_output(grad_self)
659+
.add_output(grad_weight)
660+
.add_const_input(self)
661+
.add_const_input(weight)
662+
.add_const_input(grad_out)
663+
.build();
664+
native::xpu::prelu_backward_kernel(iter);
665+
return {grad_self, grad_weight};
666+
}
667+
636668
std::tuple<Tensor&, Tensor&> XPUNativeFunctions::log_sigmoid_forward_out(
637669
const Tensor& input,
638670
Tensor& result,

src/ATen/native/xpu/BinaryOps.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <ATen/native/xpu/sycl/BinaryLogicalOpsKernels.h>
1111
#include <ATen/native/xpu/sycl/BinaryMiscBackwardOpsKernels.h>
1212
#include <ATen/native/xpu/sycl/BinaryRemainderKernel.h>
13+
#include <ATen/native/xpu/sycl/BinaryShiftOpsKernels.h>
1314
#include <ATen/native/xpu/sycl/CopysignKernel.h>
1415
#include <ATen/native/xpu/sycl/GcdLcmKernels.h>
1516
#include <ATen/native/xpu/sycl/LogAddExpKernels.h>
@@ -301,6 +302,80 @@ Tensor& XPUNativeFunctions::bitwise_xor_out(
301302
return out;
302303
}
303304

305+
Tensor XPUNativeFunctions::__lshift__(const Tensor& self, const Tensor& other) {
306+
Tensor result;
307+
auto iter = TensorIterator::binary_op(result, self, other);
308+
native::xpu::lshift_kernel(iter);
309+
return iter.output();
310+
}
311+
312+
Tensor XPUNativeFunctions::__lshift__(const Tensor& self, const Scalar& other) {
313+
Tensor result;
314+
auto wrapper = native::wrapped_scalar_tensor(other);
315+
auto iter = TensorIterator::binary_op(result, self, wrapper);
316+
native::xpu::lshift_kernel(iter);
317+
return iter.output();
318+
}
319+
320+
Tensor& XPUNativeFunctions::__ilshift__(Tensor& self, const Tensor& other) {
321+
auto iter = TensorIterator::binary_op(self, self, other);
322+
native::xpu::lshift_kernel(iter);
323+
return self;
324+
}
325+
326+
Tensor& XPUNativeFunctions::__ilshift__(Tensor& self, const Scalar& other) {
327+
auto wrapper = native::wrapped_scalar_tensor(other);
328+
auto iter = TensorIterator::binary_op(self, self, wrapper);
329+
native::xpu::lshift_kernel(iter);
330+
return self;
331+
}
332+
333+
Tensor& XPUNativeFunctions::bitwise_left_shift_out(
334+
const Tensor& self,
335+
const Tensor& other,
336+
Tensor& result) {
337+
auto iter = TensorIterator::borrowing_binary_op(result, self, other);
338+
native::xpu::lshift_kernel(iter);
339+
return result;
340+
}
341+
342+
Tensor XPUNativeFunctions::__rshift__(const Tensor& self, const Tensor& other) {
343+
Tensor result;
344+
auto iter = TensorIterator::binary_op(result, self, other);
345+
native::xpu::rshift_kernel(iter);
346+
return iter.output();
347+
}
348+
349+
Tensor XPUNativeFunctions::__rshift__(const Tensor& self, const Scalar& other) {
350+
Tensor result;
351+
auto wrapper = native::wrapped_scalar_tensor(other);
352+
auto iter = TensorIterator::binary_op(result, self, wrapper);
353+
native::xpu::rshift_kernel(iter);
354+
return iter.output();
355+
}
356+
357+
Tensor& XPUNativeFunctions::__irshift__(Tensor& self, const Tensor& other) {
358+
auto iter = TensorIterator::binary_op(self, self, other);
359+
native::xpu::rshift_kernel(iter);
360+
return self;
361+
}
362+
363+
Tensor& XPUNativeFunctions::__irshift__(Tensor& self, const Scalar& other) {
364+
auto wrapper = native::wrapped_scalar_tensor(other);
365+
auto iter = TensorIterator::binary_op(self, self, wrapper);
366+
native::xpu::rshift_kernel(iter);
367+
return self;
368+
}
369+
370+
Tensor& XPUNativeFunctions::bitwise_right_shift_out(
371+
const Tensor& self,
372+
const Tensor& other,
373+
Tensor& result) {
374+
auto iter = TensorIterator::borrowing_binary_op(result, self, other);
375+
native::xpu::rshift_kernel(iter);
376+
return result;
377+
}
378+
304379
Tensor XPUNativeFunctions::gcd(const Tensor& self, const Tensor& other) {
305380
Tensor out;
306381
auto iter = TensorIterator::borrowing_binary_op(out, self, other);

src/ATen/native/xpu/XPUFallback.template

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,11 +236,8 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
236236
"ormqr",
237237
"_pdist_backward",
238238
"_pdist_forward",
239-
"_prelu_kernel",
240-
"_prelu_kernel_backward",
241239
"put_",
242240
"rrelu_with_noise",
243-
"__rshift__.Scalar",
244241
"_scaled_dot_product_efficient_attention",
245242
"_scaled_mm",
246243
"segment_reduce",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Dispatch.h>
3+
#include <ATen/native/TensorIterator.h>
4+
5+
#include <ATen/native/xpu/sycl/Loops.h>
6+
7+
namespace at::native::xpu {
8+
9+
template <typename scalar_t>
10+
struct PreluFunctor {
11+
scalar_t operator()(scalar_t input, scalar_t weight) const {
12+
return (input > 0) ? input : weight * input;
13+
}
14+
};
15+
16+
template <typename scalar_t>
17+
struct PreluBackwardFunctor {
18+
std::tuple<scalar_t, scalar_t> operator()(
19+
scalar_t input,
20+
scalar_t weight,
21+
scalar_t grad) const {
22+
auto mask = input > 0;
23+
auto grad_input = mask ? grad : weight * grad;
24+
auto grad_weight = mask ? scalar_t{0} : input * grad;
25+
return std::tuple<scalar_t, scalar_t>{grad_input, grad_weight};
26+
}
27+
};
28+
29+
void prelu_kernel(TensorIterator& iter) {
30+
AT_DISPATCH_FLOATING_TYPES_AND2(
31+
kBFloat16, kHalf, iter.dtype(), "prelu_xpu", [&] {
32+
gpu_kernel(iter, PreluFunctor<scalar_t>());
33+
});
34+
}
35+
36+
void prelu_backward_kernel(TensorIterator& iter) {
37+
AT_DISPATCH_FLOATING_TYPES_AND2(
38+
kBFloat16, kHalf, iter.dtype(), "prelu_backward_xpu", [&] {
39+
gpu_kernel_multiple_outputs(iter, PreluBackwardFunctor<scalar_t>());
40+
});
41+
}
42+
43+
} // namespace at::native::xpu
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <ATen/native/TensorIterator.h>
4+
5+
namespace at::native::xpu {
6+
7+
void prelu_kernel(TensorIterator& iter);
8+
9+
void prelu_backward_kernel(TensorIterator& iter);
10+
11+
} // namespace at::native::xpu
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include <ATen/ATen.h>
2+
#include <ATen/Dispatch.h>
3+
#include <ATen/native/TensorIterator.h>
4+
5+
#include <ATen/native/xpu/sycl/Loops.h>
6+
7+
namespace at::native::xpu {
8+
9+
template <typename scalar_t>
10+
struct LshiftFunctor {
11+
scalar_t operator()(scalar_t a, scalar_t b) const {
12+
constexpr scalar_t max_shift = sizeof(scalar_t) * CHAR_BIT;
13+
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
14+
(b >= max_shift)) {
15+
return 0;
16+
}
17+
return static_cast<std::make_unsigned_t<scalar_t>>(a) << b;
18+
}
19+
};
20+
21+
void lshift_kernel(TensorIteratorBase& iter) {
22+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "lshift_xpu", [&]() {
23+
gpu_kernel_with_scalars(iter, LshiftFunctor<scalar_t>());
24+
});
25+
}
26+
27+
template <typename scalar_t>
28+
struct RshiftFunctor {
29+
scalar_t operator()(scalar_t a, scalar_t b) const {
30+
// right shift value to retain sign bit for signed and no bits for
31+
// unsigned
32+
constexpr scalar_t max_shift =
33+
sizeof(scalar_t) * CHAR_BIT - std::is_signed_v<scalar_t>;
34+
if ((static_cast<std::make_signed_t<scalar_t>>(b) < 0) ||
35+
(b >= max_shift)) {
36+
return a >> max_shift;
37+
}
38+
return a >> b;
39+
}
40+
};
41+
42+
void rshift_kernel(TensorIteratorBase& iter) {
43+
AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "rshift_xpu", [&]() {
44+
gpu_kernel_with_scalars(iter, RshiftFunctor<scalar_t>());
45+
});
46+
}
47+
48+
} // namespace at::native::xpu
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#pragma once
2+
3+
#include <ATen/native/TensorIterator.h>
4+
5+
namespace at::native::xpu {
6+
7+
void lshift_kernel(TensorIteratorBase& iter);
8+
9+
void rshift_kernel(TensorIteratorBase& iter);
10+
11+
} // namespace at::native::xpu

test/xpu/xpu_test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
"bitwise_not",
5555
"bitwise_or",
5656
"bitwise_xor",
57+
"bitwise_left_shift",
58+
"bitwise_right_shift",
5759
"addcmul",
5860
"addcdiv",
5961
"clamp",
@@ -112,6 +114,7 @@
112114
"nn.functional.glu",
113115
"nn.functional.pad",
114116
"nn.functional.leaky_relu",
117+
"nn.functional.prelu",
115118
"nn.functional.threshold",
116119
"nn.functional.silu",
117120
"nn.functional.hardsigmoid",

yaml/xpu_functions.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,18 @@ supported:
118118
- relu
119119
- relu_
120120
- relu.out
121+
- _prelu_kernel
122+
- _prelu_kernel_backward
123+
- __lshift__.Scalar
124+
- __lshift__.Tensor
125+
- __ilshift__.Scalar
126+
- __ilshift__.Tensor
127+
- __rshift__.Scalar
128+
- __rshift__.Tensor
129+
- __irshift__.Scalar
130+
- __irshift__.Tensor
131+
- bitwise_left_shift.Tensor_out
132+
- bitwise_right_shift.Tensor_out
121133
- threshold
122134
- threshold_
123135
- threshold.out

0 commit comments

Comments
 (0)