forked from takezo5096/DNN
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvariable.h
112 lines (68 loc) · 1.72 KB
/
variable.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#ifndef _VARIABLE_
#define _VARIABLE_
#include <list>
#include <random>
#include <memory>
#include <boost/intrusive_ptr.hpp>
#include "cuMat.h"
#include "cuMatSparse.h"
using namespace std;
class Function;
class Variable {
private:
friend class boost::serialization::access;
template<class Archive> void serialize(Archive & ar, const unsigned int version) {
ar & id;
ar & data;
ar & grad;
ar & seed;
ar & isGetGrad;
}
public:
int id = 0;
int opt = 0;
int *last_opt = NULL;
bool *is_last_backward = NULL;
int forward_count = 0;
Function *creator = NULL;
string name;
cuMat data;
cuMatSparse data_sparse;
cuMat grad;
cuMat seed;
int grad_num = -999;
bool isGetGrad = true;
bool isSparse = false;
Variable();
Variable(const Variable &a);
Variable(int rows, int cols);
Variable(int rows, int cols, bool is_get_grad);
Variable(Function *f, int rows, int cols);
Variable(cuMat &input);
Variable(Function *f, cuMat &input);
Variable(vector<float> &ids, int nums);
~Variable();
void creatorSet(Function *f);
Variable &operator=(const Variable &a);
Variable sin();
Variable log();
void backward();
void backward(Variable *v);
void zero_grads();
void zero_grads(Variable *v);
/*
void truncate();
void truncate(Variable *v);
*/
void ones();
void zeros();
void unchain();
void zero_grad();
void randoms(float m, float a);
void binominal_randoms(float ratio);
float val();
};
using PVariable = shared_ptr<Variable>;
Variable *variable_construct(int rows, int cols);
void variable_destroy(Variable *ptr);
#endif