forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
UnaryOps.cpp
210 lines (184 loc) · 6.88 KB
/
UnaryOps.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// define constants like M_PI and C keywords for MSVC
#ifdef _MSC_VER
#define _USE_MATH_DEFINES
#include <math.h>
#endif
#include "ATen/ATen.h"
#include "ATen/Dispatch.h"
#include "ATen/ExpandUtils.h"
#include "ATen/NativeFunctions.h"
#include "ATen/WrapDimUtils.h"
#include "ATen/CPUApplyUtils.h"
#include "ATen/Parallel.h"
#include "ATen/native/cpu/UnaryOpsKernel.h"
#include <algorithm>
#include <cmath>
#include <functional>
#include <numeric>
#include <vector>
#include <map>
// NOTE:
// YOU ARE NOT OBLIGED TO USE THESE MACROS
// If you're writing something more specialized, please don't try to make them
// work for your case, but just write something new instead.
namespace at {
namespace native {
Tensor clamp(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
Tensor result = at::empty({0}, self.options());
return clamp_out(result, self, min, max);
}
Tensor clamp_max(const Tensor& self, Scalar max) {
Tensor result = at::empty({0}, self.options());
return clamp_max_out(result, self, max);
}
Tensor clamp_min(const Tensor& self, Scalar min) {
Tensor result = at::empty({0}, self.options());
return clamp_min_out(result, self, min);
}
Tensor& _clamp__cpu(Tensor& self, optional<Scalar> min, optional<Scalar> max) {
if (min && max) {
return _th_clamp_out(self, self, *min, *max);
} else if (max) {
return _th_clamp_max_out(self, self, *max);
} else if (min) {
return _th_clamp_min_out(self, self, *min);
} else {
return self;
}
}
Tensor& _clamp_out_cpu(
Tensor& result,
const Tensor& self,
optional<Scalar> min,
optional<Scalar> max) {
if (min && max) {
_th_clamp_out(result, self, *min, *max);
} else if (max) {
_th_clamp_max_out(result, self, *max);
} else if (min) {
_th_clamp_min_out(result, self, *min);
}
return result;
}
Tensor& _clamp_max__cpu(Tensor& self, Scalar max) {
return _th_clamp_max_out(self, self, max);
}
Tensor& _clamp_max_out_cpu(Tensor& result, const Tensor& self, Scalar max) {
return _th_clamp_max_out(result, self, max);
}
Tensor& _clamp_min__cpu(Tensor& self, Scalar min) {
return _th_clamp_min_out(self, self, min);
}
Tensor& _clamp_min_out_cpu(Tensor& result, const Tensor& self, Scalar min) {
return _th_clamp_min_out(result, self, min);
}
Tensor& fill_(Tensor& self, Scalar value) {
return at::_th_fill_(self, value);
}
Tensor& fill_(Tensor& self, const Tensor& value) {
return at::_th_fill_(self, value);
}
Tensor mvlgamma(const Tensor& self, int64_t p) {
AT_CHECK(at::isFloatingType(self.type().scalarType()),
"mvlgamma is not implemented for ", self.type());
AT_CHECK((self > 0.5 * (p - 1.)).all().item<uint8_t>(),
"Condition for computing multivariate log-gamma not met");
AT_CHECK(p >= 1, "p has to be greater than or equal to 1");
Tensor args = native::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
args = args.add(self.unsqueeze(-1));
return args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.);
}
Tensor& mvlgamma_(Tensor& self, int64_t p) {
AT_CHECK(at::isFloatingType(self.type().scalarType()),
"mvlgamma is not implemented for ", self.type());
AT_CHECK((self > 0.5 * (p - 1.)).all().item<uint8_t>(),
"Condition for computing multivariate log-gamma not met");
AT_CHECK(p >= 1, "p has to be greater than or equal to 1");
Tensor args = native::arange(-p / 2. + 0.5, 0.5, 0.5, self.options());
args = args.add(self.unsqueeze(-1));
return self.copy_(args.lgamma_().sum(-1).add_(p * (p - 1) * std::log(M_PI) / 4.));
}
// NB: If you use this macro, you may also need to add a CUDA forwarding
// stub in CUDAUnaryOps
#define IMPLEMENT_UNARY_OP_VEC(op) \
Tensor op(const Tensor& self) { \
Tensor result = at::empty({0}, self.options()); \
return at::op##_out(result, self); \
} \
Tensor& _##op##__cpu(Tensor& self_) { \
if (self_.numel() > 0) { \
Tensor self = sort_strides(self_); \
op##Impl(kCPU, self, self); \
} \
return self_; \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
if (result.numel() > 0) { \
op##Impl(kCPU, result, self); \
} \
return result; \
}
#define IMPLEMENT_UNARY_OP_TH(op) \
Tensor op(const Tensor& self) { \
Tensor result = at::empty({0}, self.options()); \
return at::op##_out(result, self); \
} \
Tensor& _##op##__cpu(Tensor& self) { \
return at::op##_out(self, self); \
} \
Tensor& _##op##_out_cpu(Tensor& result, const Tensor& self) { \
result.resize_(self.sizes()); \
return at::_th_##op##_out(result, self); \
}
// NB: Temp. defaulting to TH implementation of abs due to issues with Apple
IMPLEMENT_UNARY_OP_TH(abs)
IMPLEMENT_UNARY_OP_VEC(acos)
IMPLEMENT_UNARY_OP_VEC(asin)
IMPLEMENT_UNARY_OP_VEC(atan)
IMPLEMENT_UNARY_OP_VEC(ceil)
IMPLEMENT_UNARY_OP_VEC(cos)
IMPLEMENT_UNARY_OP_TH(cosh)
IMPLEMENT_UNARY_OP_VEC(erf)
IMPLEMENT_UNARY_OP_VEC(erfc)
IMPLEMENT_UNARY_OP_VEC(exp)
IMPLEMENT_UNARY_OP_VEC(expm1)
IMPLEMENT_UNARY_OP_VEC(floor)
IMPLEMENT_UNARY_OP_VEC(log)
IMPLEMENT_UNARY_OP_VEC(log10)
IMPLEMENT_UNARY_OP_VEC(log1p)
IMPLEMENT_UNARY_OP_VEC(log2)
IMPLEMENT_UNARY_OP_VEC(round)
IMPLEMENT_UNARY_OP_VEC(rsqrt)
IMPLEMENT_UNARY_OP_VEC(sigmoid)
IMPLEMENT_UNARY_OP_VEC(sin)
IMPLEMENT_UNARY_OP_TH(sinh)
IMPLEMENT_UNARY_OP_VEC(sqrt)
IMPLEMENT_UNARY_OP_VEC(tan)
IMPLEMENT_UNARY_OP_VEC(tanh)
IMPLEMENT_UNARY_OP_VEC(trunc)
DEFINE_DISPATCH(absImpl);
DEFINE_DISPATCH(acosImpl);
DEFINE_DISPATCH(asinImpl);
DEFINE_DISPATCH(atanImpl);
DEFINE_DISPATCH(ceilImpl);
DEFINE_DISPATCH(cosImpl);
DEFINE_DISPATCH(erfImpl);
DEFINE_DISPATCH(erfcImpl);
DEFINE_DISPATCH(expImpl);
DEFINE_DISPATCH(expm1Impl);
DEFINE_DISPATCH(floorImpl);
DEFINE_DISPATCH(logImpl);
DEFINE_DISPATCH(log10Impl);
DEFINE_DISPATCH(log1pImpl);
DEFINE_DISPATCH(log2Impl);
DEFINE_DISPATCH(roundImpl);
DEFINE_DISPATCH(rsqrtImpl);
DEFINE_DISPATCH(sigmoidImpl);
DEFINE_DISPATCH(sinImpl);
DEFINE_DISPATCH(sqrtImpl);
DEFINE_DISPATCH(tanImpl);
DEFINE_DISPATCH(tanhImpl);
DEFINE_DISPATCH(truncImpl);
}
} // namespace at