diff --git a/include/af/autograd/Functions.hpp b/include/af/autograd/Functions.hpp index 17a190e..5ca6ff7 100644 --- a/include/af/autograd/Functions.hpp +++ b/include/af/autograd/Functions.hpp @@ -36,7 +36,11 @@ namespace af { Variable cos(const Variable &input); Variable tanh(const Variable &input); Variable sigmoid(const Variable &input); - + + Variable max(const Variable &lhs, const Variable &rhs); + Variable max(const Variable &lhs, const double &rhs); + Variable max(const double &lhs, const Variable &rhs); + Variable transpose(const Variable &input); Variable expandAs(const Variable &input, const Variable &reference); Variable reduceAs(const Variable &input, const Variable &reference); @@ -44,5 +48,7 @@ namespace af { Variable matmul(const Variable &lhs, const Variable &rhs); Variable matmulTN(const Variable &lhs, const Variable &rhs); Variable matmulNT(const Variable &lhs, const Variable &rhs); + + } } diff --git a/include/af/nn/Modules/Activations.hpp b/include/af/nn/Modules/Activations.hpp index 1530cd9..95beab9 100644 --- a/include/af/nn/Modules/Activations.hpp +++ b/include/af/nn/Modules/Activations.hpp @@ -30,5 +30,23 @@ namespace af autograd::Variable forward(const autograd::Variable &input); }; + + class ReLU : public Module + { + public: + ReLU(); + + autograd::Variable forward(const autograd::Variable &input); + }; + + class LeakyReLU : public Module + { + private: + double m_slope; + public: + LeakyReLU(double slope = 0.0); + + autograd::Variable forward(const autograd::Variable &input); + }; } } diff --git a/src/autograd/Functions.cpp b/src/autograd/Functions.cpp index 71048b6..87e0f7f 100644 --- a/src/autograd/Functions.cpp +++ b/src/autograd/Functions.cpp @@ -54,6 +54,18 @@ namespace af { }; return Variable(result, {lhs, rhs}, grad_func); } + + Variable operator >(const Variable &lhs, const Variable &rhs) + { + auto result = lhs.array() > rhs.array(); + return Variable(result, false); + } + + Variable operator <=(const Variable &lhs, const Variable &rhs) + { + auto result = lhs.array() <= rhs.array(); + return Variable(result, false); + } #define INSTANTIATE_OPERATOR(OP) \ Variable operator OP(const double &lhs_val, const Variable &rhs) \ @@ -78,10 +90,54 @@ namespace af { INSTANTIATE_OPERATOR(-) INSTANTIATE_OPERATOR(*) INSTANTIATE_OPERATOR(/) + INSTANTIATE_OPERATOR(>) + INSTANTIATE_OPERATOR(<=) #undef INSTANTIATE_OPERATOR - Variable negate(const Variable &input) + Variable operator !(const Variable &input) + { + auto result = !input.array(); + return Variable(result, false); + } + + Variable max(const Variable &lhs, const Variable &rhs) + { + auto mask = lhs > rhs; + auto result = max(lhs.array(), rhs.array()); + + auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) { + inputs[0].addGrad( inputs[2] * grad_output); + inputs[1].addGrad(!inputs[2] * grad_output); + }; + return Variable(result, {lhs, rhs, mask}, grad_func); + } + +#define INSTANTIATE_FUNCTION(FN) \ + Variable FN(const double &lhs_val, const Variable &rhs) \ + { \ + auto lhs = Variable( \ + af::constant(lhs_val, \ + rhs.array().dims(), \ + rhs.array().type()), \ + false); \ + return FN(lhs,rhs); \ + } \ + Variable FN(const Variable &lhs, const double &rhs_val) \ + { \ + auto rhs = Variable( \ + af::constant(rhs_val, \ + lhs.array().dims(), lhs.array().type()), \ + false); \ + return FN(lhs, rhs); \ + } + + + INSTANTIATE_FUNCTION(max); + +#undef INSTANTIATE_FUNCTION + + Variable negate(const Variable &input) { auto result = 0.0 - input.array(); auto grad_func = [](std::vector<Variable> &inputs, const Variable &grad_output) { diff --git a/src/nn/Modules/Activations.cpp b/src/nn/Modules/Activations.cpp index 0d1ca6e..05b9510 100644 --- a/src/nn/Modules/Activations.cpp +++ b/src/nn/Modules/Activations.cpp @@ -29,5 +29,22 @@ namespace af { return tanh(input); } + + ReLU::ReLU() {} + + Variable ReLU::forward(const Variable &input) + { + return max(input, 0.0); + } + + LeakyReLU::LeakyReLU(double slope) : + m_slope(slope) + { + } + + Variable LeakyReLU::forward(const Variable &input) + { + return max(input, m_slope * input); + } } }