Skip to content

Commit

Permalink
Use references while iterating when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
pavanky authored and umar456 committed Jul 6, 2017
1 parent 82d77dd commit 8129b47
Show file tree
Hide file tree
Showing 5 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion examples/perceptron.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ int main()

// Update parameters
// TODO: Should use optimizer
for (auto param : perceptron.parameters()) {
for (auto &param : perceptron.parameters()) {
param.array() += lr * param.grad().array();
param.array().eval();
}
Expand Down
4 changes: 2 additions & 2 deletions include/af/autograd/Variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ namespace af {

af::array& array() const;

Variable grad() const;
Variable& grad() const;

std::ptrdiff_t id() const;

Expand All @@ -74,7 +74,7 @@ namespace af {
private:
void evalGrad(bool retain_grad_graph = false);

std::vector<Variable> getInputs() const;
std::vector<Variable>& getInputs() const;

static void buildSubGraph(Cache_t &cache, DAG_t &dag, const Variable &var);

Expand Down
8 changes: 4 additions & 4 deletions src/autograd/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ namespace af {
m_shared(nullptr)
{
bool calc_grad = false;
for (auto input : inputs) {
for (const auto &input : inputs) {
calc_grad |= input.isCalcGrad();
}
if (calc_grad) {
Expand All @@ -70,7 +70,7 @@ namespace af {
return m_shared->m_data;
}

Variable Variable::grad() const
Variable& Variable::grad() const
{
if (!m_shared->m_calc_grad) {
throw af::exception("Gradient calclation disabled.");
Expand All @@ -86,7 +86,7 @@ namespace af {
return (std::ptrdiff_t)m_shared.get();
}

std::vector<Variable> Variable::getInputs() const
std::vector<Variable>& Variable::getInputs() const
{
return m_shared->m_inputs;
}
Expand Down Expand Up @@ -181,7 +181,7 @@ namespace af {
if (cache.find(id) != cache.end()) {
return;
}
for (auto input : var.getInputs()) {
for (const auto &input : var.getInputs()) {
Variable::buildSubGraph(cache, dag, input);
}
cache[id] = true;
Expand Down
2 changes: 1 addition & 1 deletion src/nn/Modules/Container.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace af
Variable Sequential::forward(const Variable &input)
{
Variable output = input;
for(auto module : m_modules) {
for (auto &module : m_modules) {
output = module->forward(output);
}
return output;
Expand Down
4 changes: 2 additions & 2 deletions src/nn/Modules/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ namespace af

void Module::train()
{
for (auto parameter : m_parameters) {
for (auto &parameter : m_parameters) {
parameter.setCalcGrad(true);
}
}

void Module::eval()
{
for (auto parameter : m_parameters) {
for (auto &parameter : m_parameters) {
parameter.setCalcGrad(false);
}
}
Expand Down

0 comments on commit 8129b47

Please sign in to comment.