forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
TensorCompare.cu
41 lines (38 loc) · 1.1 KB
/
TensorCompare.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#include "ATen/NativeFunctions.h"
#include "ATen/Dispatch.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
namespace {
template <typename scalar_t>
void where_cuda(
at::Tensor& ret,
const at::Tensor& condition,
const at::Tensor& self,
const at::Tensor& other) {
// Yes this name is repetitive, but the CPU version is called
// CPU_tensor_apply4 and we don't have a CPU namespace or directory.
at::cuda::CUDA_tensor_apply4<scalar_t, uint8_t, scalar_t, scalar_t>(
ret,
condition,
self,
other,
[] __device__(
scalar_t & ret_val,
const uint8_t& cond_val,
const scalar_t& self_val,
const scalar_t& other_val) {
ret_val = cond_val ? self_val : other_val;
});
}
} // namespace
namespace at { namespace native {
Tensor _s_where_cuda(
const Tensor& condition,
const Tensor& self,
const Tensor& other) {
Tensor ret = at::empty(self.sizes(), self.options());
AT_DISPATCH_ALL_TYPES_AND_HALF(ret.type(), "where", [&] {
where_cuda<scalar_t>(ret, condition, self, other);
});
return ret;
}
}} // namespace at::native