forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathLoops.cuh
110 lines (90 loc) · 3.24 KB
/
Loops.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#pragma once
#include <ATen/detail/FunctionTraits.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/TensorIteratorDynamicCasting.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh>
namespace at { namespace native {
#define NUM_THREADS (C10_WARP_SIZE * 2)
#define THREAD_WORK_SIZE 4
#define BLOCK_WORK_SIZE (THREAD_WORK_SIZE * num_threads)
constexpr int num_threads = NUM_THREADS;
constexpr int thread_work_size = THREAD_WORK_SIZE;
constexpr int block_work_size = BLOCK_WORK_SIZE;
template<int N>
static OffsetCalculator<N> make_input_offset_calculator(const TensorIterator& iter) {
// array size can not be 0, this happens when N == 0
constexpr int array_size = std::max<int>(N, 1);
TORCH_INTERNAL_ASSERT(N == iter.ntensors() - 1);
std::array<const int64_t*, array_size> strides;
int64_t element_sizes[array_size];
for (int i = 0; i < N; i++) {
strides[i] = iter.strides(i + 1).data();
element_sizes[i] = iter.element_size(i + 1);
}
return OffsetCalculator<N>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
}
static OffsetCalculator<1> make_output_offset_calculator(const TensorIterator& iter) {
std::array<const int64_t*, 1> strides;
strides[0] = iter.strides(0).data();
int64_t element_size = iter.element_size(0);
return OffsetCalculator<1>(iter.ndim(), iter.shape().data(), strides.data(), &element_size);
}
}} // namespace at::native
// Note:
// CUDA and ROCm get diverged in this PR:
// https://github.com/pytorch/pytorch/pull/32383
// Because for some reason trying to enable vectorized
// memory access introduce regression on ROCm.
#ifndef __HIP_PLATFORM_HCC__
#include <ATen/native/cuda/CUDALoops.cuh>
#else
#include <ATen/native/cuda/ROCmLoops.cuh>
#endif
namespace at { namespace native {
template <typename func_t>
void gpu_kernel(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
for (int arg = 0; arg < iter.ntensors(); arg++) {
TORCH_INTERNAL_ASSERT(iter.device(arg).is_cuda());
}
if (iter.numel() == 0) {
return;
}
if (!iter.can_use_32bit_indexing()) {
for (auto& sub_iter : iter.with_32bit_indexing()) {
gpu_kernel(sub_iter, f);
}
return;
}
gpu_kernel_impl(iter, f);
}
template <typename func_t>
void gpu_kernel_with_scalars(TensorIterator& iter, const func_t& f) {
ASSERT_HOST_DEVICE_LAMBDA(func_t);
TORCH_INTERNAL_ASSERT(iter.ntensors() == 3);
using traits = function_traits<func_t>;
static_assert(
traits::arity == 2,
"gpu_kernel_with_scalars only supports two input arguments");
if (iter.is_cpu_scalar(1)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto a = iter.scalar_value<arg1_t>(1);
iter.remove_operand(1);
const OptionalDeviceGuard device_guard(device_of(iter.tensor(1)));
gpu_kernel(iter, [=]GPU_LAMBDA(arg2_t b) {
return f(a, b);
});
} else if (iter.is_cpu_scalar(2)) {
using arg1_t = typename traits::template arg<0>::type;
using arg2_t = typename traits::template arg<1>::type;
auto b = iter.scalar_value<arg2_t>(2);
iter.remove_operand(2);
gpu_kernel(iter, [=]GPU_LAMBDA(arg1_t a) {
return f(a, b);
});
} else {
gpu_kernel(iter, f);
}
}
}} //namespace at::native