-
Notifications
You must be signed in to change notification settings - Fork 18
/
cpp-implementation.Rmd
379 lines (344 loc) · 12.6 KB
/
cpp-implementation.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
# (APPENDIX) Appendices {-}
# Reference C++ Implementations
## Forward-mode automatic differentiation in C++
Forward-mode automatic differentiation can be implemented directly in
C++ following the pattern established in the standard library for
complex numbers. A forward-mode autodiff variable is represented by a
dual number holding two scalar values, a constructor where the second
value defaults to zero, and getters for the value and tangent.
```
namespace autodiff {
template <typename T>
class dual {
T val_;
T tan_;
public:
dual(const T& val = 0, const T& tan = 0)
: val_(val), tan_(tan) { }
const T& val() const { return val_; }
const T& tan() const { return tan_; }
};
}
```
The type of value is templated to supported nested forward mode.
The default copy constructor `dual(const dual<T>&)`,
destructor `~dual()`, and assignment operator `operator=(const
dual<T>&)` are sufficient.
Dual numbers are constructed from either a value and a tangent, or
just a value with default zero tangent, or with neither and default
zero values and tangent.
```
using autodiff::dual;
dual<double> a(2.2, -3.1);
dual<double> b(5);
dual<double> c;
```
The first has a value of 2.2 and a tangent of -3.1, whereas the second
has a value of 5 and a tangent of 0 (the default) and the last a value
and tangent of 0.
The constructor is not declared implicit and thus will support the
assignment of primitives, giving them default 0 tangent values. The
assignments
```
dual<double> c = 5;
```
and
```
dual<double> d = dual<double>(5);
```
both invoke the default copy assignment operator `operator=(const
dual<double>&)`; the former promotes `5` to `dual<double>(5)` using
the non-explicit unary constructor.
Functions and operators are coded following the dual arithmetic. All
autodiff functionality will be declared in the namespace `autodiff`,
though the namespace qualification will not be shown. For example,
exponentiation is defined following the dual arithmetic rule.
```
#include <cmath>
template <typename T>
inline dual<T> exp(const dual<T>& x) {
using std::exp;
T y = exp(x.val());
return dual(y, x.tan() * y);
}
```
So is the logarithm function.
```
template <typename T>
inline dual<T> log(const dual<T>& x) {
using std::log;
T y = log(x.val());
return dual(y, x.tan() / y );
}
```
The standard library `cmath` is required for definitions of `exp` and
`log` for primitive values. The function definitions begin with using
statements, e.g., `using std::exp`. This allows the `exp` defined in
the standard library `cmath` to be used for primitives and the `exp`
defined for type `T` to be found by argument-dependent lookup for
autodiff value types `T`.
Binary operations can be implemented following the dual number
definitions.
```
template <typename T>
inline dual<T> operator*(const dual<T>& x1, const dual<T>& x2) {
return dual(x1.val() * x2.val(),
x1.tan() * x2.val() + x1.val() * x2.tan());
}
```
This suffices for the case where both arguments are autodiff variables,
```
dual<double> x1 = -1.3;
dual<double> x2 = 2.1;
dual<double> u = x1 * x2;
```
There is no need to explicitly bring in `autodiff::operator+` because it is
included implicitly by argument-dependent lookup.
The following statements which mix autodiff variables and primitives
will not match the templated `operator*` because the primitive argment
types do not match the template.
```
dual<double> u = 1.2;
dual<double> v = u * 3.2;
dual<double> w = 2 * u;
```
The multiplication operator (`operator*`) can be further overloaded in
order to support these mixed types.
```
#include <type_traits>
template <typename T, typename U,
typename = std::enable_if_t<std::is_arithmetic_v<U>>>
inline dual<T> operator*(const dual<T>& x1, const U& c2) {
return dual(x1.val() * c2, x1.tan() * c2);
}
template <typename T, typename U,
typename = std::enable_if_t<std::is_arithmetic_v<U>>>
inline dual<T> operator*(const U& c1, const dual<T>& x2) {
return dual(c1 * x2.val(), c1 * x2.tan());
}
```
The third template argument invlves a C++ idiom that requires the
template parameter `U` to be a primitive.^[Arithmetic types include
only the built-in primitive types `float`, `double`, `long double`,
`bool`, `char`, `short`, `int`, or `long int` as of C++17.]
`U` being preventing a match of the template function unless
`U` is a primitive; the functions `enable_if_t` and `is_arithmetic_v`
are declared in the standard library header `<type_traits>`.
## Reverse-mode automatic differentiation in C++
Like forward mode, reverse-mode automatic differentiation can be
implemented through operator overloading in C++. As with forward
mode, argument-dependent lookup means that templated code will just
work with autodiff variables as long as all primitive functions
invoked are defined for autodiff types.
A template using statement will reduce the boilerplate in requiring
arithmetic arguments.
```
#include <type_traits>
template <typename T>
using enable_if_arithmetic_t
= std::enable_if_t<std::is_arithmetic_v<T>>;
```
The core code for reverse-mode autodiff defines a class `adj` used to
store values and an index that will be unique for each subexpression.
```
#include <cstddef>
std::size_t next_ = 0;
class adj {
double val_;
std::size_t idx_;
public:
template <typename T, typename = enable_if_arithmetic_t<T>>
adj(T val = 0, int idx = next_++)
: val_(val), idx_(idx) { }
double value() const { return val_; }
double index() const { return idx_; }
};
```
The global counter `next_` is used to assign unique identifiers in
sequence to each autodiff variable as it is constructed, so it must be
initialized to zero before any autodiff calculations. The autodiff
vaiable class `adj` holds a double precision value and a unique
index. The constructor is responsible for generating indexes and
storing values. The default copy constructor, assignment operator,
and destructor suffice here.
Usage is similar to that of forward-mode autodiff variables.
```
using autodiff:adj;
autodiff::next_ = 0; // initialize stack before starting
adj x(3.7); // construct from value
```
The constructor call for `x` allocates a unique index and increments
the global index counter. Assignment of arithmetic values works by
promotion using the implicit constructor, so that
```
adj y = 2.9; // assignment works by promoting
```
is equivalent to
```
adj y = adj(2.9);
```
In order to carry out reverse-mode automatic differentiation, each
expression must create and store a continuation used to propagate
adjoints from the result to the operands in the reverse sweep. In the
reference implementation, these continuations are pushed ont a global
stack as they are created.
```
#include <vector>
#include <functional>
std::vector<std::function<void(std::vector<double>&)>> stack_;
```
The reverse sweep is implemented by the `chain()` function, which
takes the variable `y` from which derivatives should be propagated.
```
std::vector<double> chain(const adj& y) {
std::vector<double> adjoints(y.idx() + 1, 0);
adjoints[y.idx()] = 1;
for (auto chain_f = stack_.crbegin();
chain_f != stack_.crend();
++chain_f)
(*chain_f)(adjoints);
return adjoints;
}
```
First, the vector `adjoints` of adjoint values is allocated at
size `y.idx_ + 1` so that it's large enough to the adjoints of every
expression involved in the calculation of `y`; this is guaranteed to
be enough because every expression involved in the calculation of `y`
has an index lower than `y`'s. The initial values are set to zero in
the constructor for `adjoints`. To begin the reverse sweep, the
adjoint for `y`, namely `adjoints[y.idx_]` is set to one. Then the
stack of continuations is traversed from `y` down to the independent
variables, executing each continuation on the stack applied to the
adjoint vector. Finally, it returns the adjoints that are calculated
so that derivatives may be retrieved.
A simple operation like addition is overloaded as follows.
```
inline adj operator+(const adj& x1, const adj& x2) {
adj y(x1.val() + x2.val());
auto f = [=](std::vector<double>& adj) {
adj[x1.idx()] += adj[y.idx()];
adj[x2.idx()] += adj[y.idx()];
};
stack_.emplace_back(f);
return y;
}
```
First, the result `y` is constructed with value equal to adding the
values of the arguments, `x1.val_` and `x2.val_`. Then a continuation
`f` for the chain rule is defined as an anoymous function using a
lambda. The notation `[=]` indicates that the lambda captures the
values of variables for later execution by copying. Here, the
variables captured are `x1`, `x2`, and `y`. The continuation is
declared to take a mutable reference to a vector of double-precision
floating point values as an argument---these hold the adjoints of all
the subexpressions as declared in the `chain()` function. The body of
the continuation follows the reverse-mode adjoint rule for addition,
namely adding the adjoint of the result `y` to the adjoint of each of
the operands, `x1` and `x2`. After the continuation is defined, it is
pushed back onto the global stack. Finally, the value `y` is
returned.
While the above code will work by promoting arithmetic values to
the adjoint class, it is more efficient to define further overloads
that are more specific and avoid the redundant work on the stack.
```
template <typename T, typename = enable_if_arithmetic_t<T>>
inline adj operator+(const adj& x1, T x2) {
adj y(x1.val() + x2);
stack_.emplace_back([=](std::vector<double>& adj) {
adj[x1.idx()] += adj[y.idx()];
});
return y;
}
template <typename T, typename = enable_if_arithmetic_t<T>>
inline adj operator+(T x1, const adj& x2) {
adj y(x1 + x2.val());
stack_.emplace_back([=](std::vector<double>& adj) {
adj[x2.idx()] += adj[y.idx()];
});
return y;
}
```
Rather than defining a temporary for the continuation, it is pushed
directly onto the stack. The value is computed using the value from
the adjoint variables and the primitives directly, and the only
propagation is to the adjoint operand.
Multiplication is defined similarly, with the captured operand's
values and indexes both being used.
```
inline adj operator*(const adj& x1, const adj& x2) {
adj y(x1.val() * x2.val());
stack_.emplace_back([=](std::vector<double>& adj) {
adj[x1.idx()] += x2.val() * adj[y.idx()];
adj[x2.idx()] += x1.val() * adj[y.idx()];
});
return y;
}
template <typename T, typename = enable_if_arithmetic_t<T>>
inline adj operator*(const adj& x1, T x2) {
adj y(x1.val() * x2);
stack_.emplace_back([=](std::vector<double>& adj) {
adj[x1.idx()] += x2 * adj[y.idx()];
});
return y;
}
template <typename T, typename = enable_if_arithmetic_t<T>>
inline adj operator*(T x1, const adj& x2) {
adj y(x1 * x2.val());
stack_.emplace_back([=](std::vector<double>& adj) {
adj[x2.idx()] += x1 * adj[y.idx()];
});
return y;
}
```
Non-linear functions like exponentiation also follow their
definitions. We need the `<cmath>` library for a definition of the
exponential function.
```
#include <cmath>
namespace autodiff {
inline adj exp(const adj& x) {
adj y(std::exp(x.val()));
auto f = [=](std::vector<double>& adj) {
adj[x.idx()] += y.val() * adj[y.idx()];
};
stack_.emplace_back(f);
return y;
}
}
```
The constructor defines the value of `y` to be the value of `x`
exponentiated. The adjoint is incremented using the captured value of
`y`, namely `exp(x.val_)`, which is the derivative of `y` with respect
to `x`.
The following code computes $\nabla f(10.3, -1.1)$, where $f(x_1, x_2)
= x_1 \cdot \exp(x_2 \cdot 2) + 7$.
```
#include <iostream>
int main() {
using autodiff::adj;
next_idx = 0;
stack_.clear();
adj x1 = 10.3;
adj x2 = -1.1;
adj y = x1 * exp(x2 * 2) + 7;
std::vector<double> adjoints = chain(y);
double dy_dx1 = adjoints(x1.idx_);
double dy_dx2 = adjoints(x2.idx_);
std::cout << "grad f = [" << dy_dx1 << ", " << dy_dx2 << "]" << std::endl;
return 0;
}
```
First, the index counter and stack are reset. Then the inddependent
variables `x1` and `x2` are initialized. The resulting dependent
variable `y` is computed as a single expression and is also of
autodiff variable type. The definitions of `operator*`, `operator+`,
and `exp()` are found through argument-dependent lookup. Next, the
reverse sweep is carried out starting from the result `y` using the
`chain()` function. The resulting adjoints for `x1` and `x2` are
found by indexing the vector of adjoints returned by `chain().` These
are then printed and the default success code (zero) is returned.
## References
The reverse-mode autodiff implementation is based on
[@carpenter:2018]. Matrices are implemented with the Eigen C++
library [@gunnebaud:2020]. A thorough and precise introduction to
modern C++ template programming is [@vandevoorde:2017].