@@ -103,37 +103,46 @@ namespace af {
103
103
}
104
104
}
105
105
106
- void Variable::evalGrad ()
106
+ void Variable::evalGrad (bool retain_grad_graph )
107
107
{
108
108
// Flag asking not to calculate gradients
109
109
if (!m_shared->m_calc_grad ) return ;
110
110
111
111
// Best not to evaluate the JIT immediately if theres only a single gradient
112
+ Variable grad = m_shared->m_grads [0 ];
112
113
if (m_shared->m_grads .size () > 1 ) {
113
- Variable grad = m_shared->m_grads [0 ];
114
114
for (unsigned i = 1 ; i < m_shared->m_grads .size (); i++) {
115
115
grad = grad + m_shared->m_grads [i];
116
116
}
117
117
grad.array ().eval ();
118
- m_shared->m_grads .clear ();
119
- m_shared->m_grads .push_back (grad);
118
+ m_shared->m_grads .resize (1 );
119
+ }
120
+
121
+ // Remove the graph if not needed
122
+ if (!retain_grad_graph) {
123
+ // This can be done by extracting af::array and ignoring everything else
124
+ auto grad_data = grad.array ();
125
+ // Since there's no graph leading this, set calc_grad to false
126
+ grad = Variable (grad_data, false );
120
127
}
128
+
129
+ m_shared->m_grads [0 ] = grad;
121
130
}
122
131
123
- void Variable::calcGradInputs ()
132
+ void Variable::calcGradInputs (bool retain_grad_graph )
124
133
{
125
134
evalGrad ();
126
135
if (m_shared->m_grad_func ) {
127
136
m_shared->m_grad_func (m_shared->m_inputs , m_shared->m_grads [0 ]);
128
137
}
129
138
}
130
139
131
- void Variable::backward (const Variable &grad)
140
+ void Variable::backward (const Variable &grad, bool retain_grad_graph )
132
141
{
133
142
this ->addGrad (grad);
134
143
Variable::DAG_t dag = this ->build ();
135
144
for (auto iter = dag.rbegin (); iter != dag.rend (); iter++) {
136
- iter->calcGradInputs ();
145
+ iter->calcGradInputs (retain_grad_graph );
137
146
}
138
147
}
139
148
0 commit comments