forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
SummaryOps.cu
293 lines (271 loc) · 11.3 KB
/
SummaryOps.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAApplyUtils.cuh"
namespace at {
namespace cuda {
#define THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM 100
#define THRESH_NUMBER_BINS_FOR_GLOBAL_MEM 1000
#define FOR_KERNEL_LOOP(i, lim) \
for (IndexType i = blockIdx.x * blockDim.x + threadIdx.x; i < lim; \
i += gridDim.x * blockDim.x)
/*
Memory types used for the 3 histogram implementations.
See `CUDA_tensor_histogram` below.
*/
enum class CUDAHistogramMemoryType { SHARED, MULTI_BLOCK, GLOBAL };
/*
Kernel for computing the histogram of the input.
*/
template <
typename output_t,
typename input_t,
typename IndexType,
int ADims,
int PDims,
int BDims,
CUDAHistogramMemoryType MemoryType = CUDAHistogramMemoryType::MULTI_BLOCK,
typename Op>
__global__ void kernelHistogram1D(
detail::TensorInfo<output_t, IndexType> a, /* output */
detail::TensorInfo<output_t, IndexType> p, /* partial output */
detail::TensorInfo<input_t, IndexType> b, /* input */
int binsize,
IndexType totalElements,
Op getOp) {
extern __shared__ unsigned char my_smem[];
output_t* smem = nullptr;
if (MemoryType == CUDAHistogramMemoryType::SHARED) {
////////////////////////// Shared memory //////////////////////////
// atomically add to block specific shared memory
// then atomically add to the global output tensor
smem = reinterpret_cast<output_t*>(my_smem);
for (IndexType i = threadIdx.x; i < a.sizes[0]; i += blockDim.x) {
smem[i] = 0;
}
__syncthreads();
FOR_KERNEL_LOOP(linearIndex, totalElements) {
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
detail::IndexToOffset<input_t, IndexType, BDims>::get(linearIndex, b);
// Use value at `b` as an offset of `smem`
const IndexType pOffset = b.data[bOffset] / binsize;
atomicAdd(&smem[pOffset], getOp(linearIndex));
}
__syncthreads();
// NOTE: atomically update output bin count.
// Atomic update is imp since __syncthread() will only synchronize threads
// in a given block, not across blocks.
for (IndexType i = threadIdx.x; i < a.sizes[0]; i += blockDim.x) {
const IndexType aOffset =
detail::IndexToOffset<output_t, IndexType, ADims>::get(i, a);
atomicAdd(&a.data[aOffset], smem[i]);
}
} else if (MemoryType == CUDAHistogramMemoryType::MULTI_BLOCK) {
////////////////////////// Multi Block memory //////////////////////////
// atomically add to block specific global tensor
// then atomically add to the global output tensor
// compute histogram for the block
FOR_KERNEL_LOOP(linearIndex, totalElements) {
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
detail::IndexToOffset<input_t, IndexType, BDims>::get(linearIndex, b);
const auto bVal = b.data[bOffset];
// Use value at `b` as an offset of `p`
const IndexType pIdx = p.strides[0] * blockIdx.x + bVal / binsize;
const IndexType pOffset =
detail::IndexToOffset<output_t, IndexType, PDims>::get(pIdx, p);
atomicAdd(&p.data[pOffset], getOp(linearIndex));
}
__syncthreads();
// NOTE: atomically update output bin count.
// Atomic update is imp since __syncthread() will only synchronize threads
// in a given block, not across blocks.
const IndexType pIdx = p.strides[0] * blockIdx.x;
const IndexType pOffset =
detail::IndexToOffset<output_t, IndexType, PDims>::get(pIdx, p);
for (IndexType i = threadIdx.x; i < a.sizes[0]; i += blockDim.x) {
const IndexType aOffset =
detail::IndexToOffset<output_t, IndexType, ADims>::get(i, a);
atomicAdd(&a.data[aOffset], p.data[pOffset + i]);
}
} else {
////////////////////////// Global memory //////////////////////////
// atomically add to the output tensor
// compute histogram for the block
FOR_KERNEL_LOOP(linearIndex, totalElements) {
// Convert `linearIndex` into an offset of `b`
const IndexType bOffset =
detail::IndexToOffset<input_t, IndexType, BDims>::get(linearIndex, b);
const auto bVal = b.data[bOffset];
// Use value at `b` as an offset of `a`
const IndexType aIdx = bVal / binsize;
const IndexType aOffset =
detail::IndexToOffset<output_t, IndexType, ADims>::get(aIdx, a);
atomicAdd(&a.data[aOffset], getOp(linearIndex));
}
}
}
#define HANDLE_CASE(MEMORY_TYPE, WEIGHTS_OP) \
kernelHistogram1D<output_t, input_t, IndexType, 1, 2, 1, MEMORY_TYPE> \
<<<grid, \
block, \
(MEMORY_TYPE == CUDAHistogramMemoryType::SHARED) ? sharedMem : 0, \
getCurrentCUDAStream()>>>( \
aInfo, pInfo, bInfo, binsize, totalElements, WEIGHTS_OP); \
AT_ASSERTM(cudaGetLastError() == cudaSuccess, "kernelHistogram1D failed");
#define HANDLE_SWITCH_CASE(mType, getOp) \
switch (mType) { \
case CUDAHistogramMemoryType::SHARED: \
HANDLE_CASE(CUDAHistogramMemoryType::SHARED, getOp); \
break; \
case CUDAHistogramMemoryType::MULTI_BLOCK: \
HANDLE_CASE(CUDAHistogramMemoryType::MULTI_BLOCK, getOp); \
break; \
default: \
HANDLE_CASE(CUDAHistogramMemoryType::GLOBAL, getOp); \
}
inline int64_t getFreeGlobalMemory() {
// no need to use `cudaSetDevice`
size_t free_mem, total_mem;
cudaMemGetInfo(&free_mem, &total_mem);
AT_ASSERTM(
cudaGetLastError() == cudaSuccess,
"CUDA_tensor_histogram failed to get free global memory");
return static_cast<int64_t>(free_mem);
}
/*
Calculate the frequency of the input values.
`a` contains the final output or the histogram.
Input `b` is assumed to be 1-D non-negative int array.
`c` optionally contains the weight vector.
See `help torch.bincount` for details on the math.
3 implementations based of input size and memory usage:
case: #bins < THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM and enough shared mem
SHARED: Each block atomically adds to it's own **shared** hist copy,
then atomically updates the global tensor.
case: #bins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM and enough global mem
MULTI_BLOCK: Each block atomically adds to it's own **global** hist
copy, then atomically updates the global tensor.
case: THRESH_NUMBER_BINS_FOR_GLOBAL_MEM <= #bins
GLOBAL: all threads atomically update to a single **global** hist copy.
*/
template <typename output_t, typename input_t, bool HasWeights>
bool CUDA_tensor_histogram(
at::Tensor a, /* output */
at::Tensor b, /* input */
at::Tensor c, /* weights(optional) */
int64_t nbins,
int binsize,
TensorArgType aType = TensorArgType::ReadWrite,
TensorArgType bType = TensorArgType::ReadOnly,
TensorArgType cType = TensorArgType::ReadOnly) {
checkBackend("CUDA_tensor_histogram", {a, b}, Backend::CUDA);
if (HasWeights) {
checkBackend("CUDA_tensor_histogram", {c}, Backend::CUDA);
}
auto totalElements = b.size(0);
const dim3 block = getApplyBlock();
dim3 grid;
int64_t curDevice = current_device();
if (curDevice == -1 || !getApplyGrid(totalElements, grid, curDevice)) {
return false;
}
CUDAHistogramMemoryType memType = CUDAHistogramMemoryType::GLOBAL;
auto maxSharedMem = getCurrentDeviceProperties()->sharedMemPerBlock;
auto sharedMem = nbins * sizeof(output_t) + 8; // 8 guard bytes
auto maxGlobalMem = getFreeGlobalMemory();
auto multiBlockMem = nbins * grid.x * sizeof(output_t) + 8; // 8 guard bytes
// determine memory type to use in the kernel
if (nbins < THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM &&
sharedMem < maxSharedMem) {
memType = CUDAHistogramMemoryType::SHARED;
} else if (
nbins < THRESH_NUMBER_BINS_FOR_GLOBAL_MEM &&
multiBlockMem < (maxGlobalMem / 2)) {
// check against half of free mem to be extra safe
// due to cached allocator, we may anyway have slightly more free mem
memType = CUDAHistogramMemoryType::MULTI_BLOCK;
}
// alloc memory for MULTI_BLOCK
using IndexType = int64_t;
auto aInfo = detail::getTensorInfo<output_t, IndexType>(a);
auto bInfo = detail::getTensorInfo<input_t, IndexType>(b);
detail::TensorInfo<output_t, IndexType> pInfo(nullptr, 0, {}, {});
Tensor partial_output;
if (memType == CUDAHistogramMemoryType::MULTI_BLOCK) {
partial_output = native::zeros({grid.x, nbins}, a.options());
pInfo = detail::getTensorInfo<output_t, IndexType>(partial_output);
}
if (HasWeights) {
auto cInfo = detail::getTensorInfo<output_t, IndexType>(c);
const auto getWeightsOp = [cInfo] __device__(IndexType cIndex) {
const IndexType cOffset =
detail::IndexToOffset<output_t, IndexType, 1>::get(cIndex, cInfo);
return cInfo.data[cOffset];
};
HANDLE_SWITCH_CASE(memType, getWeightsOp)
} else {
static const auto getDummyOp = [] __device__(IndexType) { return 1L; };
HANDLE_SWITCH_CASE(memType, getDummyOp)
}
return true;
}
#undef HANDLE_CASE
#undef HANDLE_SWITCH_CASE
#undef FOR_KERNEL_LOOP
#undef THRESH_NUMBER_BINS_FOR_GLOBAL_MEM
#undef THRESH_NUMBER_BINS_FOR_MULTI_BLOCK_MEM
} // namespace cuda
namespace {
///////////////// bincount /////////////////
template <typename input_t, typename weights_t>
Tensor _bincount_cuda_template(
const Tensor& self,
const Tensor& weights,
int64_t minlength) {
if (minlength < 0) {
AT_ERROR("minlength should be >= 0");
}
if (self.dim() == 1 && self.numel() == 0) {
return native::zeros({minlength}, device(kCUDA).dtype(kLong));
}
if (self.dim() != 1 ||
(!std::is_same<input_t, uint8_t>::value &&
*self.min().cpu().data<input_t>() < 0)) {
AT_ERROR("bincount only supports 1-d non-negative integral inputs.");
}
bool has_weights = weights.defined();
if (has_weights && weights.size(0) != self.size(0)) {
AT_ERROR("input and weights should have the same length");
}
auto nbins = self.max().item<int64_t>() + 1L;
nbins = std::max(nbins, minlength);
// alloc output counter on GPU
Tensor output;
if (has_weights) {
output = native::zeros({nbins}, weights.options());
auto ret = cuda::CUDA_tensor_histogram<weights_t, input_t, true>(
output, self, weights, nbins, 1);
} else {
output = native::zeros({nbins}, device(DeviceType::CUDA).dtype(kLong));
auto ret = cuda::CUDA_tensor_histogram<int64_t, input_t, false>(
output, self, weights, nbins, 1);
}
return output;
}
} // namespace
namespace native {
Tensor _bincount_cuda(
const Tensor& self,
const Tensor& weights,
int64_t minlength) {
return AT_DISPATCH_INTEGRAL_TYPES(self.type(), "bincount", [&] {
const auto scalar = weights.type().scalarType();
if (scalar == ScalarType::Undefined || scalar == ScalarType::Float)
return _bincount_cuda_template<scalar_t, float>(self, weights, minlength);
return _bincount_cuda_template<scalar_t, double>(
self, weights.toType(CUDA(kDouble)), minlength);
});
}
} // namespace native
} // namespace at