1#include <torch/csrc/autograd/variable.h>
2
3#include <torch/csrc/autograd/InferenceMode.h>
4#include <torch/csrc/autograd/autograd.h>
5#include <torch/csrc/autograd/edge.h>
6#include <torch/csrc/autograd/engine.h>
7#include <torch/csrc/autograd/function.h>
8#include <torch/csrc/autograd/functions/accumulate_grad.h>
9#include <torch/csrc/autograd/functions/tensor.h>
10#include <torch/csrc/autograd/generated/Functions.h>
11#include <torch/csrc/autograd/utils/error_messages.h>
12
13#include <ATen/core/VariableHooksInterface.h>
14
15#include <ATen/ATen.h>
16#include <ATen/FuncTorchTLS.h>
17#include <ATen/MemoryOverlap.h>
18#include <c10/util/Exception.h>
19
20#include <iostream>
21#include <list>
22#include <memory>
23#include <mutex>
24#include <stdexcept>
25#include <string>
26#include <typeinfo>
27#include <utility>
28#include <vector>
29
30namespace torch {
31namespace autograd {
32
33DifferentiableViewMeta::DifferentiableViewMeta(
34 at::TensorImpl* self_impl,
35 c10::optional<ViewInfo> backward_info,
36 c10::optional<ViewInfo> forward_info,
37 bool shared_view_info,
38 CreationMeta creation_meta)
39 : AutogradMeta(self_impl),
40 backward_info_(std::move(backward_info)),
41 forward_info_(std::move(forward_info)),
42 shared_view_info_(shared_view_info),
43 creation_meta_(creation_meta) {
44 is_view_ = true;
45 if (backward_info_.has_value()) {
46 self_impl->set_version_counter(
47 impl::version_counter(backward_info_.value().base_));
48 attr_version_ = self_impl->version_counter().current_version();
49 TORCH_INTERNAL_ASSERT(
50 backward_info_.value().base_.unsafeGetTensorImpl() != self_impl);
51 }
52 if (shared_view_info_) {
53 TORCH_INTERNAL_ASSERT(
54 backward_info_.has_value(),
55 "Shared view info require a backward view info.");
56 TORCH_INTERNAL_ASSERT(
57 !forward_info_.has_value(),
58 "Shared view info require forward view info to be empty")
59 }
60}
61
62// Chain this view info with the new view op between base and tensor
63ViewInfo ViewInfo::chain(
64 const Variable& base,
65 const Variable& tensor,
66 std::function<Variable(const Variable&)> view_func) const {
67 // Set `view_func` using the root base as input.
68 // `view_func` is used to recover views in backward when either as_strided is
69 // not supported or the view function changes the metadata which is not
70 // recorded by as_strided See Note [View + Inplace update on base tensor] and
71 // [View + Inplace update on view tensor] for more details how we use this
72 // function in backward.
73 if (view_func) {
74 // both current_view and it's parent have a view_func
75 if (view_fn_) {
76 // Copy parent view function to gain ownership
77 auto prev_fn = view_fn_;
78 view_func = [=](const at::Tensor& root_base) {
79 auto temp = prev_fn(root_base);
80 return view_func(temp);
81 };
82 } else {
83 // current_view has a view_func and but it's parent doesn't have one
84 if (base.unsafeGetTensorImpl()->support_as_strided()) {
85 auto size = base.sym_sizes().vec();
86 auto stride = base.sym_strides().vec();
87 auto storage_offset = base.sym_storage_offset();
88 view_func = [=](const at::Tensor& root_base) {
89 auto temp = root_base.as_strided_symint(size, stride, storage_offset);
90 return view_func(temp);
91 };
92 } else {
93 // When base is a view but doesn't carry a view_fn in
94 // DifferentiableViewMeta, it's a view that doesn't support inplace
95 // update, e.g. unbind. In this case we should throw an error when
96 // inplace update happens in **forward**. One would naturally think the
97 // following function will be first called in backward pass. But the
98 // first call site is indeed in **forward** pass when we refresh
99 // `grad_fn` triggered by inplace update. Search Note [View + Inplace
100 // update for view tensor] to for the call site.
101 view_func = [=](const at::Tensor& root_base) {
102 TORCH_CHECK(
103 false,
104 "This view is the output of a function that returns multiple views."
105 "Such functions do not allow the output views to be modified inplace."
106 "You should replace the inplace operation by an out-of-place one");
107 return root_base;
108 };
109 }
110 }
111 } else if (view_fn_) {
112 // if current_view doesn't have a view_func but it's parent has one
113 // Copy parent view function to gain ownership
114 auto prev_view_fn = view_fn_;
115 auto size = tensor.sym_sizes().vec();
116 auto stride = tensor.sym_strides().vec();
117 auto storage_offset = tensor.sym_storage_offset();
118 view_func = [=](const at::Tensor& root_base) {
119 auto temp = prev_view_fn(root_base);
120 return temp.as_strided_symint(size, stride, storage_offset);
121 };
122 }
123
124 return ViewInfo(base_, std::move(view_func));
125}
126
127namespace {
128
129at::Tensor singleton_undefined_tensor;
130
131struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory {
132 std::unique_ptr<c10::AutogradMetaInterface> make() const override {
133 return std::make_unique<AutogradMeta>();
134 }
135 const at::Tensor& undefined_tensor() const override {
136 return singleton_undefined_tensor;
137 }
138};
139
140ConcreteAutogradMetaFactory meta_factory;
141
142static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer(
143 &meta_factory);
144
145} // namespace
146
147namespace impl {
148
149AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) {
150 TORCH_CHECK(
151 self.defined(),
152 "cannot call materialize_autograd_meta() on undefined tensor");
153 auto p = self.unsafeGetTensorImpl();
154 if (!p->autograd_meta()) {
155 p->set_autograd_meta(std::make_unique<AutogradMeta>());
156 }
157 return get_autograd_meta(self);
158}
159
160void update_tensor_hooks_on_new_gradfn(
161 const at::TensorBase& self,
162 const std::shared_ptr<torch::autograd::Node>& old_fn,
163 const std::shared_ptr<torch::autograd::Node>& new_fn) {
164 // This function is called whenever the grad_fn of the tensor is
165 // changed. We assume here that new_fn does not yet have hooks of
166 // its own.
167 //
168 // This function does two things:
169 // (1) reset the list when grad_fn is updated, so new hooks don't
170 // get erroneously registered to the old grad_fn.
171 // Note that the old cpp_hooks_list_ is still kept alive by the
172 // old grad_fn so hooks registered to the older version of the tensor
173 // will continue to be active.
174 // (2) If there is a retains_grad hook registered, move that from the
175 // old cpp_hooks_list_ to the new one
176 const auto& meta = impl::get_autograd_meta(self);
177 TORCH_INTERNAL_ASSERT(meta);
178 TORCH_INTERNAL_ASSERT(new_fn);
179 meta->cpp_hooks_list_ = nullptr;
180 const c10::impl::PyInterpreter* interp =
181 self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter();
182 if (interp) {
183 (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl());
184 }
185 if (self.retains_grad()) {
186 TORCH_INTERNAL_ASSERT(old_fn);
187 auto out = old_fn->pop_retains_grad_hook(self.output_nr());
188 TORCH_INTERNAL_ASSERT(out != nullptr);
189 new_fn->add_retains_grad_hook(std::move(out), self.output_nr());
190 }
191}
192
193void rebase_history(const Variable& self, Edge gradient_edge) {
194 TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr);
195 const auto& meta = impl::get_autograd_meta(self);
196 auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr;
197 auto diff_view_meta = get_view_autograd_meta(self);
198 if (diff_view_meta && diff_view_meta->has_bw_view()) {
199 // See NOTE [ View + Inplace detection ]
200 auto creation_meta = diff_view_meta->get_creation_meta();
201 // Do not use handle_view_on_rebase here as check_inplace should have been
202 // called before this and either throw an error
203 TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT);
204 TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0);
205 TORCH_INTERNAL_ASSERT(gradient_edge.function);
206 TORCH_CHECK(
207 gradient_edge.function->num_inputs() == 1,
208 "Functions which modify views in-place must return a single Variable");
209 auto view_info = diff_view_meta->get_backward_view();
210 diff_view_meta->output_nr_ = gradient_edge.input_nr;
211 auto copy_slices = std::make_shared<CopySlices>(
212 view_info.base_,
213 at::TensorGeometry(self),
214 view_info.view_fn_,
215 std::move(gradient_edge.function));
216 set_gradient_edge(view_info.base_, {std::move(copy_slices), 0});
217 self.grad_fn(); // trigger an update to the view's grad_fn
218 return;
219 }
220
221 set_gradient_edge(self, std::move(gradient_edge));
222 // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly
223 torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
224 self, old_fn, self.grad_fn());
225}
226
227void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) {
228 const auto& fn = self.grad_fn();
229 std::shared_ptr<hooks_list>& list =
230 materialize_autograd_meta(self)->cpp_hooks_list_;
231 // NOLINTNEXTLINE(modernize-make-shared)
232 list.reset(new hooks_list());
233 std::unique_ptr<FunctionPreHook> hook_ptr{
234 new CppFunctionTensorPreHook(list, self.output_nr())};
235 // NB: we could potentially only update hooks_ if !fn, but it shouldn't
236 // matter
237 // and this was the way before, so we keep it like this for now.
238 clear_hooks(self);
239 add_hook(self, std::make_unique<CppFunctionTensorPreHook>(list, 0));
240 if (fn) {
241 fn->add_tensor_pre_hook(std::move(hook_ptr));
242 }
243}
244
245void set_grad_accumulator(
246 const Variable& self,
247 std::weak_ptr<Node> grad_accumulator) {
248 materialize_autograd_meta(self)->grad_accumulator_ =
249 std::move(grad_accumulator);
250}
251
252std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) {
253 if (get_autograd_meta(self)) {
254 return get_autograd_meta(self)->grad_accumulator_.lock();
255 } else {
256 return nullptr;
257 }
258}
259
260std::shared_ptr<Node> grad_accumulator(const Variable& self) {
261 auto autograd_meta = get_autograd_meta(self);
262 if (!autograd_meta) {
263 return nullptr;
264 }
265 if (autograd_meta->grad_fn_) {
266 throw std::logic_error(
267 "grad_accumulator() should be only called on leaf Variables");
268 }
269 if (!autograd_meta->requires_grad_) {
270 return nullptr;
271 }
272
273 std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
274
275 auto result = autograd_meta->grad_accumulator_.lock();
276 if (result)
277 return result;
278
279 c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl());
280 auto intrusive_from_this =
281 c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl());
282 result = std::make_shared<AccumulateGrad>(
283 Variable(std::move(intrusive_from_this)));
284 autograd_meta->grad_accumulator_ = result;
285 return result;
286}
287
288Edge gradient_edge(const Variable& self) {
289 // If grad_fn is null (as is the case for a leaf node), we instead
290 // interpret the gradient function to be a gradient accumulator, which will
291 // accumulate its inputs into the grad property of the variable. These
292 // nodes get suppressed in some situations, see "suppress gradient
293 // accumulation" below. Note that only variables which have `requires_grad =
294 // True` can have gradient accumulators.
295 if (const auto& gradient = self.grad_fn()) {
296 return Edge(gradient, self.output_nr());
297 } else {
298 return Edge(grad_accumulator(self), 0);
299 }
300}
301
302void set_gradient_edge(const Variable& self, Edge edge) {
303 auto* meta = materialize_autograd_meta(self);
304 meta->grad_fn_ = std::move(edge.function);
305 meta->output_nr_ = edge.input_nr;
306 // For views, make sure this new grad_fn_ is not overwritten unless it is
307 // necessary in the VariableHooks::grad_fn below. This logic is only relevant
308 // for custom autograd Functions for which multiple operations can happen on a
309 // given Tensor before its gradient edge is set when exiting the custom
310 // Function.
311 auto diff_view_meta = get_view_autograd_meta(self);
312 if (diff_view_meta && diff_view_meta->has_bw_view()) {
313 diff_view_meta->set_attr_version(self._version());
314 }
315}
316
317Node* grad_fn_unsafe(const Variable& self) {
318 if (get_autograd_meta(self)) {
319 return get_autograd_meta(self)->grad_fn_.get();
320 } else {
321 return nullptr;
322 }
323}
324
325// Versions
326//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
327
328void set_version_counter(
329 const Variable& self,
330 const c10::VariableVersion& version_counter) {
331 TORCH_CHECK(
332 self.defined(), "cannot call set_version_counter() on undefined tensor");
333 self.unsafeGetTensorImpl()->set_version_counter(version_counter);
334}
335
336void bump_version(const Variable& self) {
337 TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor");
338 self.unsafeGetTensorImpl()->bump_version();
339}
340
341const c10::VariableVersion& version_counter(const Variable& self) {
342 TORCH_CHECK(
343 self.defined(), "cannot call version_counter() on undefined tensor");
344 return self.unsafeGetTensorImpl()->version_counter();
345}
346
347// Hooks
348//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
349
350void add_hook(
351 const at::TensorBase& self,
352 std::unique_ptr<FunctionPreHook> hook) {
353 AutogradMeta* meta = materialize_autograd_meta(self);
354 TORCH_INTERNAL_ASSERT(meta->hooks_.empty());
355 meta->hooks_.push_back(std::move(hook));
356}
357
358std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable& self) {
359 TORCH_INTERNAL_ASSERT(get_autograd_meta(self));
360 return get_autograd_meta(self)->hooks_;
361}
362
363void clear_hooks(const at::TensorBase& self) {
364 // This is a little goofy, but usually this should be a no oop
365 materialize_autograd_meta(self)->hooks_.clear();
366}
367
368void set_name(const Variable& self, const std::string& name) {
369 materialize_autograd_meta(self)->name_ = name;
370}
371
372// Miscellaneous
373//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
374
375AutogradMeta* get_autograd_meta(const at::TensorBase& self) {
376 // NB: could return nullptr
377 TORCH_CHECK(
378 self.defined(), "cannot call get_autograd_meta() on undefined tensor");
379 return static_cast<AutogradMeta*>(
380 self.unsafeGetTensorImpl()->autograd_meta());
381}
382
383DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) {
384 // NB: return nullptr if self is not a view
385 AutogradMeta* meta = get_autograd_meta(self);
386 if (meta && meta->is_view_) {
387 return static_cast<DifferentiableViewMeta*>(meta);
388 } else {
389 return nullptr;
390 }
391}
392
393} // namespace impl
394
395using at::Tensor;
396
397struct VariableHooks final : at::impl::VariableHooksInterface {
398 at::TensorBase tensor_data(const at::TensorBase&) const override;
399 at::TensorBase variable_data(const at::TensorBase&) const override;
400 const std::shared_ptr<torch::autograd::Node>& grad_fn(
401 const at::TensorBase&) const override;
402 unsigned _register_hook(
403 const at::TensorBase&,
404 std::function<at::TensorBase(const at::TensorBase&)> hook) const override;
405 void remove_hook(const at::TensorBase&, unsigned pos) const override;
406 bool is_view(const at::TensorBase&) const override;
407 const at::TensorBase& base(const at::TensorBase&) const override;
408 const std::string& name(const at::TensorBase&) const override;
409 bool is_leaf(const at::TensorBase&) const override;
410 int64_t output_nr(const at::TensorBase&) const override;
411 void set_data(const at::TensorBase& self, const at::TensorBase& new_data)
412 const override;
413 at::TensorBase data(const at::TensorBase& self) const override;
414 int64_t _version(const at::TensorBase& self) const override;
415 void retain_grad(const at::TensorBase& self) const override;
416 bool retains_grad(const at::TensorBase& self) const override;
417 void _backward(
418 const Tensor& self,
419 at::TensorList inputs,
420 const c10::optional<Tensor>& gradient,
421 c10::optional<bool> keep_graph,
422 bool create_graph) const override;
423 void requires_grad_(const at::TensorBase& self, bool _requires_grad)
424 const override;
425};
426
427VariableHooks variableHooks;
428at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks);
429
430at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const {
431 TORCH_CHECK(
432 self.defined(), "cannot call variable_data() on undefined tensor");
433 auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
434 /*version_counter=*/0,
435 /*allow_tensor_metadata_change=*/false);
436 self_impl_copy->set_autograd_meta(nullptr);
437 return at::Tensor(self_impl_copy);
438}
439
440at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const {
441 TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor");
442 auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach(
443 /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(),
444 /*allow_tensor_metadata_change=*/
445 self.unsafeGetTensorImpl()->allow_tensor_metadata_change());
446 return at::Tensor(self_impl_copy);
447}
448
449bool VariableHooks::is_leaf(const at::TensorBase& self) const {
450 if (impl::get_autograd_meta(self)) {
451 return impl::get_autograd_meta(self)->grad_fn_ == nullptr;
452 } else {
453 return true;
454 }
455}
456
457int64_t VariableHooks::output_nr(const at::TensorBase& self) const {
458 if (impl::get_autograd_meta(self)) {
459 return impl::get_autograd_meta(self)->output_nr_;
460 } else {
461 return 0;
462 }
463}
464
465void VariableHooks::set_data(
466 const at::TensorBase& self_base,
467 const at::TensorBase& new_data_base) const {
468 at::OptionalTensorRef self_ref(self_base);
469 const Tensor& self = *self_ref;
470 at::OptionalTensorRef new_data_ref(new_data_base);
471 const Tensor& new_data = *new_data_ref;
472
473 // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields
474 // from `new_data` to `var`. It requires that `new_data` and `var` have
475 // compatible tensor type.
476 TORCH_CHECK(
477 _has_compatible_shallow_copy_type(self, new_data),
478 "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type.");
479
480 TORCH_CHECK(
481 !self.requires_grad() ||
482 isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())),
483 "data set to a tensor that requires gradients must be floating point or complex dtype");
484
485 // Resets gradient accumulator if metadata is out of date
486 AutogradMeta* autograd_meta = impl::get_autograd_meta(self);
487 if (autograd_meta) {
488 std::lock_guard<std::mutex> lock(autograd_meta->mutex_);
489 auto prior_accumulator = autograd_meta->grad_accumulator_.lock();
490 if (prior_accumulator) {
491 const auto prior_device = prior_accumulator->input_metadata(0).device();
492 const auto new_device = new_data.device();
493
494 if (!new_data.options().type_equal(self.options()) ||
495 prior_device != new_device) {
496 autograd_meta->grad_accumulator_.reset();
497 }
498 }
499 }
500
501 // Version counter is not shared when we replace a `Variable`'s tensor data
502 // by calling `set_data(...)`. The original version of the `Variable` is
503 // always preserved. See NOTE [ Version Counter Sharing ] for details.
504 //
505 // `var.set_data(new_data)` always ignores `var`'s
506 // `allow_tensor_metadata_change_`, because users need this API as an escape
507 // hatch for changing a tensor's metadata regardless of its
508 // `allow_tensor_metadata_change_` value, and the users are responsible for
509 // ensuring this is the behavior they want.
510 self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr());
511}
512
513at::TensorBase VariableHooks::data(const at::TensorBase& self) const {
514 return self.variable_data();
515}
516
517int64_t VariableHooks::_version(const at::TensorBase& self) const {
518 return self.unsafeGetTensorImpl()->version_counter().current_version();
519}
520
521void VariableHooks::retain_grad(const at::TensorBase& self) const {
522 TORCH_CHECK(
523 self.requires_grad(),
524 "can't retain_grad on Tensor that has requires_grad=False");
525
526 // temporary hack to improve functorch UX.
527 const auto& functorch_tls = at::functorch::functorchTLSAccessor();
528 if (functorch_tls) {
529 functorch_tls->checkSupportsRetainGrad();
530 }
531
532 if (self.is_leaf()) { // no-op for leaves
533 return;
534 }
535 if (impl::get_autograd_meta(self)->retains_grad_) {
536 return;
537 }
538 c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr());
539
540 auto retain_grad_hook = [weak_self](const at::TensorBase& grad_base) {
541 at::Tensor grad{grad_base};
542 if (!weak_self.expired() && grad.defined()) {
543 auto var = weak_self.lock();
544 if (!var->grad().defined()) {
545 if (grad.is_sparse()) {
546 var->mutable_grad() = grad.clone();
547 } else {
548 var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous);
549 }
550 } else {
551 var->mutable_grad() = var->grad() + grad;
552 }
553 }
554 return at::TensorBase{};
555 };
556
557 const auto& fn = self.grad_fn();
558 std::unique_ptr<FunctionPreHook> hook_ptr{new CppFunctionSingleTensorPreHook(
559 std::move(retain_grad_hook), self.output_nr())};
560 fn->add_retains_grad_hook(std::move(hook_ptr), self.output_nr());
561 impl::get_autograd_meta(self)->retains_grad_ = true;
562}
563
564bool VariableHooks::retains_grad(const at::TensorBase& self) const {
565 if (impl::get_autograd_meta(self)) {
566 return impl::get_autograd_meta(self)->retains_grad_;
567 } else {
568 return false;
569 }
570}
571
572void VariableHooks::_backward(
573 const Tensor& self,
574 at::TensorList inputs,
575 const c10::optional<Tensor>& gradient,
576 c10::optional<bool> keep_graph,
577 bool create_graph) const {
578 // TODO torch::autograd::backward should take the c10::optional<Tensor>
579 // gradient directly instead of us having to unwrap it to Tensor _gradient
580 // here.
581 Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
582 std::vector<torch::autograd::Variable> input_vars(
583 inputs.begin(), inputs.end());
584 torch::autograd::backward(
585 {self}, {std::move(_gradient)}, keep_graph, create_graph, input_vars);
586}
587
588void VariableHooks::requires_grad_(
589 const at::TensorBase& self,
590 bool _requires_grad) const {
591 if (!self.is_leaf() && !_requires_grad) {
592 throw std::runtime_error(
593 autograd::utils::requires_grad_leaf_error(_requires_grad));
594 }
595 self.set_requires_grad(_requires_grad);
596}
597
598// Backward View Variables
599//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
600
601bool VariableHooks::is_view(const at::TensorBase& self) const {
602 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
603 if (diff_view_meta) {
604 return diff_view_meta->has_bw_view();
605 } else {
606 return false;
607 }
608}
609
610const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const {
611 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
612 if (diff_view_meta) {
613 TORCH_CHECK(
614 diff_view_meta->has_bw_view(),
615 "Can't get base of non-backward view Tensor");
616 return diff_view_meta->get_backward_view().base_;
617 } else {
618 throw std::runtime_error("Can't get base of non-view Tensor");
619 }
620}
621
622namespace {
623std::string singleton_string;
624}
625
626const std::string& VariableHooks::name(const at::TensorBase& self) const {
627 TORCH_CHECK(
628 self.defined(), "cannot call variable_data() on undefined tensor");
629 if (torch::autograd::impl::get_autograd_meta(self)) {
630 return torch::autograd::impl::get_autograd_meta(self)->name_;
631 } else {
632 return singleton_string;
633 }
634}
635
636namespace {
637std::shared_ptr<torch::autograd::Node> singleton_shared_ptr;
638}
639
640const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
641 const at::TensorBase& self) const {
642 auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self);
643 if (diff_view_meta && diff_view_meta->has_bw_view()) {
644 // See NOTE [ View + Inplace detection ]
645 std::lock_guard<std::mutex> lock(diff_view_meta->mutex_);
646 auto view_info = diff_view_meta->get_backward_view();
647 if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) {
648 return diff_view_meta->grad_fn_;
649 }
650 auto current_version = self._version();
651 auto old_fn = diff_view_meta->grad_fn_;
652 if (diff_view_meta->get_attr_version() != current_version) {
653 // This is an indirect rebase_history due to another view or the base
654 // being modified inplace
655 handle_view_on_rebase(diff_view_meta, /* indirect */ true);
656 TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0);
657 // Note [View + Inplace update for view tensor]
658 // An inplace update happened on Tensor `self` (which is a view).
659 // For example:
660 // view_1 = view_op_1(diff_view_meta->base_)
661 // view_2 = view_op_2(view_1)
662 // ...
663 // self = view_op_n(view_n-1)
664 // self = inplace_op(self)
665 //
666 // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to
667 // represent the chain of view backward ops for effienciency.
668 //
669 // However in XLA backend we don't have full support of
670 // AsStridedBackward0, we instead run a full forward pass with a tensor
671 // that requires gradient to get proper grad_fn setup, then save it to
672 // DifferentiableViewMeta for future use. This is fairly cheap for XLA
673 // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA
674 // Tensor only run thorugh VariableType dispatch and lower the forward
675 // pass to a XLA HLO graph, then we take grad_fn and never materialize the
676 // tensor content. So we only construct the graph but not execute it,
677 // which is a fairly cheap operation to do.
678 //
679 // See Note [View + Inplace update for base tensor] for what we do to base
680 // tensor when an in-place operation happens.
681 //
682 // TODO: Potentially the following logic can be replaced by special logic
683 // in VariableType_x.cpp
684 // that would provide a way to recreate the grad_fn chain.
685 if (view_info.has_view_fn()) {
686 auto view_fn = view_info.view_fn();
687 Tensor diff_view;
688 {
689 // We can reach this path with grad_mode disabled, e.g. engine
690 AutoGradMode grad_mode(true);
691 diff_view = view_fn(view_info.base_);
692 }
693 diff_view_meta->grad_fn_ = diff_view.grad_fn();
694 } else {
695 auto fn =
696 std::make_shared<torch::autograd::generated::AsStridedBackward0>();
697 fn->self_geometry = at::TensorGeometry(view_info.base_);
698 fn->size = self.sym_sizes().vec();
699 fn->stride = self.sym_strides().vec();
700 fn->storage_offset = self.sym_storage_offset();
701 fn->set_next_edges(
702 torch::autograd::collect_next_edges(view_info.base_));
703 fn->add_input_metadata(
704 view_info.base_.options(),
705 self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
706 // intentional
707 self.unsafeGetTensorImpl()->is_python_dispatch());
708 diff_view_meta->grad_fn_ = std::move(fn);
709 }
710 diff_view_meta->set_attr_version(current_version);
711
712 torch::autograd::impl::update_tensor_hooks_on_new_gradfn(
713 self, old_fn, diff_view_meta->grad_fn_);
714 }
715 return diff_view_meta->grad_fn_;
716 }
717
718 if (torch::autograd::impl::get_autograd_meta(self)) {
719 return torch::autograd::impl::get_autograd_meta(self)->grad_fn_;
720 } else {
721 return singleton_shared_ptr;
722 }
723}
724
725void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos)
726 const {
727 auto& list =
728 torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_;
729 TORCH_CHECK(
730 list && pos < list->size(), "Invalid index, no hook at position ", pos);
731 // Hook will be ignored
732 (*list)[pos] = nullptr;
733}
734
735unsigned VariableHooks::_register_hook(
736 const at::TensorBase& self,
737 std::function<at::TensorBase(const at::TensorBase&)> hook) const {
738 TORCH_CHECK(
739 self.requires_grad(),
740 "cannot register a hook on a variable that "
741 "doesn't require gradient");
742 // NB: materialize_autograd_meta unnecessary due to requires grad check
743 auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_;
744 if (!list) {
745 torch::autograd::impl::create_cpp_hook(
746 self, /*is_retains_grad_hook=*/false);
747 }
748 unsigned idx = list->size();
749 list->push_back(hook);
750 return idx;
751}
752
753void handle_view_on_rebase(
754 DifferentiableViewMeta* diff_view_meta,
755 bool indirect) {
756 /// See NOTE [ View + Inplace detection ] for justification of the logic below
757 auto creation_meta = diff_view_meta->get_creation_meta();
758 if (creation_meta != CreationMeta::DEFAULT) {
759 auto grad_fn = diff_view_meta->grad_fn_.get();
760 std::string msg;
761 std::string modified_obj;
762 // Create the header for the error message.
763 if (indirect) {
764 modified_obj = "its base or another view of its base has been";
765 } else {
766 modified_obj = "is being";
767 }
768
769 if (creation_meta == CreationMeta::INFERENCE_MODE ||
770 creation_meta == CreationMeta::NO_GRAD_MODE || !grad_fn) {
771 std::string prefix;
772 if (grad_fn) {
773 prefix = c10::str(
774 "Output ",
775 diff_view_meta->output_nr_,
776 " of ",
777 grad_fn->name(),
778 " is a view of a view which was created in");
779 } else {
780 prefix = "A view was created in";
781 }
782 if (creation_meta == CreationMeta::INFERENCE_MODE) {
783 msg = c10::str(
784 prefix,
785 " inference mode and ",
786 modified_obj,
787 " modified inplace in normal mode.");
788 } else {
789 // create_meta is not necessarily CreationMeta::NO_GRAD_MODE
790 // e.g. CreationMeta::IN_CUSTOM_FUNCTION is possible, but we know that
791 // if there is no grad_fn, that means that the view was performed in
792 // no-grad mode
793 msg = c10::str(
794 prefix,
795 " no_grad mode and ",
796 modified_obj,
797 " modified inplace with grad mode enabled.");
798 }
799 } else {
800 msg = c10::str(
801 "Output ",
802 diff_view_meta->output_nr_,
803 " of ",
804 grad_fn->name(),
805 " is a view and ",
806 modified_obj,
807 " modified inplace.");
808 }
809
810 if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) {
811 msg = c10::str(
812 msg,
813 " This view is the output of a function that returns multiple views. Such functions do not"
814 " allow the output views to be modified inplace. You should replace the inplace operation by an"
815 " out-of-place one.");
816 } else if (creation_meta == CreationMeta::NO_GRAD_MODE) {
817 msg = c10::str(
818 msg,
819 " Given that this use case is ambiguous and error-prone, it is forbidden."
820 " You can clarify your code by moving both the view and the inplace either both"
821 " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want"
822 " the inplace to be tracked).");
823 } else if (creation_meta == CreationMeta::INFERENCE_MODE) {
824 msg = c10::str(
825 msg,
826 " Given that this use case is ambiguous and error-prone, it is forbidden."
827 " You can clarify your code by moving both the view and the inplace either both"
828 " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want"
829 " the inplace to be tracked).");
830 } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) {
831 msg = c10::str(
832 msg,
833 " This view was created inside a custom Function (or because an input was returned as-is) and the"
834 " autograd logic to handle view+inplace would override the custom backward associated with the custom"
835 " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by"
836 " cloning the output of the custom Function.");
837 } else {
838 TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state");
839 }
840
841 TORCH_CHECK(false, msg);
842 }
843}
844
845} // namespace autograd
846} // namespace torch
847