Skip to content

Commit eaa8e37

Browse files
committed
Disable higher order gradients by default.
This can be enabled by passing true as second argument to backward
1 parent 9b05273 commit eaa8e37

File tree

2 files changed

+21
-13
lines changed

2 files changed

+21
-13
lines changed

include/af/autograd/Variable.hpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,16 +60,15 @@ namespace af {
6060

6161
void addGrad(const Variable &child_grad);
6262

63-
void evalGrad();
63+
void calcGradInputs(bool retain_grad_graph = false);
6464

65-
void calcGradInputs();
66-
67-
void backward(const Variable &grad);
68-
69-
DAG_t build();
65+
void backward(const Variable &grad, bool retain_grad_graph = false);
7066

7167
void buildSubGraph(Cache_t &cache, DAG_t &dag);
7268
private:
69+
void evalGrad(bool retain_grad_graph = false);
70+
71+
DAG_t build();
7372
std::shared_ptr<Shared> m_shared;
7473
};
7574
}

src/autograd/Variable.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -103,37 +103,46 @@ namespace af {
103103
}
104104
}
105105

106-
void Variable::evalGrad()
106+
void Variable::evalGrad(bool retain_grad_graph)
107107
{
108108
// Flag asking not to calculate gradients
109109
if (!m_shared->m_calc_grad) return;
110110

111111
// Best not to evaluate the JIT immediately if theres only a single gradient
112+
Variable grad = m_shared->m_grads[0];
112113
if (m_shared->m_grads.size() > 1) {
113-
Variable grad = m_shared->m_grads[0];
114114
for (unsigned i = 1; i < m_shared->m_grads.size(); i++) {
115115
grad = grad + m_shared->m_grads[i];
116116
}
117117
grad.array().eval();
118-
m_shared->m_grads.clear();
119-
m_shared->m_grads.push_back(grad);
118+
m_shared->m_grads.resize(1);
119+
}
120+
121+
// Remove the graph if not needed
122+
if (!retain_grad_graph) {
123+
// This can be done by extracting af::array and ignoring everything else
124+
auto grad_data = grad.array();
125+
// Since there's no graph leading this, set calc_grad to false
126+
grad = Variable(grad_data, false);
120127
}
128+
129+
m_shared->m_grads[0] = grad;
121130
}
122131

123-
void Variable::calcGradInputs()
132+
void Variable::calcGradInputs(bool retain_grad_graph)
124133
{
125134
evalGrad();
126135
if (m_shared->m_grad_func) {
127136
m_shared->m_grad_func(m_shared->m_inputs, m_shared->m_grads[0]);
128137
}
129138
}
130139

131-
void Variable::backward(const Variable &grad)
140+
void Variable::backward(const Variable &grad, bool retain_grad_graph)
132141
{
133142
this->addGrad(grad);
134143
Variable::DAG_t dag = this->build();
135144
for (auto iter = dag.rbegin(); iter != dag.rend(); iter++) {
136-
iter->calcGradInputs();
145+
iter->calcGradInputs(retain_grad_graph);
137146
}
138147
}
139148

0 commit comments

Comments
 (0)