forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
custom_function.cpp
270 lines (237 loc) · 9.72 KB
/
custom_function.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
#include <torch/csrc/autograd/custom_function.h>
#include <torch/csrc/autograd/functions/accumulate_grad.h>
#include <torch/csrc/autograd/autograd.h>
namespace torch { namespace autograd {
VariableInfo::VariableInfo(const Variable& var)
: layout(var.layout())
, device(var.device())
, scalar_type(var.scalar_type())
, size(var.sizes().vec())
, requires_grad(var.requires_grad())
, is_empty(false) {
}
VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
if (is_empty) {
// Return undefined tensor.
return at::Tensor();
} else {
return at::zeros(
size, at::TensorOptions(scalar_type).device(device).layout(layout));
}
}
std::vector<c10::optional<Variable>> _wrap_outputs(const variable_list &input_vars,
const std::unordered_set<at::TensorImpl*> &non_differentiable,
const std::unordered_set<at::TensorImpl*> &dirty_inputs,
const at::ArrayRef<c10::optional<Variable>> raw_outputs,
const std::shared_ptr<Node> &cdata) {
std::unordered_set<at::TensorImpl*> inputs;
inputs.reserve(input_vars.size());
for (auto& var : input_vars) {
inputs.emplace(var.unsafeGetTensorImpl());
}
int num_outputs = raw_outputs.size();
// Sets the grad_fn and output_nr of an output Variable.
auto set_history = [&](Variable& var, uint32_t output_nr, bool is_input, bool is_modified,
bool is_differentiable) {
if (!is_differentiable) {
if (!var.requires_grad()) {
return;
}
// Return detached aliases of inputs, instead of changing their requires_grad
// property.
if (is_input) {
var = var.detach();
} else if (!var.is_view()) {
var.detach_();
}
// If var is a view of one of the inputs of the custom autograd Function,
// we don't detach it in a no_grad block. This is so that we can mimic the
// behavior of returning a view from a no_grad block:
// x = torch.randn(3, requires_grad=True)
// with torch.no_grad():
// y = x.view(-1)
// Here, `y` requires_grad (!).
} else if (is_modified) {
if (var.is_leaf() && var.requires_grad()) {
throw std::runtime_error("a leaf Variable that requires grad has been used in an in-place operation.");
}
// No need to mark as modified Tensors that are not inputs.
if (!is_input) {
TORCH_WARN("Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
" is no need to pass it to mark_dirty().");
}
// If the input is a view, the rebase will need to rewrite the graph and this only works if we have a single
// output to this Function.
TORCH_CHECK(!(var.is_view() && num_outputs > 1), "If your Function modifies inplace an input that is a view"
" of another Tensor, your Function cannot return more than one Tensor. This is not supported"
" by the current autograd engine. You should either make sure the input is not a view (using"
" .clone() for example) or make your Function only return one Tensor (potentially splitting"
" it into two Functions: one doing the inplace that returns a single Tensor and a second one"
" that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if"
" you need help to do this change.");
// If the input was modified, transplant the grad_fn in the graph:
// grad_fn <- variable <- self ==> grad_fn <- self <- variable
var.mutable_grad().reset();
impl::clear_hooks(var);
if (auto grad_acc_fn = impl::try_get_grad_accumulator(var)) {
auto grad_acc = dynamic_cast<AccumulateGrad*>(grad_acc_fn.get());
grad_acc->variable.reset();
}
if (cdata) {
impl::rebase_history(var, {cdata, output_nr});
}
} else if (is_input) {
// An input has been returned, but it wasn't modified. Return it as a view
// so that we can attach a new grad_fn to the Variable.
// Run in no_grad mode to mimic the behavior of the forward.
{
AutoGradMode grad_mode(false);
var = var.view_as(var);
}
impl::set_gradient_edge(var, {cdata, output_nr});
} else if (cdata) {
impl::set_gradient_edge(var, {cdata, output_nr});
}
};
std::vector<c10::optional<Variable>> outputs;
std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
outputs.reserve(num_outputs);
int num_diff_outputs = 0;
for (auto i = 0; i < num_outputs; ++i) {
// For outputs that are not tensors, put a placeholder undefined input.
if (!raw_outputs[i].has_value()) {
if (cdata) {
auto output_nr = cdata->add_input_metadata(Node::undefined_input());
AT_ASSERT(i == (int)output_nr);
}
outputs.emplace_back();
continue;
}
Variable var = raw_outputs[i].value();
auto out_tensor_impl = var.unsafeGetTensorImpl();
bool is_input = inputs.count(out_tensor_impl) > 0;
bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
bool is_differentiable = cdata && non_differentiable.count(out_tensor_impl) == 0
&& isDifferentiableType(var.scalar_type());
if (cdata) {
auto output_nr = cdata->add_input_metadata(var);
AT_ASSERT(i == (int)output_nr);
}
set_history(var, i, is_input, is_modified, is_differentiable);
// For deprecation cycle. Can be removed after 1.6. In the case where we detected a view
// in no grad mode during the forward, only warn the user (do not change the flag if we
// return and input that is a view as is).
// See NOTE [ View + Inplace detection ] for why we replace everything by a warning.
if (!(is_input && is_modified) && var.is_view()) {
// is_view() => diff_view_meta
auto diff_view_meta = impl::get_view_autograd_meta(var);
diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
}
if (is_differentiable) {
++num_diff_outputs;
}
outputs_impl.insert(out_tensor_impl);
outputs.emplace_back(var);
}
// If multiple differentiable outputs are returned, we do not allow views to be modified inplace
// See NOTE [ View + Inplace detection ] for more details
if (num_diff_outputs > 1) {
for (auto& var: outputs) {
if (var.has_value()) {
auto diff_view_meta = impl::get_view_autograd_meta(var.value());
if (diff_view_meta && diff_view_meta->has_bw_view()) {
diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
}
}
}
}
// All the modified Tensors must be returned as is for the rewrite to be valid.
for (auto& dirty_input : dirty_inputs) {
TORCH_CHECK(outputs_impl.count(dirty_input) > 0,
"Some elements marked as dirty during the forward method were not returned as output. The"
" inputs that are modified inplace must all be outputs of the Function.");
}
return outputs;
}
void check_variable_result(const Variable& original, const Variable& result, std::string hook_name) {
if (!original.options().type_equal(result.options())) {
std::stringstream ss;
ss << "hook '" << hook_name << "' has changed the type of value (";
ss << "was " << original.toString() << " got ";
ss << result.toString() << ")";
throw std::runtime_error(ss.str());
}
if (original.is_cuda() != result.is_cuda()) {
std::stringstream ss;
ss << "hook '" << hook_name << "' has changed the type of value";
if (original.is_cuda()) {
ss << " (was CUDA tensor got CPU tensor)";
} else {
ss << " (was CPU tensor got CUDA tensor)";
}
throw std::runtime_error(ss.str());
}
if (original.sizes().vec() != result.sizes().vec()) {
std::stringstream ss;
ss << "hook '" << hook_name << "' has changed the size of value";
throw std::runtime_error(ss.str());
}
}
void AutogradContext::save_for_backward(variable_list to_save) {
to_save_ = std::move(to_save);
}
// The logic for handling saved variables here is the same as python_function.cpp
// See _save_variables() and unpack_saved_variables()
void AutogradContext::save_variables() {
saved_variables_.clear();
auto ptr = grad_fn_.lock();
for (const auto& var : to_save_) {
// Allow empty variables to be saved
if (var.defined()) {
bool is_output = var.grad_fn().get() == ptr.get();
saved_variables_.emplace_back(var, is_output);
} else {
saved_variables_.emplace_back();
}
}
to_save_.clear();
}
variable_list AutogradContext::get_saved_variables() const {
TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
variable_list saved;
saved.reserve(saved_variables_.size());
auto ptr = grad_fn_.lock();
TORCH_INTERNAL_ASSERT(ptr);
for (auto& var : saved_variables_) {
saved.push_back(var.unpack(ptr));
}
return saved;
}
void AutogradContext::mark_dirty(const variable_list &inputs) {
dirty_inputs_.clear();
dirty_inputs_.reserve(inputs.size());
for(auto& var : inputs) {
dirty_inputs_.insert(var.unsafeGetTensorImpl());
}
}
void AutogradContext::mark_non_differentiable(const variable_list &outputs) {
non_differentiable_.clear();
non_differentiable_.reserve(outputs.size());
for(auto& var : outputs) {
non_differentiable_.insert(var.unsafeGetTensorImpl());
}
}
void AutogradContext::set_materialize_grads(bool value) {
materialize_grads_ = value;
}
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_and_bump_dirty() const {
for (auto& var : dirty_inputs_) {
var->bump_version();
}
return dirty_inputs_;
}
const std::unordered_set<at::TensorImpl*>& AutogradContext::get_non_differentiable() const {
return non_differentiable_;
}
}} // namespace torch::autograd