From 7e7cc69b103b70461605b8a80ed3674f21957104 Mon Sep 17 00:00:00 2001 From: Pavan Yalamanchili Date: Wed, 2 Aug 2017 15:49:50 -0700 Subject: [PATCH] overload Module::operator() to call Module::forward --- examples/perceptron.cpp | 6 +++--- include/af/nn/Modules/Loss.hpp | 3 +++ include/af/nn/Modules/Module.hpp | 2 ++ src/nn/Modules/Loss.cpp | 6 ++++++ src/nn/Modules/Module.cpp | 5 +++++ 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/perceptron.cpp b/examples/perceptron.cpp index 0854369..dd3e3d3 100644 --- a/examples/perceptron.cpp +++ b/examples/perceptron.cpp @@ -51,10 +51,10 @@ int main() af::array out_j = out(af::span, j); // Forward propagation - result = perceptron.forward(nn::input(in_j)); + result = perceptron(nn::input(in_j)); // Calculate loss - l = loss.forward(result, nn::noGrad(out_j)); + l = loss(result, nn::noGrad(out_j)); // Backward propagation l.backward(); @@ -71,7 +71,7 @@ int main() perceptron.eval(); // Forward propagation - result = perceptron.forward(nn::input(in)); + result = perceptron(nn::input(in)); // Calculate loss // TODO: Use loss function diff --git a/include/af/nn/Modules/Loss.hpp b/include/af/nn/Modules/Loss.hpp index 44b743a..dd5d1e8 100644 --- a/include/af/nn/Modules/Loss.hpp +++ b/include/af/nn/Modules/Loss.hpp @@ -23,6 +23,9 @@ namespace af const autograd::Variable &targets) = 0; autograd::Variable forward(const autograd::Variable &inputs); + + autograd::Variable operator()(const autograd::Variable &inputs, + const autograd::Variable &targets); }; class MeanSquaredError : public Loss diff --git a/include/af/nn/Modules/Module.hpp b/include/af/nn/Modules/Module.hpp index a6189ae..9666ee3 100644 --- a/include/af/nn/Modules/Module.hpp +++ b/include/af/nn/Modules/Module.hpp @@ -42,6 +42,8 @@ namespace af void eval(); virtual autograd::Variable forward(const autograd::Variable &input) = 0; + + autograd::Variable operator()(const autograd::Variable &input); }; } } diff --git a/src/nn/Modules/Loss.cpp b/src/nn/Modules/Loss.cpp index ab7f80a..226f8e4 100644 --- a/src/nn/Modules/Loss.cpp +++ b/src/nn/Modules/Loss.cpp @@ -21,6 +21,12 @@ namespace af throw af::exception("Loss module requires both inputs and targets"); } + autograd::Variable Loss::operator()(const autograd::Variable &inputs, + const autograd::Variable &targets) + { + return this->forward(inputs, targets); + } + autograd::Variable MeanSquaredError::forward(const autograd::Variable &inputs, const autograd::Variable &targets) { diff --git a/src/nn/Modules/Module.cpp b/src/nn/Modules/Module.cpp index 9cf547f..c769a02 100644 --- a/src/nn/Modules/Module.cpp +++ b/src/nn/Modules/Module.cpp @@ -60,5 +60,10 @@ namespace af parameter.zeroGrad(); } } + + Variable Module::operator()(const Variable &input) + { + return this->forward(input); + } } }