-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathoptimizer_adagrad.h
58 lines (36 loc) · 1.11 KB
/
optimizer_adagrad.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
/*
* optimizer_adagrad.h
*
*/
#ifndef OPTIMIZER_ADAGRAD_H_
#define OPTIMIZER_ADAGRAD_H_
#include "model.h"
#include "optimizer.h"
class OptimizerAdagradParams : public OptimizerParams {
public:
cuMat ndw;
cuMat g2;
OptimizerAdagradParams(int output_units, int input_units) {
ndw = cuMat(output_units, input_units);
g2 = cuMat(output_units, input_units);
}
};
class OptimizerAdagrad : public Optimizer {
public:
OptimizerAdagrad(Model *model, float lr) : Optimizer(model, lr) {
}
OptimizerAdagrad(Model *model, float lr, float clip_grad_threshold) : Optimizer(model, lr, clip_grad_threshold) {
}
OptimizerParams *createOptimizerParams(Variable *v){
return new OptimizerAdagradParams(v->data.rows, v->data.cols);
}
void update_param(Variable *w, OptimizerParams &opp) {
OptimizerAdagradParams &op = (OptimizerAdagradParams &)opp;
op.g2 += w->grad * w->grad;
cuMat tmp = op.g2.sqrt();
tmp = w->grad / tmp;
tmp.mul(-lr, op.ndw);
w->data.plus(op.ndw, w->data);
}
};
#endif /* OPTIMIZER_ADAGRAD_H_ */