Skip to content

Commit

Permalink
Disable higher order gradients by default.
Browse files Browse the repository at this point in the history
This can be enabled by passing true as second argument to backward
  • Loading branch information
pavanky committed Jul 5, 2017
1 parent 9b05273 commit eaa8e37
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
11 changes: 5 additions & 6 deletions include/af/autograd/Variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,15 @@ namespace af {

void addGrad(const Variable &child_grad);

void evalGrad();
void calcGradInputs(bool retain_grad_graph = false);

void calcGradInputs();

void backward(const Variable &grad);

DAG_t build();
void backward(const Variable &grad, bool retain_grad_graph = false);

void buildSubGraph(Cache_t &cache, DAG_t &dag);
private:
void evalGrad(bool retain_grad_graph = false);

DAG_t build();
std::shared_ptr<Shared> m_shared;
};
}
Expand Down
23 changes: 16 additions & 7 deletions src/autograd/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,37 +103,46 @@ namespace af {
}
}

void Variable::evalGrad()
void Variable::evalGrad(bool retain_grad_graph)
{
// Flag asking not to calculate gradients
if (!m_shared->m_calc_grad) return;

// Best not to evaluate the JIT immediately if theres only a single gradient
Variable grad = m_shared->m_grads[0];
if (m_shared->m_grads.size() > 1) {
Variable grad = m_shared->m_grads[0];
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
grad = grad + m_shared->m_grads[i];
}
grad.array().eval();
m_shared->m_grads.clear();
m_shared->m_grads.push_back(grad);
m_shared->m_grads.resize(1);
}

// Remove the graph if not needed
if (!retain_grad_graph) {
// This can be done by extracting af::array and ignoring everything else
auto grad_data = grad.array();
// Since there's no graph leading this, set calc_grad to false
grad = Variable(grad_data, false);
}

m_shared->m_grads[0] = grad;
}

void Variable::calcGradInputs()
void Variable::calcGradInputs(bool retain_grad_graph)
{
evalGrad();
if (m_shared->m_grad_func) {
m_shared->m_grad_func(m_shared->m_inputs, m_shared->m_grads[0]);
}
}

void Variable::backward(const Variable &grad)
void Variable::backward(const Variable &grad, bool retain_grad_graph)
{
this->addGrad(grad);
Variable::DAG_t dag = this->build();
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
iter->calcGradInputs();
iter->calcGradInputs(retain_grad_graph);
}
}

Expand Down

0 comments on commit eaa8e37

Please sign in to comment.