forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathAbsKernel.cu
40 lines (32 loc) · 1.12 KB
/
AbsKernel.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include <ATen/native/UnaryOps.h>
#include <ATen/native/cuda/Loops.cuh>
#include <ATen/Dispatch.h>
#include <ATen/native/DispatchStub.h>
#include <ATen/native/TensorIterator.h>
namespace at { namespace native {
// We manually overload abs because std::abs does not work with thrust::complex types and ROCm.
template<typename scalar_t>
__host__ __device__ static inline scalar_t abs_wrapper(scalar_t v) {
return ::abs(v);
}
template<typename T>
__host__ __device__ static inline c10::complex<T> abs_wrapper(c10::complex<T> v) {
return std::abs(v);
}
__host__ __device__ static inline uint8_t abs_wrapper(uint8_t v) {
return v;
}
__host__ __device__ static inline bool abs_wrapper(bool v) {
return v;
}
void abs_kernel_cuda(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, iter.dtype(), "abs_cuda", [&]() {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "abs_cuda", [&] {
gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
return abs_wrapper(a);
});
});
});
}
REGISTER_DISPATCH(abs_stub, &abs_kernel_cuda);
}} // namespace at::native