1#include <c10/util/irange.h>
2#include <torch/csrc/autograd/autograd.h>
3#include <torch/csrc/autograd/custom_function.h>
4#include <torch/csrc/autograd/functions/accumulate_grad.h>
5
6#include <utility>
7
8namespace torch {
9namespace autograd {
10
11VariableInfo::VariableInfo(const Variable& var)
12 : layout(var.layout()),
13 device(var.device()),
14 scalar_type(var.scalar_type()),
15 size(var.sizes().vec()),
16 requires_grad(var.requires_grad()),
17 is_empty(false) {}
18
19VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {}
20
21Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
22 if (is_empty) {
23 // Return undefined tensor.
24 return at::Tensor();
25 } else {
26 return at::zeros(
27 size, at::TensorOptions(scalar_type).device(device).layout(layout));
28 }
29}
30
31// This function has two main goals:
32// 1) Use the user-provided jvp function to populate the the outputs' forward
33// gradient 2) Perform error checking to ensure that view and inplace ops are
34// properly handled
35//
36// For 1) we have to:
37// - Create a variable_list of grad_inputs based on the function inputs
38// - Call the user jvp function with these to get the grad_outputs
39// - Set the forward grad field on each output based on these grad_outputs
40//
41// For 2) we want to check the following:
42// - If an output is a view, then the generated forward grad must be a view as
43// well and
44// the output's base's forward grad must be the output's forward grad's base.
45// - If an input was modified inplace (it must be an output as well) we make
46// sure that its
47// forward grad was also modified inplace and already present on the
48// corresponding output.
49void _process_forward_mode_AD(
50 const variable_list& inputs,
51 std::unordered_map<at::TensorImpl*, size_t> inputs_mapping,
52 const at::ArrayRef<c10::optional<Variable>> raw_outputs,
53 const optional_variable_list& outputs,
54 const std::unordered_set<at::TensorImpl*>& non_differentiable,
55 const std::unordered_set<at::TensorImpl*>& dirty_inputs,
56 _jvp_fn_t jvp_user_function) {
57 // TODO handle multiple levels here
58 uint64_t level = 0;
59
60 const auto num_inputs = inputs.size();
61 const auto num_outputs = outputs.size();
62
63 // The tracking info below are used to perform the view and inplace checks.
64 // They are lazily initialized to reduce the cost of this function in the
65 // common case where the user is not using forward mode AD.
66 variable_list input_grads;
67 std::vector<int64_t> grad_versions;
68 std::vector<at::TensorImpl*> grad_impls;
69 std::unordered_map<at::TensorImpl*, size_t> inputs_bases;
70
71 auto init_tracked_info = [&]() {
72 input_grads.resize(num_inputs);
73 grad_versions.resize(num_inputs);
74 grad_impls.resize(num_inputs);
75
76 for (const auto i : c10::irange(num_inputs)) {
77 const auto& inp = inputs[i];
78 if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) {
79 inputs_bases.emplace(
80 impl::get_view_autograd_meta(inp)
81 ->get_forward_view()
82 .base_.unsafeGetTensorImpl(),
83 i);
84 } else {
85 inputs_bases.emplace(inp.unsafeGetTensorImpl(), i);
86 }
87 }
88 };
89
90 bool any_input_has_grad = false;
91 // Extract the input's forward gradients and record any info we will need
92 // later
93 for (const auto i : c10::irange(num_inputs)) {
94 const auto& inp = inputs[i];
95 if (!inp.defined()) {
96 continue;
97 }
98 const auto& fw_grad = inp._fw_grad(level);
99 if (fw_grad.defined()) {
100 if (!any_input_has_grad) {
101 any_input_has_grad = true;
102 init_tracked_info();
103 }
104 input_grads[i] = fw_grad;
105 grad_versions[i] = fw_grad._version();
106 grad_impls[i] = fw_grad.unsafeGetTensorImpl();
107 }
108 }
109
110 // If no input has forward grad, nothing to do here
111 if (!any_input_has_grad) {
112 return;
113 }
114
115 torch::autograd::variable_list forward_grads;
116 {
117 at::AutoFwGradMode fw_grad_mode(false);
118 forward_grads = jvp_user_function(inputs, std::move(input_grads));
119 }
120
121 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
122 const auto num_forward_grads = forward_grads.size();
123 // contrary to backward mode, we don't allow returning too many gradients
124 TORCH_CHECK(
125 num_forward_grads == num_outputs,
126 "Function's jvp returned "
127 "an invalid number of forward gradients (expected ",
128 num_outputs,
129 " but got ",
130 num_forward_grads,
131 ")");
132
133 for (const auto i : c10::irange(num_outputs)) {
134 const auto& out =
135 outputs[i].has_value() ? outputs[i].value() : at::Tensor();
136 auto out_tensor_impl = raw_outputs[i].value().unsafeGetTensorImpl();
137 bool is_differentiable =
138 (non_differentiable.count(out_tensor_impl) == 0 &&
139 isDifferentiableType(raw_outputs[i].value().scalar_type()));
140 const auto& out_grad = forward_grads[i];
141 if (!out.defined() || !is_differentiable) {
142 TORCH_CHECK(
143 !out_grad.defined(),
144 "Function's jvp returned a gradient at position ",
145 i,
146 ", but "
147 " the corresponding forward output is not a differentiable Tensor."
148 "You should return None at that position instead.");
149 continue;
150 }
151
152 TORCH_INTERNAL_ASSERT(raw_outputs[i].has_value());
153 bool is_input = inputs_mapping.count(out_tensor_impl) > 0;
154 bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
155
156 if (is_modified) {
157 TORCH_CHECK(
158 is_input,
159 "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
160 " is no need to pass it to mark_dirty().");
161 auto inp_idx = inputs_mapping[out_tensor_impl];
162 if (grad_impls[inp_idx]) {
163 // If there was already a forward grad for that input
164 // Just make sure that it is modified inplace and returned as-is
165 TORCH_CHECK(
166 out_grad._version() != grad_versions[inp_idx],
167 "An inplace custom Function is not modifying the "
168 "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp "
169 "function must modify the corresponding gradient inplace.")
170 TORCH_CHECK(
171 out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx],
172 "An inplace custom Function is not returning the "
173 "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp "
174 "function must modify the gradient inplace and return it as-is.")
175 } else {
176 // If that Tensor didn't had gradients already, set the newly returned
177 // one We could also use inputs[inp_idx] here as it is the same as out
178 out._set_fw_grad(out_grad, level, /* is_inplace_op */ true);
179 }
180 } else {
181 // At this point, outputs[i] cannot be one of the input (raw_outputs[i]
182 // might be but was changed by the backward code)
183 TORCH_INTERNAL_ASSERT(
184 inputs_mapping.count(out.unsafeGetTensorImpl()) == 0);
185
186 if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) {
187 // If the output is a view
188 const auto& out_view_info =
189 impl::get_view_autograd_meta(out)->get_forward_view();
190 if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) {
191 // And it is a view of an input (either that input is its base or they
192 // have a common base)
193 const auto matching_input_idx =
194 inputs_bases[out_view_info.base_.unsafeGetTensorImpl()];
195 const auto& matching_input = inputs[matching_input_idx];
196
197 const auto& matching_input_grad = matching_input._fw_grad(level);
198
199 // If the matching input has a forward grad, the user should have
200 // returned a view of that Tensor
201 if (matching_input_grad.defined()) {
202 TORCH_CHECK(
203 out_grad.is_view() &&
204 impl::get_view_autograd_meta(out_grad)->has_fw_view(),
205 "A custom Function's forward is returning a view (or an input as-is) but the jvp is not "
206 "returning a view.");
207 const auto& out_grad_base = impl::get_view_autograd_meta(out_grad)
208 ->get_forward_view()
209 .base_;
210 if (matching_input_grad.is_view() &&
211 impl::get_view_autograd_meta(matching_input_grad)
212 ->has_fw_view()) {
213 // If the matching input's grad is a view, ensure that the
214 // out_grad is a view of the same base
215 const auto& matching_input_grad_base =
216 impl::get_view_autograd_meta(matching_input_grad)
217 ->get_forward_view()
218 .base_;
219 TORCH_CHECK(
220 matching_input_grad_base.unsafeGetTensorImpl() ==
221 out_grad_base.unsafeGetTensorImpl(),
222 "A custom Function is returning a view but the jvp is not returning a view of the same base as "
223 "the given grad input.");
224 } else {
225 // If the matching input's grad is not a view, then it must be the
226 // output gradient's base
227 TORCH_CHECK(
228 matching_input_grad.unsafeGetTensorImpl() ==
229 out_grad_base.unsafeGetTensorImpl(),
230 "A custom Function is returning a view but the jvp is not returning a view of the given grad input.");
231 }
232 } else {
233 // We have a view op where the input didn't have a forward grad but
234 // the user returned one for the output To ensure that we maintain
235 // the view/inplace constraints, we consider this as an inplace op
236 // This case CANNOT happen in codegen as all view ops are mapping
237 // from one Tensor to one Tensor and so the output of the view
238 // cannot have a forward grad if the base does not.
239 out._set_fw_grad(out_grad, level, /* is_inplace_op */ true);
240 return;
241 }
242 }
243 }
244
245 out._set_fw_grad(out_grad, level, /* is_inplace_op */ false);
246 }
247 }
248}
249
250at::Tensor _view_as_self_with_no_grad(at::Tensor self) {
251 // This is called below in _process_backward_mode_ad in two places:
252 //
253 // (1) An input has been returned, but it wasn't modified. Return it as a view
254 // so that we can attach a new grad_fn to the Variable.
255 // Run in no_grad mode to mimic the behavior of the forward.
256 //
257 // (2) Though it is not necessary for the purposes of attaching grad_fn, we
258 // also call this function when an output is non-differentiable (and does not
259 // require grad). to help custom forward AD UX more consistent. We'd like to
260 // uniformly say that returning an input as-is is treated as if
261 // `self.view_as(self)` were returned for that output.
262 //
263 // Alternatively, we could have not disabled forward grad while performing
264 // this view, but it would mean that the user defined jvp may be silently
265 // ignored.
266 at::AutoFwGradMode fw_grad_mode(false);
267 AutoGradMode grad_mode(false);
268 return self.view_as(self);
269}
270
271optional_variable_list _process_backward_mode_ad(
272 const std::unordered_map<at::TensorImpl*, size_t>& inputs_mapping,
273 const std::unordered_set<at::TensorImpl*>& non_differentiable,
274 const std::unordered_set<at::TensorImpl*>& dirty_inputs,
275 const at::ArrayRef<c10::optional<Variable>> raw_outputs,
276 const std::shared_ptr<Node>& cdata) {
277 int num_outputs = raw_outputs.size();
278
279 // Sets the grad_fn and output_nr of an output Variable.
280 auto set_history = [&](Variable& var,
281 uint32_t output_nr,
282 bool is_input,
283 bool is_modified,
284 bool is_differentiable) {
285 if (!is_differentiable) {
286 if (!var.requires_grad()) {
287 if (is_input && !is_modified) {
288 var = _view_as_self_with_no_grad(var);
289 }
290 return;
291 }
292 // Return detached aliases of inputs, instead of changing their
293 // requires_grad property.
294 if (is_input) {
295 var = var.detach();
296 } else if (!var.is_view()) {
297 var.detach_();
298 }
299 // If var is a view of one of the inputs of the custom autograd Function,
300 // we don't detach it in a no_grad block. This is so that we can mimic the
301 // behavior of returning a view from a no_grad block:
302 // x = torch.randn(3, requires_grad=True)
303 // with torch.no_grad():
304 // y = x.view(-1)
305 // Here, `y` requires_grad (!).
306 } else if (is_modified) {
307 if (var.is_leaf() && var.requires_grad()) {
308 TORCH_CHECK(
309 false,
310 "a leaf Variable that requires grad has been used in an in-place operation.");
311 }
312 // No need to mark as modified Tensors that are not inputs.
313 if (!is_input) {
314 TORCH_WARN(
315 "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there"
316 " is no need to pass it to mark_dirty().");
317 }
318 // If the input is a view, the rebase will need to rewrite the graph and
319 // this only works if we have a single output to this Function.
320 TORCH_CHECK(
321 !(var.is_view() && num_outputs > 1),
322 "If your Function modifies inplace an input that is a view"
323 " of another Tensor, your Function cannot return more than one Tensor. This is not supported"
324 " by the current autograd engine. You should either make sure the input is not a view (using"
325 " .clone() for example) or make your Function only return one Tensor (potentially splitting"
326 " it into two Functions: one doing the inplace that returns a single Tensor and a second one"
327 " that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if"
328 " you need help to do this change.");
329
330 // If the input was modified, transplant the grad_fn in the graph:
331 // grad_fn <- variable <- self ==> grad_fn <- self <- variable
332 var.mutable_grad().reset();
333 impl::clear_hooks(var);
334 if (auto grad_acc_fn = impl::try_get_grad_accumulator(var)) {
335 auto& grad_acc = dynamic_cast<AccumulateGrad&>(*grad_acc_fn);
336 grad_acc.variable.reset();
337 }
338 if (cdata) {
339 impl::rebase_history(var, {cdata, output_nr});
340 }
341 } else if (is_input) {
342 var = _view_as_self_with_no_grad(var);
343 impl::set_gradient_edge(var, {cdata, output_nr});
344 } else if (cdata) {
345 impl::set_gradient_edge(var, {cdata, output_nr});
346 }
347 };
348
349 optional_variable_list outputs;
350 std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check
351 outputs.reserve(num_outputs);
352 int num_diff_outputs = 0;
353
354 for (const auto i : c10::irange(num_outputs)) {
355 // For outputs that are not tensors, put a placeholder undefined input.
356 if (!raw_outputs[i].has_value()) {
357 if (cdata) {
358 auto output_nr = cdata->add_input_metadata(Node::undefined_input());
359 AT_ASSERT(i == (int)output_nr);
360 }
361 outputs.emplace_back();
362 continue;
363 }
364
365 Variable var = raw_outputs[i].value();
366
367 auto out_tensor_impl = var.unsafeGetTensorImpl();
368 bool is_input = inputs_mapping.count(out_tensor_impl) > 0;
369 bool is_modified = dirty_inputs.count(out_tensor_impl) > 0;
370 bool is_differentiable = cdata &&
371 non_differentiable.count(out_tensor_impl) == 0 &&
372 isDifferentiableType(var.scalar_type());
373
374 if (cdata) {
375 auto output_nr = cdata->add_input_metadata(var);
376 AT_ASSERT(i == (int)output_nr);
377 }
378 set_history(var, i, is_input, is_modified, is_differentiable);
379
380 // For deprecation cycle. Can be removed after 1.6. In the case where we
381 // detected a view in no grad mode during the forward, only warn the user
382 // (do not change the flag if we return and input that is a view as is). See
383 // NOTE [ View + Inplace detection ] for why we replace everything by a
384 // warning.
385 if (!(is_input && is_modified) && var.is_view()) {
386 // is_view() => diff_view_meta
387 auto diff_view_meta = impl::get_view_autograd_meta(var);
388 diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION);
389 }
390
391 if (is_differentiable) {
392 ++num_diff_outputs;
393 }
394
395 outputs_impl.insert(out_tensor_impl);
396 outputs.emplace_back(var);
397 }
398
399 // If multiple differentiable outputs are returned, we do not allow views to
400 // be modified inplace See NOTE [ View + Inplace detection ] for more details
401 if (num_diff_outputs > 1) {
402 for (auto& var : outputs) {
403 if (var.has_value()) {
404 auto diff_view_meta = impl::get_view_autograd_meta(var.value());
405 if (diff_view_meta && diff_view_meta->has_bw_view()) {
406 diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE);
407 }
408 }
409 }
410 }
411
412 // All the modified Tensors must be returned as is for the rewrite to be
413 // valid.
414 for (auto& dirty_input : dirty_inputs) {
415 TORCH_CHECK(
416 outputs_impl.count(dirty_input) > 0,
417 "Some elements marked as dirty during the forward method were not returned as output. The"
418 " inputs that are modified inplace must all be outputs of the Function.");
419 }
420
421 return outputs;
422}
423
424optional_variable_list _wrap_outputs(
425 const variable_list& input_vars,
426 const std::unordered_set<at::TensorImpl*>& non_differentiable,
427 const std::unordered_set<at::TensorImpl*>& dirty_inputs,
428 const at::ArrayRef<c10::optional<Variable>> raw_outputs,
429 const std::shared_ptr<Node>& cdata,
430 _jvp_fn_t jvp_user_function) {
431 std::unordered_map<at::TensorImpl*, size_t> inputs_mapping;
432 inputs_mapping.reserve(input_vars.size());
433 for (const auto i : c10::irange(input_vars.size())) {
434 inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i);
435 }
436
437 auto outputs = _process_backward_mode_ad(
438 inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata);
439
440 // This must happen after the backward processing as we expect the
441 // computations happening here to track backward mode gradients.
442 _process_forward_mode_AD(
443 input_vars,
444 std::move(inputs_mapping),
445 raw_outputs,
446 outputs,
447 non_differentiable,
448 dirty_inputs,
449 std::move(jvp_user_function));
450
451 return outputs;
452}
453
454void check_variable_result(
455 const at::TensorBase& original,
456 const at::TensorBase& result,
457 std::string hook_name) {
458 if (!original.options().type_equal(result.options())) {
459 std::stringstream ss;
460 ss << "hook '" << hook_name << "' has changed the type of value (";
461 ss << "was " << original.toString() << " got ";
462 ss << result.toString() << ")";
463 throw std::runtime_error(ss.str());
464 }
465
466 if (original.is_cuda() != result.is_cuda()) {
467 std::stringstream ss;
468 ss << "hook '" << hook_name << "' has changed the type of value";
469 if (original.is_cuda()) {
470 ss << " (was CUDA tensor got CPU tensor)";
471 } else {
472 ss << " (was CPU tensor got CUDA tensor)";
473 }
474 throw std::runtime_error(ss.str());
475 }
476
477 if (original.sizes().vec() != result.sizes().vec()) {
478 std::stringstream ss;
479 ss << "hook '" << hook_name << "' has changed the size of value";
480 throw std::runtime_error(ss.str());
481 }
482}
483
484void AutogradContext::save_for_backward(variable_list to_save) {
485 to_save_ = std::move(to_save);
486}
487
488// The logic for handling saved variables here is the same as
489// python_function.cpp See _save_variables() and unpack_saved_variables()
490void AutogradContext::save_variables() {
491 saved_variables_.clear();
492 auto ptr = grad_fn_.lock();
493
494 for (const auto& var : to_save_) {
495 // Allow empty variables to be saved
496 if (var.defined()) {
497 bool is_output = var.grad_fn().get() == ptr.get();
498 saved_variables_.emplace_back(var, is_output);
499 } else {
500 saved_variables_.emplace_back();
501 }
502 }
503 to_save_.clear();
504}
505
506variable_list AutogradContext::get_saved_variables() const {
507 TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE);
508 variable_list saved;
509 saved.reserve(saved_variables_.size());
510 auto ptr = grad_fn_.lock();
511 TORCH_INTERNAL_ASSERT(ptr);
512 for (auto& var : saved_variables_) {
513 saved.push_back(var.unpack(ptr));
514 }
515 return saved;
516}
517
518bool AutogradContext::needs_input_grad(size_t output_edge_index) const {
519 auto ptr = grad_fn_.lock();
520 TORCH_INTERNAL_ASSERT(ptr);
521 return ptr->task_should_compute_output(output_edge_index);
522}
523
524bool AutogradContext::needs_input_grad(
525 std::initializer_list<IndexRange> idxs) const {
526 auto ptr = grad_fn_.lock();
527 TORCH_INTERNAL_ASSERT(ptr);
528 return ptr->task_should_compute_output(idxs);
529}
530
531void AutogradContext::mark_dirty(const variable_list& inputs) {
532 dirty_inputs_.clear();
533 dirty_inputs_.reserve(inputs.size());
534 for (auto& var : inputs) {
535 dirty_inputs_.insert(var.unsafeGetTensorImpl());
536 }
537}
538
539void AutogradContext::mark_non_differentiable(const variable_list& outputs) {
540 non_differentiable_.clear();
541 non_differentiable_.reserve(outputs.size());
542 for (auto& var : outputs) {
543 non_differentiable_.insert(var.unsafeGetTensorImpl());
544 }
545}
546
547void AutogradContext::set_materialize_grads(bool value) {
548 materialize_grads_ = value;
549}
550
551const std::unordered_set<at::TensorImpl*>& AutogradContext::get_and_bump_dirty()
552 const {
553 for (auto& var : dirty_inputs_) {
554 var->bump_version();
555 }
556 return dirty_inputs_;
557}
558
559const std::unordered_set<at::TensorImpl*>& AutogradContext::
560 get_non_differentiable() const {
561 return non_differentiable_;
562}
563} // namespace autograd
564} // namespace torch
565