1 | #include <torch/csrc/autograd/variable.h> |
2 | |
3 | #include <torch/csrc/autograd/InferenceMode.h> |
4 | #include <torch/csrc/autograd/autograd.h> |
5 | #include <torch/csrc/autograd/edge.h> |
6 | #include <torch/csrc/autograd/engine.h> |
7 | #include <torch/csrc/autograd/function.h> |
8 | #include <torch/csrc/autograd/functions/accumulate_grad.h> |
9 | #include <torch/csrc/autograd/functions/tensor.h> |
10 | #include <torch/csrc/autograd/generated/Functions.h> |
11 | #include <torch/csrc/autograd/utils/error_messages.h> |
12 | |
13 | #include <ATen/core/VariableHooksInterface.h> |
14 | |
15 | #include <ATen/ATen.h> |
16 | #include <ATen/FuncTorchTLS.h> |
17 | #include <ATen/MemoryOverlap.h> |
18 | #include <c10/util/Exception.h> |
19 | |
20 | #include <iostream> |
21 | #include <list> |
22 | #include <memory> |
23 | #include <mutex> |
24 | #include <stdexcept> |
25 | #include <string> |
26 | #include <typeinfo> |
27 | #include <utility> |
28 | #include <vector> |
29 | |
30 | namespace torch { |
31 | namespace autograd { |
32 | |
33 | DifferentiableViewMeta::DifferentiableViewMeta( |
34 | at::TensorImpl* self_impl, |
35 | c10::optional<ViewInfo> backward_info, |
36 | c10::optional<ViewInfo> forward_info, |
37 | bool shared_view_info, |
38 | CreationMeta creation_meta) |
39 | : AutogradMeta(self_impl), |
40 | backward_info_(std::move(backward_info)), |
41 | forward_info_(std::move(forward_info)), |
42 | shared_view_info_(shared_view_info), |
43 | creation_meta_(creation_meta) { |
44 | is_view_ = true; |
45 | if (backward_info_.has_value()) { |
46 | self_impl->set_version_counter( |
47 | impl::version_counter(backward_info_.value().base_)); |
48 | attr_version_ = self_impl->version_counter().current_version(); |
49 | TORCH_INTERNAL_ASSERT( |
50 | backward_info_.value().base_.unsafeGetTensorImpl() != self_impl); |
51 | } |
52 | if (shared_view_info_) { |
53 | TORCH_INTERNAL_ASSERT( |
54 | backward_info_.has_value(), |
55 | "Shared view info require a backward view info." ); |
56 | TORCH_INTERNAL_ASSERT( |
57 | !forward_info_.has_value(), |
58 | "Shared view info require forward view info to be empty" ) |
59 | } |
60 | } |
61 | |
62 | // Chain this view info with the new view op between base and tensor |
63 | ViewInfo ViewInfo::chain( |
64 | const Variable& base, |
65 | const Variable& tensor, |
66 | std::function<Variable(const Variable&)> view_func) const { |
67 | // Set `view_func` using the root base as input. |
68 | // `view_func` is used to recover views in backward when either as_strided is |
69 | // not supported or the view function changes the metadata which is not |
70 | // recorded by as_strided See Note [View + Inplace update on base tensor] and |
71 | // [View + Inplace update on view tensor] for more details how we use this |
72 | // function in backward. |
73 | if (view_func) { |
74 | // both current_view and it's parent have a view_func |
75 | if (view_fn_) { |
76 | // Copy parent view function to gain ownership |
77 | auto prev_fn = view_fn_; |
78 | view_func = [=](const at::Tensor& root_base) { |
79 | auto temp = prev_fn(root_base); |
80 | return view_func(temp); |
81 | }; |
82 | } else { |
83 | // current_view has a view_func and but it's parent doesn't have one |
84 | if (base.unsafeGetTensorImpl()->support_as_strided()) { |
85 | auto size = base.sym_sizes().vec(); |
86 | auto stride = base.sym_strides().vec(); |
87 | auto storage_offset = base.sym_storage_offset(); |
88 | view_func = [=](const at::Tensor& root_base) { |
89 | auto temp = root_base.as_strided_symint(size, stride, storage_offset); |
90 | return view_func(temp); |
91 | }; |
92 | } else { |
93 | // When base is a view but doesn't carry a view_fn in |
94 | // DifferentiableViewMeta, it's a view that doesn't support inplace |
95 | // update, e.g. unbind. In this case we should throw an error when |
96 | // inplace update happens in **forward**. One would naturally think the |
97 | // following function will be first called in backward pass. But the |
98 | // first call site is indeed in **forward** pass when we refresh |
99 | // `grad_fn` triggered by inplace update. Search Note [View + Inplace |
100 | // update for view tensor] to for the call site. |
101 | view_func = [=](const at::Tensor& root_base) { |
102 | TORCH_CHECK( |
103 | false, |
104 | "This view is the output of a function that returns multiple views." |
105 | "Such functions do not allow the output views to be modified inplace." |
106 | "You should replace the inplace operation by an out-of-place one" ); |
107 | return root_base; |
108 | }; |
109 | } |
110 | } |
111 | } else if (view_fn_) { |
112 | // if current_view doesn't have a view_func but it's parent has one |
113 | // Copy parent view function to gain ownership |
114 | auto prev_view_fn = view_fn_; |
115 | auto size = tensor.sym_sizes().vec(); |
116 | auto stride = tensor.sym_strides().vec(); |
117 | auto storage_offset = tensor.sym_storage_offset(); |
118 | view_func = [=](const at::Tensor& root_base) { |
119 | auto temp = prev_view_fn(root_base); |
120 | return temp.as_strided_symint(size, stride, storage_offset); |
121 | }; |
122 | } |
123 | |
124 | return ViewInfo(base_, std::move(view_func)); |
125 | } |
126 | |
127 | namespace { |
128 | |
129 | at::Tensor singleton_undefined_tensor; |
130 | |
131 | struct ConcreteAutogradMetaFactory : public c10::impl::AutogradMetaFactory { |
132 | std::unique_ptr<c10::AutogradMetaInterface> make() const override { |
133 | return std::make_unique<AutogradMeta>(); |
134 | } |
135 | const at::Tensor& undefined_tensor() const override { |
136 | return singleton_undefined_tensor; |
137 | } |
138 | }; |
139 | |
140 | ConcreteAutogradMetaFactory meta_factory; |
141 | |
142 | static c10::impl::AutogradMetaFactoryRegisterer meta_factory_registerer( |
143 | &meta_factory); |
144 | |
145 | } // namespace |
146 | |
147 | namespace impl { |
148 | |
149 | AutogradMeta* materialize_autograd_meta(const at::TensorBase& self) { |
150 | TORCH_CHECK( |
151 | self.defined(), |
152 | "cannot call materialize_autograd_meta() on undefined tensor" ); |
153 | auto p = self.unsafeGetTensorImpl(); |
154 | if (!p->autograd_meta()) { |
155 | p->set_autograd_meta(std::make_unique<AutogradMeta>()); |
156 | } |
157 | return get_autograd_meta(self); |
158 | } |
159 | |
160 | void update_tensor_hooks_on_new_gradfn( |
161 | const at::TensorBase& self, |
162 | const std::shared_ptr<torch::autograd::Node>& old_fn, |
163 | const std::shared_ptr<torch::autograd::Node>& new_fn) { |
164 | // This function is called whenever the grad_fn of the tensor is |
165 | // changed. We assume here that new_fn does not yet have hooks of |
166 | // its own. |
167 | // |
168 | // This function does two things: |
169 | // (1) reset the list when grad_fn is updated, so new hooks don't |
170 | // get erroneously registered to the old grad_fn. |
171 | // Note that the old cpp_hooks_list_ is still kept alive by the |
172 | // old grad_fn so hooks registered to the older version of the tensor |
173 | // will continue to be active. |
174 | // (2) If there is a retains_grad hook registered, move that from the |
175 | // old cpp_hooks_list_ to the new one |
176 | const auto& meta = impl::get_autograd_meta(self); |
177 | TORCH_INTERNAL_ASSERT(meta); |
178 | TORCH_INTERNAL_ASSERT(new_fn); |
179 | meta->cpp_hooks_list_ = nullptr; |
180 | const c10::impl::PyInterpreter* interp = |
181 | self.unsafeGetTensorImpl()->pyobj_slot()->pyobj_interpreter(); |
182 | if (interp) { |
183 | (*interp)->reset_backward_hooks(self.unsafeGetTensorImpl()); |
184 | } |
185 | if (self.retains_grad()) { |
186 | TORCH_INTERNAL_ASSERT(old_fn); |
187 | auto out = old_fn->pop_retains_grad_hook(self.output_nr()); |
188 | TORCH_INTERNAL_ASSERT(out != nullptr); |
189 | new_fn->add_retains_grad_hook(std::move(out), self.output_nr()); |
190 | } |
191 | } |
192 | |
193 | void rebase_history(const Variable& self, Edge gradient_edge) { |
194 | TORCH_INTERNAL_ASSERT(gradient_edge.function != nullptr); |
195 | const auto& meta = impl::get_autograd_meta(self); |
196 | auto old_fn = meta != nullptr ? meta->grad_fn_ : nullptr; |
197 | auto diff_view_meta = get_view_autograd_meta(self); |
198 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
199 | // See NOTE [ View + Inplace detection ] |
200 | auto creation_meta = diff_view_meta->get_creation_meta(); |
201 | // Do not use handle_view_on_rebase here as check_inplace should have been |
202 | // called before this and either throw an error |
203 | TORCH_INTERNAL_ASSERT(creation_meta == CreationMeta::DEFAULT); |
204 | TORCH_INTERNAL_ASSERT(gradient_edge.input_nr == 0); |
205 | TORCH_INTERNAL_ASSERT(gradient_edge.function); |
206 | TORCH_CHECK( |
207 | gradient_edge.function->num_inputs() == 1, |
208 | "Functions which modify views in-place must return a single Variable" ); |
209 | auto view_info = diff_view_meta->get_backward_view(); |
210 | diff_view_meta->output_nr_ = gradient_edge.input_nr; |
211 | auto copy_slices = std::make_shared<CopySlices>( |
212 | view_info.base_, |
213 | at::TensorGeometry(self), |
214 | view_info.view_fn_, |
215 | std::move(gradient_edge.function)); |
216 | set_gradient_edge(view_info.base_, {std::move(copy_slices), 0}); |
217 | self.grad_fn(); // trigger an update to the view's grad_fn |
218 | return; |
219 | } |
220 | |
221 | set_gradient_edge(self, std::move(gradient_edge)); |
222 | // Pass both self and its grad_fn to avoid calling into grad_fn reentrantly |
223 | torch::autograd::impl::update_tensor_hooks_on_new_gradfn( |
224 | self, old_fn, self.grad_fn()); |
225 | } |
226 | |
227 | void create_cpp_hook(const at::TensorBase& self, bool is_retains_grad_hook) { |
228 | const auto& fn = self.grad_fn(); |
229 | std::shared_ptr<hooks_list>& list = |
230 | materialize_autograd_meta(self)->cpp_hooks_list_; |
231 | // NOLINTNEXTLINE(modernize-make-shared) |
232 | list.reset(new hooks_list()); |
233 | std::unique_ptr<FunctionPreHook> hook_ptr{ |
234 | new CppFunctionTensorPreHook(list, self.output_nr())}; |
235 | // NB: we could potentially only update hooks_ if !fn, but it shouldn't |
236 | // matter |
237 | // and this was the way before, so we keep it like this for now. |
238 | clear_hooks(self); |
239 | add_hook(self, std::make_unique<CppFunctionTensorPreHook>(list, 0)); |
240 | if (fn) { |
241 | fn->add_tensor_pre_hook(std::move(hook_ptr)); |
242 | } |
243 | } |
244 | |
245 | void set_grad_accumulator( |
246 | const Variable& self, |
247 | std::weak_ptr<Node> grad_accumulator) { |
248 | materialize_autograd_meta(self)->grad_accumulator_ = |
249 | std::move(grad_accumulator); |
250 | } |
251 | |
252 | std::shared_ptr<Node> try_get_grad_accumulator(const Variable& self) { |
253 | if (get_autograd_meta(self)) { |
254 | return get_autograd_meta(self)->grad_accumulator_.lock(); |
255 | } else { |
256 | return nullptr; |
257 | } |
258 | } |
259 | |
260 | std::shared_ptr<Node> grad_accumulator(const Variable& self) { |
261 | auto autograd_meta = get_autograd_meta(self); |
262 | if (!autograd_meta) { |
263 | return nullptr; |
264 | } |
265 | if (autograd_meta->grad_fn_) { |
266 | throw std::logic_error( |
267 | "grad_accumulator() should be only called on leaf Variables" ); |
268 | } |
269 | if (!autograd_meta->requires_grad_) { |
270 | return nullptr; |
271 | } |
272 | |
273 | std::lock_guard<std::mutex> lock(autograd_meta->mutex_); |
274 | |
275 | auto result = autograd_meta->grad_accumulator_.lock(); |
276 | if (result) |
277 | return result; |
278 | |
279 | c10::raw::intrusive_ptr::incref(self.unsafeGetTensorImpl()); |
280 | auto intrusive_from_this = |
281 | c10::intrusive_ptr<at::TensorImpl>::reclaim(self.unsafeGetTensorImpl()); |
282 | result = std::make_shared<AccumulateGrad>( |
283 | Variable(std::move(intrusive_from_this))); |
284 | autograd_meta->grad_accumulator_ = result; |
285 | return result; |
286 | } |
287 | |
288 | Edge gradient_edge(const Variable& self) { |
289 | // If grad_fn is null (as is the case for a leaf node), we instead |
290 | // interpret the gradient function to be a gradient accumulator, which will |
291 | // accumulate its inputs into the grad property of the variable. These |
292 | // nodes get suppressed in some situations, see "suppress gradient |
293 | // accumulation" below. Note that only variables which have `requires_grad = |
294 | // True` can have gradient accumulators. |
295 | if (const auto& gradient = self.grad_fn()) { |
296 | return Edge(gradient, self.output_nr()); |
297 | } else { |
298 | return Edge(grad_accumulator(self), 0); |
299 | } |
300 | } |
301 | |
302 | void set_gradient_edge(const Variable& self, Edge edge) { |
303 | auto* meta = materialize_autograd_meta(self); |
304 | meta->grad_fn_ = std::move(edge.function); |
305 | meta->output_nr_ = edge.input_nr; |
306 | // For views, make sure this new grad_fn_ is not overwritten unless it is |
307 | // necessary in the VariableHooks::grad_fn below. This logic is only relevant |
308 | // for custom autograd Functions for which multiple operations can happen on a |
309 | // given Tensor before its gradient edge is set when exiting the custom |
310 | // Function. |
311 | auto diff_view_meta = get_view_autograd_meta(self); |
312 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
313 | diff_view_meta->set_attr_version(self._version()); |
314 | } |
315 | } |
316 | |
317 | Node* grad_fn_unsafe(const Variable& self) { |
318 | if (get_autograd_meta(self)) { |
319 | return get_autograd_meta(self)->grad_fn_.get(); |
320 | } else { |
321 | return nullptr; |
322 | } |
323 | } |
324 | |
325 | // Versions |
326 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
327 | |
328 | void set_version_counter( |
329 | const Variable& self, |
330 | const c10::VariableVersion& version_counter) { |
331 | TORCH_CHECK( |
332 | self.defined(), "cannot call set_version_counter() on undefined tensor" ); |
333 | self.unsafeGetTensorImpl()->set_version_counter(version_counter); |
334 | } |
335 | |
336 | void bump_version(const Variable& self) { |
337 | TORCH_CHECK(self.defined(), "cannot call bump_version() on undefined tensor" ); |
338 | self.unsafeGetTensorImpl()->bump_version(); |
339 | } |
340 | |
341 | const c10::VariableVersion& version_counter(const Variable& self) { |
342 | TORCH_CHECK( |
343 | self.defined(), "cannot call version_counter() on undefined tensor" ); |
344 | return self.unsafeGetTensorImpl()->version_counter(); |
345 | } |
346 | |
347 | // Hooks |
348 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
349 | |
350 | void add_hook( |
351 | const at::TensorBase& self, |
352 | std::unique_ptr<FunctionPreHook> hook) { |
353 | AutogradMeta* meta = materialize_autograd_meta(self); |
354 | TORCH_INTERNAL_ASSERT(meta->hooks_.empty()); |
355 | meta->hooks_.push_back(std::move(hook)); |
356 | } |
357 | |
358 | std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable& self) { |
359 | TORCH_INTERNAL_ASSERT(get_autograd_meta(self)); |
360 | return get_autograd_meta(self)->hooks_; |
361 | } |
362 | |
363 | void clear_hooks(const at::TensorBase& self) { |
364 | // This is a little goofy, but usually this should be a no oop |
365 | materialize_autograd_meta(self)->hooks_.clear(); |
366 | } |
367 | |
368 | void set_name(const Variable& self, const std::string& name) { |
369 | materialize_autograd_meta(self)->name_ = name; |
370 | } |
371 | |
372 | // Miscellaneous |
373 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
374 | |
375 | AutogradMeta* get_autograd_meta(const at::TensorBase& self) { |
376 | // NB: could return nullptr |
377 | TORCH_CHECK( |
378 | self.defined(), "cannot call get_autograd_meta() on undefined tensor" ); |
379 | return static_cast<AutogradMeta*>( |
380 | self.unsafeGetTensorImpl()->autograd_meta()); |
381 | } |
382 | |
383 | DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase& self) { |
384 | // NB: return nullptr if self is not a view |
385 | AutogradMeta* meta = get_autograd_meta(self); |
386 | if (meta && meta->is_view_) { |
387 | return static_cast<DifferentiableViewMeta*>(meta); |
388 | } else { |
389 | return nullptr; |
390 | } |
391 | } |
392 | |
393 | } // namespace impl |
394 | |
395 | using at::Tensor; |
396 | |
397 | struct VariableHooks final : at::impl::VariableHooksInterface { |
398 | at::TensorBase tensor_data(const at::TensorBase&) const override; |
399 | at::TensorBase variable_data(const at::TensorBase&) const override; |
400 | const std::shared_ptr<torch::autograd::Node>& grad_fn( |
401 | const at::TensorBase&) const override; |
402 | unsigned _register_hook( |
403 | const at::TensorBase&, |
404 | std::function<at::TensorBase(const at::TensorBase&)> hook) const override; |
405 | void remove_hook(const at::TensorBase&, unsigned pos) const override; |
406 | bool is_view(const at::TensorBase&) const override; |
407 | const at::TensorBase& base(const at::TensorBase&) const override; |
408 | const std::string& name(const at::TensorBase&) const override; |
409 | bool is_leaf(const at::TensorBase&) const override; |
410 | int64_t output_nr(const at::TensorBase&) const override; |
411 | void set_data(const at::TensorBase& self, const at::TensorBase& new_data) |
412 | const override; |
413 | at::TensorBase data(const at::TensorBase& self) const override; |
414 | int64_t _version(const at::TensorBase& self) const override; |
415 | void retain_grad(const at::TensorBase& self) const override; |
416 | bool retains_grad(const at::TensorBase& self) const override; |
417 | void _backward( |
418 | const Tensor& self, |
419 | at::TensorList inputs, |
420 | const c10::optional<Tensor>& gradient, |
421 | c10::optional<bool> keep_graph, |
422 | bool create_graph) const override; |
423 | void requires_grad_(const at::TensorBase& self, bool _requires_grad) |
424 | const override; |
425 | }; |
426 | |
427 | VariableHooks variableHooks; |
428 | at::impl::VariableHooksRegisterer registerVariableHooks(&variableHooks); |
429 | |
430 | at::TensorBase VariableHooks::variable_data(const at::TensorBase& self) const { |
431 | TORCH_CHECK( |
432 | self.defined(), "cannot call variable_data() on undefined tensor" ); |
433 | auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( |
434 | /*version_counter=*/0, |
435 | /*allow_tensor_metadata_change=*/false); |
436 | self_impl_copy->set_autograd_meta(nullptr); |
437 | return at::Tensor(self_impl_copy); |
438 | } |
439 | |
440 | at::TensorBase VariableHooks::tensor_data(const at::TensorBase& self) const { |
441 | TORCH_CHECK(self.defined(), "cannot call tensor_data() on undefined tensor" ); |
442 | auto self_impl_copy = self.unsafeGetTensorImpl()->shallow_copy_and_detach( |
443 | /*version_counter=*/self.unsafeGetTensorImpl()->version_counter(), |
444 | /*allow_tensor_metadata_change=*/ |
445 | self.unsafeGetTensorImpl()->allow_tensor_metadata_change()); |
446 | return at::Tensor(self_impl_copy); |
447 | } |
448 | |
449 | bool VariableHooks::is_leaf(const at::TensorBase& self) const { |
450 | if (impl::get_autograd_meta(self)) { |
451 | return impl::get_autograd_meta(self)->grad_fn_ == nullptr; |
452 | } else { |
453 | return true; |
454 | } |
455 | } |
456 | |
457 | int64_t VariableHooks::output_nr(const at::TensorBase& self) const { |
458 | if (impl::get_autograd_meta(self)) { |
459 | return impl::get_autograd_meta(self)->output_nr_; |
460 | } else { |
461 | return 0; |
462 | } |
463 | } |
464 | |
465 | void VariableHooks::set_data( |
466 | const at::TensorBase& self_base, |
467 | const at::TensorBase& new_data_base) const { |
468 | at::OptionalTensorRef self_ref(self_base); |
469 | const Tensor& self = *self_ref; |
470 | at::OptionalTensorRef new_data_ref(new_data_base); |
471 | const Tensor& new_data = *new_data_ref; |
472 | |
473 | // `var.set_data(new_data)` shallow-copies all non-autograd TensorImpl fields |
474 | // from `new_data` to `var`. It requires that `new_data` and `var` have |
475 | // compatible tensor type. |
476 | TORCH_CHECK( |
477 | _has_compatible_shallow_copy_type(self, new_data), |
478 | "Attempted to call `variable.set_data(tensor)`, but `variable` and `tensor` have incompatible tensor type." ); |
479 | |
480 | TORCH_CHECK( |
481 | !self.requires_grad() || |
482 | isDifferentiableType(at::typeMetaToScalarType(new_data.dtype())), |
483 | "data set to a tensor that requires gradients must be floating point or complex dtype" ); |
484 | |
485 | // Resets gradient accumulator if metadata is out of date |
486 | AutogradMeta* autograd_meta = impl::get_autograd_meta(self); |
487 | if (autograd_meta) { |
488 | std::lock_guard<std::mutex> lock(autograd_meta->mutex_); |
489 | auto prior_accumulator = autograd_meta->grad_accumulator_.lock(); |
490 | if (prior_accumulator) { |
491 | const auto prior_device = prior_accumulator->input_metadata(0).device(); |
492 | const auto new_device = new_data.device(); |
493 | |
494 | if (!new_data.options().type_equal(self.options()) || |
495 | prior_device != new_device) { |
496 | autograd_meta->grad_accumulator_.reset(); |
497 | } |
498 | } |
499 | } |
500 | |
501 | // Version counter is not shared when we replace a `Variable`'s tensor data |
502 | // by calling `set_data(...)`. The original version of the `Variable` is |
503 | // always preserved. See NOTE [ Version Counter Sharing ] for details. |
504 | // |
505 | // `var.set_data(new_data)` always ignores `var`'s |
506 | // `allow_tensor_metadata_change_`, because users need this API as an escape |
507 | // hatch for changing a tensor's metadata regardless of its |
508 | // `allow_tensor_metadata_change_` value, and the users are responsible for |
509 | // ensuring this is the behavior they want. |
510 | self.unsafeGetTensorImpl()->shallow_copy_from(new_data.getIntrusivePtr()); |
511 | } |
512 | |
513 | at::TensorBase VariableHooks::data(const at::TensorBase& self) const { |
514 | return self.variable_data(); |
515 | } |
516 | |
517 | int64_t VariableHooks::_version(const at::TensorBase& self) const { |
518 | return self.unsafeGetTensorImpl()->version_counter().current_version(); |
519 | } |
520 | |
521 | void VariableHooks::retain_grad(const at::TensorBase& self) const { |
522 | TORCH_CHECK( |
523 | self.requires_grad(), |
524 | "can't retain_grad on Tensor that has requires_grad=False" ); |
525 | |
526 | // temporary hack to improve functorch UX. |
527 | const auto& functorch_tls = at::functorch::functorchTLSAccessor(); |
528 | if (functorch_tls) { |
529 | functorch_tls->checkSupportsRetainGrad(); |
530 | } |
531 | |
532 | if (self.is_leaf()) { // no-op for leaves |
533 | return; |
534 | } |
535 | if (impl::get_autograd_meta(self)->retains_grad_) { |
536 | return; |
537 | } |
538 | c10::weak_intrusive_ptr<c10::TensorImpl> weak_self(self.getIntrusivePtr()); |
539 | |
540 | auto retain_grad_hook = [weak_self](const at::TensorBase& grad_base) { |
541 | at::Tensor grad{grad_base}; |
542 | if (!weak_self.expired() && grad.defined()) { |
543 | auto var = weak_self.lock(); |
544 | if (!var->grad().defined()) { |
545 | if (grad.is_sparse()) { |
546 | var->mutable_grad() = grad.clone(); |
547 | } else { |
548 | var->mutable_grad() = grad.clone(at::MemoryFormat::Contiguous); |
549 | } |
550 | } else { |
551 | var->mutable_grad() = var->grad() + grad; |
552 | } |
553 | } |
554 | return at::TensorBase{}; |
555 | }; |
556 | |
557 | const auto& fn = self.grad_fn(); |
558 | std::unique_ptr<FunctionPreHook> hook_ptr{new CppFunctionSingleTensorPreHook( |
559 | std::move(retain_grad_hook), self.output_nr())}; |
560 | fn->add_retains_grad_hook(std::move(hook_ptr), self.output_nr()); |
561 | impl::get_autograd_meta(self)->retains_grad_ = true; |
562 | } |
563 | |
564 | bool VariableHooks::retains_grad(const at::TensorBase& self) const { |
565 | if (impl::get_autograd_meta(self)) { |
566 | return impl::get_autograd_meta(self)->retains_grad_; |
567 | } else { |
568 | return false; |
569 | } |
570 | } |
571 | |
572 | void VariableHooks::_backward( |
573 | const Tensor& self, |
574 | at::TensorList inputs, |
575 | const c10::optional<Tensor>& gradient, |
576 | c10::optional<bool> keep_graph, |
577 | bool create_graph) const { |
578 | // TODO torch::autograd::backward should take the c10::optional<Tensor> |
579 | // gradient directly instead of us having to unwrap it to Tensor _gradient |
580 | // here. |
581 | Tensor _gradient = gradient.has_value() ? *gradient : Tensor(); |
582 | std::vector<torch::autograd::Variable> input_vars( |
583 | inputs.begin(), inputs.end()); |
584 | torch::autograd::backward( |
585 | {self}, {std::move(_gradient)}, keep_graph, create_graph, input_vars); |
586 | } |
587 | |
588 | void VariableHooks::requires_grad_( |
589 | const at::TensorBase& self, |
590 | bool _requires_grad) const { |
591 | if (!self.is_leaf() && !_requires_grad) { |
592 | throw std::runtime_error( |
593 | autograd::utils::requires_grad_leaf_error(_requires_grad)); |
594 | } |
595 | self.set_requires_grad(_requires_grad); |
596 | } |
597 | |
598 | // Backward View Variables |
599 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
600 | |
601 | bool VariableHooks::is_view(const at::TensorBase& self) const { |
602 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
603 | if (diff_view_meta) { |
604 | return diff_view_meta->has_bw_view(); |
605 | } else { |
606 | return false; |
607 | } |
608 | } |
609 | |
610 | const at::TensorBase& VariableHooks::base(const at::TensorBase& self) const { |
611 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
612 | if (diff_view_meta) { |
613 | TORCH_CHECK( |
614 | diff_view_meta->has_bw_view(), |
615 | "Can't get base of non-backward view Tensor" ); |
616 | return diff_view_meta->get_backward_view().base_; |
617 | } else { |
618 | throw std::runtime_error("Can't get base of non-view Tensor" ); |
619 | } |
620 | } |
621 | |
622 | namespace { |
623 | std::string singleton_string; |
624 | } |
625 | |
626 | const std::string& VariableHooks::name(const at::TensorBase& self) const { |
627 | TORCH_CHECK( |
628 | self.defined(), "cannot call variable_data() on undefined tensor" ); |
629 | if (torch::autograd::impl::get_autograd_meta(self)) { |
630 | return torch::autograd::impl::get_autograd_meta(self)->name_; |
631 | } else { |
632 | return singleton_string; |
633 | } |
634 | } |
635 | |
636 | namespace { |
637 | std::shared_ptr<torch::autograd::Node> singleton_shared_ptr; |
638 | } |
639 | |
640 | const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn( |
641 | const at::TensorBase& self) const { |
642 | auto diff_view_meta = torch::autograd::impl::get_view_autograd_meta(self); |
643 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
644 | // See NOTE [ View + Inplace detection ] |
645 | std::lock_guard<std::mutex> lock(diff_view_meta->mutex_); |
646 | auto view_info = diff_view_meta->get_backward_view(); |
647 | if (!diff_view_meta->grad_fn_ && !view_info.base_.requires_grad()) { |
648 | return diff_view_meta->grad_fn_; |
649 | } |
650 | auto current_version = self._version(); |
651 | auto old_fn = diff_view_meta->grad_fn_; |
652 | if (diff_view_meta->get_attr_version() != current_version) { |
653 | // This is an indirect rebase_history due to another view or the base |
654 | // being modified inplace |
655 | handle_view_on_rebase(diff_view_meta, /* indirect */ true); |
656 | TORCH_INTERNAL_ASSERT(diff_view_meta->output_nr_ == 0); |
657 | // Note [View + Inplace update for view tensor] |
658 | // An inplace update happened on Tensor `self` (which is a view). |
659 | // For example: |
660 | // view_1 = view_op_1(diff_view_meta->base_) |
661 | // view_2 = view_op_2(view_1) |
662 | // ... |
663 | // self = view_op_n(view_n-1) |
664 | // self = inplace_op(self) |
665 | // |
666 | // For CPU/CUDA backends, we employ one AsStridedBackward0 Node to |
667 | // represent the chain of view backward ops for effienciency. |
668 | // |
669 | // However in XLA backend we don't have full support of |
670 | // AsStridedBackward0, we instead run a full forward pass with a tensor |
671 | // that requires gradient to get proper grad_fn setup, then save it to |
672 | // DifferentiableViewMeta for future use. This is fairly cheap for XLA |
673 | // lazy tensor approach (but would be really expensive for CPU/CUDA). XLA |
674 | // Tensor only run thorugh VariableType dispatch and lower the forward |
675 | // pass to a XLA HLO graph, then we take grad_fn and never materialize the |
676 | // tensor content. So we only construct the graph but not execute it, |
677 | // which is a fairly cheap operation to do. |
678 | // |
679 | // See Note [View + Inplace update for base tensor] for what we do to base |
680 | // tensor when an in-place operation happens. |
681 | // |
682 | // TODO: Potentially the following logic can be replaced by special logic |
683 | // in VariableType_x.cpp |
684 | // that would provide a way to recreate the grad_fn chain. |
685 | if (view_info.has_view_fn()) { |
686 | auto view_fn = view_info.view_fn(); |
687 | Tensor diff_view; |
688 | { |
689 | // We can reach this path with grad_mode disabled, e.g. engine |
690 | AutoGradMode grad_mode(true); |
691 | diff_view = view_fn(view_info.base_); |
692 | } |
693 | diff_view_meta->grad_fn_ = diff_view.grad_fn(); |
694 | } else { |
695 | auto fn = |
696 | std::make_shared<torch::autograd::generated::AsStridedBackward0>(); |
697 | fn->self_geometry = at::TensorGeometry(view_info.base_); |
698 | fn->size = self.sym_sizes().vec(); |
699 | fn->stride = self.sym_strides().vec(); |
700 | fn->storage_offset = self.sym_storage_offset(); |
701 | fn->set_next_edges( |
702 | torch::autograd::collect_next_edges(view_info.base_)); |
703 | fn->add_input_metadata( |
704 | view_info.base_.options(), |
705 | self.sym_sizes(), // Note: sizes(), not base_.sizes(), is |
706 | // intentional |
707 | self.unsafeGetTensorImpl()->is_python_dispatch()); |
708 | diff_view_meta->grad_fn_ = std::move(fn); |
709 | } |
710 | diff_view_meta->set_attr_version(current_version); |
711 | |
712 | torch::autograd::impl::update_tensor_hooks_on_new_gradfn( |
713 | self, old_fn, diff_view_meta->grad_fn_); |
714 | } |
715 | return diff_view_meta->grad_fn_; |
716 | } |
717 | |
718 | if (torch::autograd::impl::get_autograd_meta(self)) { |
719 | return torch::autograd::impl::get_autograd_meta(self)->grad_fn_; |
720 | } else { |
721 | return singleton_shared_ptr; |
722 | } |
723 | } |
724 | |
725 | void VariableHooks::remove_hook(const at::TensorBase& self, unsigned pos) |
726 | const { |
727 | auto& list = |
728 | torch::autograd::impl::materialize_autograd_meta(self)->cpp_hooks_list_; |
729 | TORCH_CHECK( |
730 | list && pos < list->size(), "Invalid index, no hook at position " , pos); |
731 | // Hook will be ignored |
732 | (*list)[pos] = nullptr; |
733 | } |
734 | |
735 | unsigned VariableHooks::_register_hook( |
736 | const at::TensorBase& self, |
737 | std::function<at::TensorBase(const at::TensorBase&)> hook) const { |
738 | TORCH_CHECK( |
739 | self.requires_grad(), |
740 | "cannot register a hook on a variable that " |
741 | "doesn't require gradient" ); |
742 | // NB: materialize_autograd_meta unnecessary due to requires grad check |
743 | auto& list = torch::autograd::impl::get_autograd_meta(self)->cpp_hooks_list_; |
744 | if (!list) { |
745 | torch::autograd::impl::create_cpp_hook( |
746 | self, /*is_retains_grad_hook=*/false); |
747 | } |
748 | unsigned idx = list->size(); |
749 | list->push_back(hook); |
750 | return idx; |
751 | } |
752 | |
753 | void handle_view_on_rebase( |
754 | DifferentiableViewMeta* diff_view_meta, |
755 | bool indirect) { |
756 | /// See NOTE [ View + Inplace detection ] for justification of the logic below |
757 | auto creation_meta = diff_view_meta->get_creation_meta(); |
758 | if (creation_meta != CreationMeta::DEFAULT) { |
759 | auto grad_fn = diff_view_meta->grad_fn_.get(); |
760 | std::string msg; |
761 | std::string modified_obj; |
762 | // Create the header for the error message. |
763 | if (indirect) { |
764 | modified_obj = "its base or another view of its base has been" ; |
765 | } else { |
766 | modified_obj = "is being" ; |
767 | } |
768 | |
769 | if (creation_meta == CreationMeta::INFERENCE_MODE || |
770 | creation_meta == CreationMeta::NO_GRAD_MODE || !grad_fn) { |
771 | std::string prefix; |
772 | if (grad_fn) { |
773 | prefix = c10::str( |
774 | "Output " , |
775 | diff_view_meta->output_nr_, |
776 | " of " , |
777 | grad_fn->name(), |
778 | " is a view of a view which was created in" ); |
779 | } else { |
780 | prefix = "A view was created in" ; |
781 | } |
782 | if (creation_meta == CreationMeta::INFERENCE_MODE) { |
783 | msg = c10::str( |
784 | prefix, |
785 | " inference mode and " , |
786 | modified_obj, |
787 | " modified inplace in normal mode." ); |
788 | } else { |
789 | // create_meta is not necessarily CreationMeta::NO_GRAD_MODE |
790 | // e.g. CreationMeta::IN_CUSTOM_FUNCTION is possible, but we know that |
791 | // if there is no grad_fn, that means that the view was performed in |
792 | // no-grad mode |
793 | msg = c10::str( |
794 | prefix, |
795 | " no_grad mode and " , |
796 | modified_obj, |
797 | " modified inplace with grad mode enabled." ); |
798 | } |
799 | } else { |
800 | msg = c10::str( |
801 | "Output " , |
802 | diff_view_meta->output_nr_, |
803 | " of " , |
804 | grad_fn->name(), |
805 | " is a view and " , |
806 | modified_obj, |
807 | " modified inplace." ); |
808 | } |
809 | |
810 | if (creation_meta == CreationMeta::MULTI_OUTPUT_NODE) { |
811 | msg = c10::str( |
812 | msg, |
813 | " This view is the output of a function that returns multiple views. Such functions do not" |
814 | " allow the output views to be modified inplace. You should replace the inplace operation by an" |
815 | " out-of-place one." ); |
816 | } else if (creation_meta == CreationMeta::NO_GRAD_MODE) { |
817 | msg = c10::str( |
818 | msg, |
819 | " Given that this use case is ambiguous and error-prone, it is forbidden." |
820 | " You can clarify your code by moving both the view and the inplace either both" |
821 | " inside the no_grad block (if you don't want the inplace to be tracked) or both outside (if you want" |
822 | " the inplace to be tracked)." ); |
823 | } else if (creation_meta == CreationMeta::INFERENCE_MODE) { |
824 | msg = c10::str( |
825 | msg, |
826 | " Given that this use case is ambiguous and error-prone, it is forbidden." |
827 | " You can clarify your code by moving both the view and the inplace either both" |
828 | " inside the inference_mode block (if you don't want the inplace to be tracked) or both outside (if you want" |
829 | " the inplace to be tracked)." ); |
830 | } else if (creation_meta == CreationMeta::IN_CUSTOM_FUNCTION) { |
831 | msg = c10::str( |
832 | msg, |
833 | " This view was created inside a custom Function (or because an input was returned as-is) and the" |
834 | " autograd logic to handle view+inplace would override the custom backward associated with the custom" |
835 | " Function, leading to incorrect gradients. This behavior is forbidden. You can fix this by" |
836 | " cloning the output of the custom Function." ); |
837 | } else { |
838 | TORCH_INTERNAL_ASSERT(false, "Invalid CreationMeta state" ); |
839 | } |
840 | |
841 | TORCH_CHECK(false, msg); |
842 | } |
843 | } |
844 | |
845 | } // namespace autograd |
846 | } // namespace torch |
847 | |