forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ReduceOps.cpp
547 lines (473 loc) · 19.4 KB
/
ReduceOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
#include "ATen/native/ReduceOps.h"
#include "ATen/ATen.h"
#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/WrapDimUtilsMulti.h"
#include "ReduceOpsUtils.h"
#include "TensorIterator.h"
#include <algorithm>
#include <functional>
#include <limits>
#include <numeric>
#include <vector>
#include <map>
namespace at {
namespace native {
DEFINE_DISPATCH(sum_stub);
DEFINE_DISPATCH(prod_stub);
DEFINE_DISPATCH(norm_kernel);
static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dtype) {
ScalarType scalarType = self.type().scalarType();
ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType) ? ScalarType::Long : scalarType);
return self.toType(upcast_scalarType);
}
using DimMask = TensorIterator::DimMask;
static DimMask make_dim_mask(IntList dims, int ndim) {
auto mask = DimMask();
if (dims.empty()) {
mask.flip();
} else {
for (int dim : dims) {
mask.set(maybe_wrap_dim(dim, ndim));
}
}
return mask;
}
static void allocate_reduction_result(
Tensor& result, const Tensor& self, DimMask mask, bool keepdim,
ScalarType dtype)
{
auto shape = DimVector(self.sizes());
for (int dim = shape.size() - 1; dim >= 0; dim--) {
if (mask[dim]) {
if (keepdim) {
shape[dim] = 1;
} else {
shape.erase(shape.begin() + dim);
}
}
}
if (result.defined()) {
result.resize_(shape);
} else {
result = at::empty(shape, self.type().toScalarType(dtype));
}
}
static Tensor review_reduce_result(const Tensor& result, int ndim, DimMask mask, bool keepdim) {
if (keepdim) {
return result;
}
auto shape = DimVector(result.sizes());
auto stride = DimVector(result.strides());
for (int dim = 0; dim < ndim; dim++) {
if (mask[dim]) {
shape.insert(shape.begin() + dim, 1);
stride.insert(stride.begin() + dim, 0);
}
}
return result.as_strided(shape, stride);
}
static std::unique_ptr<TensorIterator> make_reduction(
const char* name, Tensor& result, const Tensor& self, IntList dim,
bool keepdim, ScalarType dtype)
{
// check that result type and dtype match if provided
AT_CHECK(
!result.defined() || result.type().scalarType() == dtype,
name, ": provided dtype must match dtype of result. Got ",
at::toString(result.type().scalarType()),
" and ",
at::toString(dtype),
".");
int ndim = self.dim();
auto mask = make_dim_mask(dim, ndim);
allocate_reduction_result(result, self, mask, keepdim, dtype);
auto viewed_result = review_reduce_result(result, ndim, mask, keepdim);
if (self.type().scalarType() != dtype) {
return TensorIterator::reduce_op(viewed_result, self.to(dtype));
}
return TensorIterator::reduce_op(viewed_result, self);
}
static inline Tensor cumsum(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
return at::_th_cumsum(integer_upcast(self, dtype), dim);
}
Tensor cumsum(const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::cumsum(self, dim, optional<ScalarType>(dtype));
}
Tensor cumsum(const Tensor& self, int64_t dim) {
return at::native::cumsum(self, dim, c10::nullopt);
}
static inline Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
// result type is favored over dtype; check that they match if provided (NumPy doesn't check)
AT_CHECK(
!dtype.has_value() || (result.type().scalarType() == dtype.value()),
"provided dtype must match dtype of result in cumsum. Got ",
at::toString(result.type().scalarType()),
" and ",
at::toString(dtype.value()),
".");
return at::_th_cumsum_out(result, self.toType(result.type().scalarType()), dim);
}
Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::cumsum_out(result, self, dim, optional<ScalarType>(dtype));
}
Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim) {
return at::native::cumsum_out(result, self, dim, c10::nullopt);
}
static inline Tensor cumprod(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
return at::_th_cumprod(integer_upcast(self, dtype), dim);
}
Tensor cumprod(const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::cumprod(self, dim, optional<ScalarType>(dtype));
}
Tensor cumprod(const Tensor& self, int64_t dim) {
return at::native::cumprod(self, dim, c10::nullopt);
}
static inline Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
// result type is favored over dtype; check that they match if provided (NumPy doesn't check)
AT_CHECK(
!dtype.has_value() || (result.type().scalarType() == dtype.value()),
"provided dtype must match dtype of result in cumprod. Got ",
at::toString(result.type().scalarType()),
" and ",
at::toString(dtype.value()),
".");
return at::_th_cumprod_out(result, self.toType(result.type().scalarType()), dim);
}
Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::cumprod_out(result, self, dim, optional<ScalarType>(dtype));
}
Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim) {
return at::native::cumprod_out(result, self, dim, c10::nullopt);
}
// ALL REDUCE #################################################################
static inline Tensor mean(const Tensor &self, optional<ScalarType> dtype) {
ScalarType scalarType = self.type().scalarType();
AT_CHECK(
at::isFloatingType(scalarType),
"Can only calculate the mean of floating types. Got ",
at::toString(scalarType),
" instead.");
if (self.numel() > 0) {
Tensor result = at::native::sum(self);
return result.div_(self.numel());
} else {
return self.type().scalarTensor(std::numeric_limits<double>::quiet_NaN());
}
}
Tensor mean(const Tensor &self, ScalarType dtype) {
return at::native::mean(self, optional<ScalarType>(dtype));
}
Tensor mean(const Tensor &self) {
return at::native::mean(self, c10::nullopt);
}
static ScalarType get_dtype(Tensor& result, const Tensor& self, optional<ScalarType> dtype,
bool promote_integers=false) {
if (dtype.has_value()) {
return dtype.value();
} else if (result.defined()) {
return result.type().scalarType();
}
ScalarType src_type = self.type().scalarType();
if (promote_integers && at::isIntegralType(src_type)) {
return kLong;
}
return src_type;
}
static Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim,
bool keepdim, optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype(result, self, opt_dtype, true);
auto iter = make_reduction("sum", result, self, dim, keepdim, dtype);
if (iter->numel() == 0) {
result.zero_();
} else {
sum_stub(iter->device_type(), *iter);
}
return result;
}
static Tensor sum(const Tensor& self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
Tensor result;
native::sum_out(result, self, dim, keepdim, dtype);
return result;
}
Tensor sum(const Tensor &self, ScalarType dtype) {
return at::native::sum(self, {}, false, optional<ScalarType>(dtype));
}
Tensor sum(const Tensor &self) {
return at::native::sum(self, {}, false, c10::nullopt);
}
static Tensor& prod_out(Tensor& result, const Tensor& self, IntList dim,
bool keepdim, optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype(result, self, opt_dtype, true);
auto iter = make_reduction("prod", result, self, dim, keepdim, dtype);
if (iter->numel() == 0) {
result.fill_(1);
} else {
prod_stub(iter->device_type(), *iter);
}
return result;
}
static Tensor prod(const Tensor& self, IntList dim, bool keepdim, optional<ScalarType> dtype) {
Tensor result;
native::prod_out(result, self, dim, keepdim, dtype);
return result;
}
Tensor prod(const Tensor &self, ScalarType dtype) {
return at::native::prod(self, {}, false, optional<ScalarType>(dtype));
}
Tensor prod(const Tensor &self) {
return at::native::prod(self, {}, false, c10::nullopt);
}
// \ALL REDUCE ################################################################
// DIM REDUCE #################################################################
static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
bool keepdim, optional<ScalarType> dtype) {
ScalarType scalarType = result.type().scalarType();
AT_CHECK(
at::isFloatingType(scalarType),
"Can only calculate the mean of floating types. Got ",
at::toString(scalarType),
" instead.");
at::native::sum_out(
result, self.toType(result.type().scalarType()), dim, keepdim);
if (result.numel() > 0 && self.ndimension() > 0) {
int64_t numel = self.size(dim);
if (numel > 0) {
result.div_(numel);
} else {
// NumPy equivalent
result.fill_(std::numeric_limits<double>::quiet_NaN());
}
}
return result;
}
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
return at::native::mean_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
return at::native::mean_out(result, self, dim, keepdim, c10::nullopt);
}
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::mean_out(result, self, dim, false, dtype);
}
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::sum_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim) {
return at::native::sum_out(result, self, dim, keepdim, c10::nullopt);
}
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtype) {
return at::native::sum_out(result, self, dim, false, dtype);
}
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
return at::native::prod_out(
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
return at::native::prod_out(result, self, dim, keepdim, c10::nullopt);
}
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::prod_out(result, self, dim, false, dtype);
}
static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optional<ScalarType> dtype) {
ScalarType scalarType = self.type().scalarType();
AT_CHECK(
at::isFloatingType(scalarType),
"Can only calculate the mean of floating types. Got ",
at::toString(scalarType),
" instead.");
Tensor result = at::native::sum(self, dim, keepdim);
if (result.numel() > 0 && self.ndimension() > 0) {
int64_t numel = self.size(dim);
if (numel > 0) {
result.div_(numel);
} else {
// NumPy equivalent
result.fill_(std::numeric_limits<double>::quiet_NaN());
}
}
return result;
}
Tensor mean(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
return at::native::mean(self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
Tensor mean(const Tensor& self, int64_t dim, bool keepdim) {
return at::native::mean(self, dim, keepdim, c10::nullopt);
}
Tensor mean(const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::mean(self, dim, false, dtype);
}
Tensor sum(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
return at::native::sum(self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
Tensor sum(const Tensor& self, IntList dim, bool keepdim) {
return at::native::sum(self, dim, keepdim, c10::nullopt);
}
Tensor sum(const Tensor& self, IntList dim, ScalarType dtype) {
return at::native::sum(self, dim, false, dtype);
}
Tensor prod(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
return at::native::prod(self, dim, keepdim, c10::optional<ScalarType>(dtype));
}
Tensor prod(const Tensor& self, int64_t dim, bool keepdim) {
return at::native::prod(self, dim, keepdim, c10::nullopt);
}
Tensor prod(const Tensor& self, int64_t dim, ScalarType dtype) {
return at::native::prod(self, dim, false, dtype);
}
Tensor& logsumexp_out(Tensor& result, const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
// can't take max of empty tensor
if (self.numel() != 0) {
auto maxes = at::max_values(self, dim, true);
auto maxes_squeezed = (keepdim ? maxes : maxes.squeeze(dim));
maxes_squeezed.masked_fill_(maxes_squeezed.abs() == INFINITY, 0);
at::sum_out(result, at::exp(self - maxes), dim, keepdim);
result.log_().add_(maxes_squeezed);
} else {
at::sum_out(result, at::exp(self), dim, keepdim);
result.log_();
}
return result;
}
Tensor logsumexp(const Tensor &self, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
Tensor result = at::empty({0}, self.options());
return at::native::logsumexp_out(result, self, dim, keepdim);
}
Tensor& _norm_out_cpu(Tensor& result, const Tensor& self, Scalar p, int64_t dim_, bool keepdim) {
int64_t dim = maybe_wrap_dim(dim_, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim))
return result;
if (self.is_contiguous() && result.is_contiguous()) {
_dimreduce_setup(result, self, dim);
norm_kernel(kCPU, result, self, p, dim);
if (!keepdim) {
result.squeeze_(dim);
}
return result;
} else {
return at::_th_norm_out(result, self, p, dim, keepdim);
}
}
Tensor& norm_out(Tensor &result, const Tensor &self, Scalar p, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
return result;
} else {
if (self.is_cuda()) {
return at::_th_norm_out(result, self, p, dim, keepdim);
} else {
return _norm_out_cpu(result, self, p, dim, keepdim);
}
}
}
Tensor _norm(const Tensor &self, Scalar p) {
if (self.type().is_sparse()) {
return at::native_norm(self, p);
} else {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"norm only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "norm only supports floating-point dtypes");
if (self.is_cuda()) {
return at::_th_norm(self, p);
} else {
if (self.is_contiguous()) {
Tensor result = CPU(kFloat).scalarTensor(0).toType(self.type());
norm_kernel(kCPU, result, self, p, c10::nullopt);
return result;
} else {
return at::_th_norm(self, p);
}
}
}
}
Tensor norm(const Tensor& self, Scalar p, int64_t dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::norm_out(result, self, p, dim, keepdim);
}
Tensor norm(const Tensor& self, Scalar p) {
return at::native::_norm(self, p);
}
Tensor all(const Tensor& self, int64_t dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::all_out(result, self, dim, keepdim);
}
Tensor &all_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"all only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(self.type().scalarType() == at::ScalarType::Byte, "all only supports torch.uint8 dtype");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 1, dim, keepdim)) {
return result;
} else {
return at::_th_all_out(result, self, dim, keepdim);
}
}
Tensor any(const Tensor& self, int64_t dim, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::any_out(result, self, dim, keepdim);
}
Tensor &any_out(Tensor &result, const Tensor &self, int64_t dim, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"any only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(self.type().scalarType() == at::ScalarType::Byte, "any only supports torch.uint8 dtype");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, 0, dim, keepdim)) {
return result;
} else {
return at::_th_any_out(result, self, dim, keepdim);
}
}
Tensor var(const Tensor& self, bool unbiased) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"var only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "var only supports floating-point dtypes");
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());
return trivial_return.has_value() ? trivial_return.value() : at::_th_var(self, unbiased);
}
Tensor var(const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::var_out(result, self, dim, unbiased, keepdim);
}
Tensor &var_out(Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"var only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "var only supports floating-point dtypes");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), dim, keepdim)) {
return result;
} else {
return at::_th_var_out(result, self, dim, unbiased, keepdim);
}
}
Tensor std(const Tensor& self, bool unbiased) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
auto trivial_return = _allreduce_return_trivial(self, std::numeric_limits<double>::quiet_NaN());
return trivial_return.has_value() ? trivial_return.value() : at::_th_std(self, unbiased);
}
Tensor std(const Tensor& self, int64_t dim, bool unbiased, bool keepdim) {
Tensor result = at::empty({0}, self.options());
return at::native::std_out(result, self, dim, unbiased, keepdim);
}
Tensor &std_out(Tensor &result, const Tensor &self, int64_t dim, bool unbiased, bool keepdim) {
AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA,
"std only supports CPU AND CUDA backend, got: ", at::toString(self.type().backend()));
AT_CHECK(at::isFloatingType(self.type().scalarType()), "std only supports floating-point dtypes");
dim = maybe_wrap_dim(dim, self.dim());
if (_dimreduce_return_trivial(result, self, std::numeric_limits<double>::quiet_NaN(), dim, keepdim)) {
return result;
} else {
return at::_th_std_out(result, self, dim, unbiased, keepdim);
}
}
}} // namespace at::native