Skip to content

Commit

Permalink
Add option to explicitly request higher order gradients.
Browse files Browse the repository at this point in the history
- Disabled by default
- can be enabled by passing true as second argument to backward
  • Loading branch information
pavanky committed Jul 5, 2017
1 parent 9b05273 commit 49b8917
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 13 deletions.
11 changes: 5 additions & 6 deletions include/af/autograd/Variable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,15 @@ namespace af {

void addGrad(const Variable &child_grad);

void evalGrad();
void calcGradInputs(bool retain_grad_graph = false);

void calcGradInputs();

void backward(const Variable &grad);

DAG_t build();
void backward(const Variable &grad, bool retain_grad_graph = false);

void buildSubGraph(Cache_t &cache, DAG_t &dag);
private:
void evalGrad(bool retain_grad_graph = false);

DAG_t build();
std::shared_ptr<Shared> m_shared;
};
}
Expand Down
23 changes: 16 additions & 7 deletions src/autograd/Variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,37 +103,46 @@ namespace af {
}
}

void Variable::evalGrad()
void Variable::evalGrad(bool retain_grad_graph)
{
// Flag asking not to calculate gradients
if (!m_shared->m_calc_grad) return;

// Best not to evaluate the JIT immediately if theres only a single gradient
Variable grad = m_shared->m_grads[0];
if (m_shared->m_grads.size() > 1) {
Variable grad = m_shared->m_grads[0];
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
grad = grad + m_shared->m_grads[i];
}
grad.array().eval();
m_shared->m_grads.clear();
m_shared->m_grads.push_back(grad);
m_shared->m_grads.resize(1);
}

// Remove the graph if not needed
if (!retain_grad_graph) {
// This can be done by extracting af::array and ignoring everything else
auto grad_data = grad.array();
// Since there's no graph leading this, set calc_grad to false
grad = Variable(grad_data, false);
}

m_shared->m_grads[0] = grad;
}

void Variable::calcGradInputs()
void Variable::calcGradInputs(bool retain_grad_graph)
{
evalGrad();
if (m_shared->m_grad_func) {
m_shared->m_grad_func(m_shared->m_inputs, m_shared->m_grads[0]);
}
}

void Variable::backward(const Variable &grad)
void Variable::backward(const Variable &grad, bool retain_grad_graph)
{
this->addGrad(grad);
Variable::DAG_t dag = this->build();
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
iter->calcGradInputs();
iter->calcGradInputs(retain_grad_graph);
}
}

Expand Down

0 comments on commit 49b8917

Please sign in to comment.