Skip to content

Commit 0f170a9

Browse files
fohx13pavanky
authored andcommitted
Add ReLU and LeakyReLU modules
This commit adds ReLU and LeakyRelu modules, as well as a few helper functions, max and !.
1 parent 8129b47 commit 0f170a9

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

include/af/autograd/Functions.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,19 @@ namespace af {
3636
Variable cos(const Variable &input);
3737
Variable tanh(const Variable &input);
3838
Variable sigmoid(const Variable &input);
39-
39+
40+
Variable max(const Variable &lhs, const Variable &rhs);
41+
Variable max(const Variable &lhs, const double &rhs);
42+
Variable max(const double &lhs, const Variable &rhs);
43+
4044
Variable transpose(const Variable &input);
4145
Variable expandAs(const Variable &input, const Variable &reference);
4246
Variable reduceAs(const Variable &input, const Variable &reference);
4347

4448
Variable matmul(const Variable &lhs, const Variable &rhs);
4549
Variable matmulTN(const Variable &lhs, const Variable &rhs);
4650
Variable matmulNT(const Variable &lhs, const Variable &rhs);
51+
52+
4753
}
4854
}

include/af/nn/Modules/Activations.hpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,23 @@ namespace af
3030

3131
autograd::Variable forward(const autograd::Variable &input);
3232
};
33+
34+
class ReLU : public Module
35+
{
36+
public:
37+
ReLU();
38+
39+
autograd::Variable forward(const autograd::Variable &input);
40+
};
41+
42+
class LeakyReLU : public Module
43+
{
44+
private:
45+
double m_slope;
46+
public:
47+
LeakyReLU(double slope = 0.0);
48+
49+
autograd::Variable forward(const autograd::Variable &input);
50+
};
3351
}
3452
}

src/autograd/Functions.cpp

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,18 @@ namespace af {
5454
};
5555
return Variable(result, {lhs, rhs}, grad_func);
5656
}
57+
58+
Variable operator >(const Variable &lhs, const Variable &rhs)
59+
{
60+
auto result = lhs.array() > rhs.array();
61+
return Variable(result, false);
62+
}
63+
64+
Variable operator <=(const Variable &lhs, const Variable &rhs)
65+
{
66+
auto result = lhs.array() <= rhs.array();
67+
return Variable(result, false);
68+
}
5769

5870
#define INSTANTIATE_OPERATOR(OP) \
5971
Variable operator OP(const double &lhs_val, const Variable &rhs) \
@@ -78,10 +90,54 @@ namespace af {
7890
INSTANTIATE_OPERATOR(-)
7991
INSTANTIATE_OPERATOR(*)
8092
INSTANTIATE_OPERATOR(/)
93+
INSTANTIATE_OPERATOR(>)
94+
INSTANTIATE_OPERATOR(<=)
8195

8296
#undef INSTANTIATE_OPERATOR
8397

84-
Variable negate(const Variable &input)
98+
Variable operator !(const Variable &input)
99+
{
100+
auto result = !input.array();
101+
return Variable(result, false);
102+
}
103+
104+
Variable max(const Variable &lhs, const Variable &rhs)
105+
{
106+
auto mask = lhs > rhs;
107+
auto result = max(lhs.array(), rhs.array());
108+
109+
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {
110+
inputs[0].addGrad( inputs[2] * grad_output);
111+
inputs[1].addGrad(!inputs[2] * grad_output);
112+
};
113+
return Variable(result, {lhs, rhs, mask}, grad_func);
114+
}
115+
116+
#define INSTANTIATE_FUNCTION(FN) \
117+
Variable FN(const double &lhs_val, const Variable &rhs) \
118+
{ \
119+
auto lhs = Variable( \
120+
af::constant(lhs_val, \
121+
rhs.array().dims(), \
122+
rhs.array().type()), \
123+
false); \
124+
return FN(lhs,rhs); \
125+
} \
126+
Variable FN(const Variable &lhs, const double &rhs_val) \
127+
{ \
128+
auto rhs = Variable( \
129+
af::constant(rhs_val, \
130+
lhs.array().dims(), lhs.array().type()), \
131+
false); \
132+
return FN(lhs, rhs); \
133+
}
134+
135+
136+
INSTANTIATE_FUNCTION(max);
137+
138+
#undef INSTANTIATE_FUNCTION
139+
140+
Variable negate(const Variable &input)
85141
{
86142
auto result = 0.0 - input.array();
87143
auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) {

src/nn/Modules/Activations.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,22 @@ namespace af
2929
{
3030
return tanh(input);
3131
}
32+
33+
ReLU::ReLU() {}
34+
35+
Variable ReLU::forward(const Variable &input)
36+
{
37+
return max(input, 0.0);
38+
}
39+
40+
LeakyReLU::LeakyReLU(double slope) :
41+
m_slope(slope)
42+
{
43+
}
44+
45+
Variable LeakyReLU::forward(const Variable &input)
46+
{
47+
return max(input, m_slope * input);
48+
}
3249
}
3350
}

0 commit comments

Comments
 (0)