1 | #pragma once |
2 | |
3 | #include <torch/csrc/utils/python_stub.h> |
4 | |
5 | #include <torch/csrc/Export.h> |
6 | #include <torch/csrc/autograd/cpp_hook.h> |
7 | #include <torch/csrc/autograd/edge.h> |
8 | #include <torch/csrc/autograd/forward_grad.h> |
9 | #include <torch/csrc/autograd/function_hook.h> |
10 | |
11 | #include <ATen/NamedTensorUtils.h> |
12 | #include <ATen/core/Tensor.h> |
13 | #include <c10/util/Exception.h> |
14 | |
15 | #include <cstdint> |
16 | #include <memory> |
17 | #include <mutex> |
18 | #include <stdexcept> |
19 | #include <string> |
20 | #include <utility> |
21 | #include <vector> |
22 | |
23 | namespace torch { |
24 | namespace autograd { |
25 | |
26 | /// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable = |
27 | /// at::Tensor`). This means you can perform all the usual mathematical and |
28 | /// other operations you can perform on `Tensor`s also on `Variable`s. |
29 | /// |
30 | /// The only reason we are keeping the `Variable` class is backward |
31 | /// compatibility with external user's legacy C++ frontend code. Our intention |
32 | /// is to eliminate the `Variable` class in the near future. |
33 | using Variable = at::Tensor; |
34 | |
35 | } // namespace autograd |
36 | } // namespace torch |
37 | |
38 | // The following are all internal APIs and should not be shown in libtorch docs. |
39 | // Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS |
40 | // ... #endif` |
41 | |
42 | #ifndef DOXYGEN_SHOULD_SKIP_THIS |
43 | |
44 | namespace torch { |
45 | namespace autograd { |
46 | |
47 | /// Check if this type is supported by the autograd engine. |
48 | /// If you change this, update the doc at the top of the |
49 | /// torch/autograd/__init__.py file and |
50 | /// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py |
51 | static inline bool isDifferentiableType(at::ScalarType t) { |
52 | return isFloatingType(t) || isComplexType(t); |
53 | } |
54 | |
55 | struct Node; |
56 | |
57 | ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
58 | /// Variable |
59 | ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
60 | /// A `Variable` augments a `Tensor` with the ability to interact in our |
61 | /// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between |
62 | /// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a |
63 | /// weight in a neural network, or an interior variable, when it is the result |
64 | /// of an operation between variables. Every `Variable` also stores another |
65 | /// `Variable` called its `grad` (gradient). If the variable is a leaf, its |
66 | /// gradient will be accumulated into this variable. |
67 | /// |
68 | /// Every Tensor is a Variable, but sometimes we colloquially refer to Variables |
69 | /// that don't require gradients as Tensors (since none of the autograd |
70 | /// machinery for Variables applies). Historically, Variables and Tensors |
71 | /// were separate concepts, but now they are exactly the same (i.e. we have |
72 | /// `using Variable = at::Tensor`). |
73 | /// |
74 | /// Gradient Edges |
75 | ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
76 | /// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the |
77 | /// edge in the autograd graph that connects the variable to a particular input |
78 | /// of the gradient function that will be invoked with the variable during the |
79 | /// backward pass. More precisely, this gradient function can be one of two |
80 | /// things: |
81 | /// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the |
82 | /// gradient of the function that produced the variable. |
83 | /// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a |
84 | /// scalar gradient value into its `grad` variable. |
85 | /// |
86 | /// Versioning |
87 | ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
88 | /// Another major feature of `Variable`s are *versions*. Versions are |
89 | /// incremented when an in-place mutation of a variable occurs. Versions are |
90 | /// useful when constructing `SavedVariable`s, which take a snapshot of a |
91 | /// `Variable` at a certain version. You can retrieve a `Variable`'s version |
92 | /// through its `current_version()` method. |
93 | /// |
94 | /// Views |
95 | ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
96 | /// It is possible for a `Variable` to be a *view* of another `Variable`, in |
97 | /// which case it tracks that `Variable`'s data and autograd history. Beyond |
98 | /// construction, the interface of a view is identical to that of a regular |
99 | /// `Variable`. You can determine whether `Variable` is in fact a view by |
100 | /// probing its `is_view()` method. Note that the *view* semantics are only |
101 | /// meaningful for `Variable` relations that are relevant to autograd. |
102 | /// See NOTE [ Autograd View Variables ] for more details. |
103 | ///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
104 | |
105 | struct AutogradMeta; |
106 | struct DifferentiableViewMeta; |
107 | |
108 | // Private-ish functions for manipulating variables; we don't want to put them |
109 | // on Tensor proper |
110 | namespace impl { |
111 | |
112 | // WARNING: This may return a nullptr. If you require AutogradMeta to return |
113 | // a materialized structure, use materialize_autograd_meta instead. |
114 | TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&); |
115 | |
116 | // WARNING: This will return a nullptr if the Tensor is not a view. |
117 | TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&); |
118 | |
119 | // Returns the current autograd meta, materializing it if it was previously |
120 | // none. This counts as a *mutating* operation, so do not call it on |
121 | // "read-only" operators; in particular, this is NOT thread safe |
122 | TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&); |
123 | |
124 | /// Set the gradient accumulator of the `Variable`. This is only applicable to |
125 | /// leaf variables. Interior variables should call `set_gradient_edge()`. |
126 | TORCH_API void set_grad_accumulator( |
127 | const Variable&, |
128 | std::weak_ptr<Node> grad_accumulator); |
129 | |
130 | /// Attempts to get a pointer to the gradient accumulator of the `Variable`, |
131 | /// if it still exists. If the gradient accumulator function has been |
132 | /// destroyed, returns a `nullptr`. |
133 | TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&); |
134 | |
135 | /// Gets the gradient accumulator of the `Variable` if it has one, or else |
136 | /// create one on the fly and return it. |
137 | TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&); |
138 | |
139 | /// Returns the "canonical" gradient edge of this `Variable`, i.e. either the |
140 | /// gradient function if this is an interior `Variable`, or the gradient |
141 | /// accumulator otherwise. If the `Variable` is interior, the returned `Edge` |
142 | /// will store the input index of the `Node` to which this variable is |
143 | /// connected in its `input_nr` field. For leaves, the `input_nr` is always |
144 | /// zero. Note that `set_gradient_edge` and `gradient_edge` are not |
145 | /// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and |
146 | /// `set_grad_accumulator` to set the accumulator. |
147 | TORCH_API Edge gradient_edge(const Variable&); |
148 | |
149 | /// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the |
150 | /// `Variable`. |
151 | /// NOTE: This will always set the `grad_fn`, even if this is a leaf variable, |
152 | /// and never the `grad_accumulator`. For the latter, use |
153 | /// `set_grad_accumulator`. This allows late construction of an interior |
154 | /// `Variable`. |
155 | TORCH_API void set_gradient_edge(const Variable&, Edge edge); |
156 | |
157 | // Autograd Graph Interaction |
158 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
159 | |
160 | /// Update the `grad_fn` of an existing Variable. Called after in-place |
161 | /// modifications. |
162 | /// |
163 | /// For View Variables: |
164 | /// Called after in-place modifications. Modifies the grad_fn of the base |
165 | /// Variable. |
166 | TORCH_API void rebase_history(const Variable&, Edge gradient_edge); |
167 | |
168 | /// Gets the raw gradient function pointer, whatever it currently is. |
169 | TORCH_API Node* grad_fn_unsafe(const Variable&); |
170 | |
171 | /// Increments the version count of this `Variable`. |
172 | TORCH_API void bump_version(const Variable&); |
173 | TORCH_API void set_version_counter( |
174 | const Variable&, |
175 | const c10::VariableVersion& version_counter); |
176 | |
177 | /// Retrieves this `Variable`s version counter. |
178 | TORCH_API const c10::VariableVersion& version_counter(const Variable&); |
179 | |
180 | TORCH_API void set_name(const Variable&, const std::string& name); |
181 | |
182 | TORCH_API void add_hook( |
183 | const at::TensorBase&, |
184 | std::unique_ptr<FunctionPreHook> hook); |
185 | TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&); |
186 | TORCH_API void clear_hooks(const at::TensorBase&); |
187 | |
188 | TORCH_API void create_cpp_hook( |
189 | const at::TensorBase&, |
190 | bool is_retains_grad_hooks = false); |
191 | } // namespace impl |
192 | |
193 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
194 | // AutogradMeta |
195 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
196 | |
197 | /// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd |
198 | /// metadata fields that are necessary for tracking the Variable's autograd |
199 | /// history. As an optimization, a Variable may store a nullptr, in lieu of a |
200 | /// default constructed AutogradMeta. |
201 | |
202 | struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface { |
203 | std::string name_; |
204 | |
205 | Variable grad_; |
206 | std::shared_ptr<Node> grad_fn_; |
207 | std::weak_ptr<Node> grad_accumulator_; |
208 | |
209 | // This field is used to store all the forward AD gradients |
210 | // associated with this AutogradMeta (and the Tensor it corresponds to) |
211 | // There is a semantic 1:1 correspondence between AutogradMeta and |
212 | // ForwardGrad but: |
213 | // - This field is lazily populated. |
214 | // - This field is a shared_ptr but it must never be |
215 | // shared by multiple Tensors. See Note [ Using ForwardGrad ] |
216 | // Any transition from not_initialized to initialized |
217 | // must be protected by mutex_ |
218 | std::shared_ptr<ForwardGrad> fw_grad_; |
219 | |
220 | // The hooks_ field is actually reused by both python and cpp logic |
221 | // For both cases, we have a data structure, cpp_hooks_list_ (cpp) |
222 | // or dict (python) which is the canonical copy. |
223 | // Then, for both cases, we always register a single hook to |
224 | // hooks_ which wraps all the hooks in the list/dict. |
225 | // And, again in both cases, if the grad_fn exists on that tensor |
226 | // we will additionally register a single hook to the grad_fn. |
227 | // |
228 | // Note that the cpp and python use cases aren't actually aware of |
229 | // each other, so using both is not defined behavior. |
230 | std::vector<std::unique_ptr<FunctionPreHook>> hooks_; |
231 | std::shared_ptr<hooks_list> cpp_hooks_list_; |
232 | |
233 | // Only meaningful on leaf variables (must be false otherwise) |
234 | bool requires_grad_{false}; |
235 | |
236 | // Only meaningful on non-leaf variables (must be false otherwise) |
237 | bool retains_grad_{false}; |
238 | |
239 | bool is_view_{false}; |
240 | |
241 | // The "output number" of this variable; e.g., if this variable |
242 | // was the second output of a function, then output_nr == 1. |
243 | // We use this to make sure we can setup the backwards trace |
244 | // correctly when this variable is passed to another function. |
245 | uint32_t output_nr_; |
246 | |
247 | // Mutex to ensure that concurrent read operations that modify internal |
248 | // state are still thread-safe. Used by grad_fn(), grad_accumulator(), |
249 | // fw_grad() and set_fw_grad() |
250 | // This is mutable because we need to be able to acquire this from const |
251 | // version of this class for the functions above |
252 | mutable std::mutex mutex_; |
253 | |
254 | /// Sets the `requires_grad` property of `Variable`. This should be true for |
255 | /// leaf variables that want to accumulate gradients, and false for all other |
256 | /// variables. |
257 | void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl) |
258 | override { |
259 | TORCH_CHECK( |
260 | !requires_grad || |
261 | isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())), |
262 | "Only Tensors of floating point and complex dtype can require gradients" ); |
263 | requires_grad_ = requires_grad; |
264 | } |
265 | |
266 | bool requires_grad() const override { |
267 | return requires_grad_ || grad_fn_; |
268 | } |
269 | |
270 | /// Accesses the gradient `Variable` of this `Variable`. |
271 | Variable& mutable_grad() override { |
272 | return grad_; |
273 | } |
274 | |
275 | const Variable& grad() const override { |
276 | return grad_; |
277 | } |
278 | |
279 | const Variable& fw_grad(uint64_t level, const at::TensorBase& self) |
280 | const override; |
281 | |
282 | void set_fw_grad( |
283 | const at::TensorBase& new_grad, |
284 | const at::TensorBase& self, |
285 | uint64_t level, |
286 | bool is_inplace_op) override; |
287 | |
288 | AutogradMeta( |
289 | at::TensorImpl* self_impl = nullptr, |
290 | bool requires_grad = false, |
291 | Edge gradient_edge = Edge()) |
292 | : grad_fn_(std::move(gradient_edge.function)), |
293 | |
294 | output_nr_(gradient_edge.input_nr) { |
295 | // set_requires_grad also checks error conditions. |
296 | if (requires_grad) { |
297 | TORCH_INTERNAL_ASSERT(self_impl); |
298 | // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall) |
299 | set_requires_grad(requires_grad, self_impl); |
300 | } |
301 | TORCH_CHECK( |
302 | !grad_fn_ || !requires_grad_, |
303 | "requires_grad should be false if grad_fn is set" ); |
304 | } |
305 | |
306 | ~AutogradMeta() override { |
307 | // If AutogradMeta is being destroyed, it means that there is no other |
308 | // reference to its corresponding Tensor. It implies that no other thread |
309 | // can be using this object and so there is no need to lock mutex_ here to |
310 | // guard the check if fw_grad_ is populated. |
311 | if (fw_grad_) { |
312 | // See note [ Using ForwardGrad ] |
313 | fw_grad_->clear(); |
314 | } |
315 | } |
316 | }; |
317 | |
318 | struct TORCH_API ViewInfo { |
319 | /// The base `Variable` |
320 | /// If this ViewInfo represents a forward (respectively backward) AD gradient, |
321 | /// then this Tensor cannot be a forward (respectively backward) view. |
322 | Variable base_; |
323 | |
324 | /// By default we use as_strided to recover views which is more efficient. |
325 | /// view_fn is only saved when as_strided is not supported. |
326 | /// If view_fn has value, we use it to recover views in backward. |
327 | std::function<Variable(const Variable&)> view_fn_; |
328 | |
329 | /// Accessors for the view function |
330 | bool has_view_fn() const { |
331 | return view_fn_ != nullptr; |
332 | } |
333 | |
334 | std::function<Variable(const Variable&)> view_fn() const { |
335 | TORCH_CHECK( |
336 | has_view_fn(), "Can only access the view function if it exists." ); |
337 | return view_fn_; |
338 | } |
339 | |
340 | /// The chain function can be used to build a new ViewInfo for a |
341 | /// differentiable view function. It will return a new view info that |
342 | /// accurately represents how "tensor" is a view of this instance's "base_". |
343 | /// The "base" and "tensor" are respectively the input and output of the |
344 | /// differentiable view function that happened. They are required to properly |
345 | /// set the optional view_fn_ when it is not provided. The "view_func", if |
346 | /// provided, should be a function that allows to re-do the view between |
347 | /// "base" and "tensor". |
348 | ViewInfo chain( |
349 | const Variable& base, |
350 | const Variable& tensor, |
351 | std::function<Variable(const Variable&)> view_func = nullptr) const; |
352 | |
353 | ViewInfo(Variable base, std::function<Variable(const Variable&)> view_fn) |
354 | : base_(std::move(base)), view_fn_(std::move(view_fn)) { |
355 | TORCH_CHECK(base_.defined(), "base is undefined" ); |
356 | } |
357 | }; |
358 | |
359 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
360 | // DifferentiableViewMeta |
361 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
362 | |
363 | /// NOTE [ Autograd View Variables ] |
364 | /// |
365 | /// Many operations return Variable that shares storage with an input Variable. |
366 | /// The returned Variable is called a **view** Variable on the input **base** |
367 | /// Variable. |
368 | /// |
369 | /// In PyTorch, we have two types of views: differentiable views, and |
370 | /// non-differentiable views. In either type, to support proper version |
371 | /// checking, the base and view Variables must always share the same |
372 | /// version_counter. |
373 | /// |
374 | /// |
375 | /// Differentiable Views |
376 | /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
377 | /// This class allows to track both forward and backward AD differentiable |
378 | /// views. These views can have different base as non-differentiable view for |
379 | /// forward and backward mode AD are not the same. |
380 | /// |
381 | /// Most function are either both forward and backward differentiable views (for |
382 | /// example: view, select, narrow, transpose, etc) or both not forward and not |
383 | /// backward differentiable views (for example: indices, values, eq, lt, etc). |
384 | /// But there are also functions that are forward but not backward |
385 | /// differentiable views (only detach for now) or functions that are backward |
386 | /// but not forward differentiable view (only make_dual and unpack dual for |
387 | /// now). |
388 | /// |
389 | /// A concrete example of two views with different bases is as follow: |
390 | /// |
391 | /// # Have: |
392 | /// # dual is a dual Tensor that is neither a forward or backward view |
393 | /// detached_dual = dual.detach() |
394 | /// view = detached_dual.view_as(dual) |
395 | /// # The forward base of view is dual |
396 | /// # The backward base of view is detached_dual |
397 | /// |
398 | /// - Backward Mode View |
399 | /// Differentiable views are the view variables where you want gradients to flow |
400 | /// back to the base variables. Out-of-place operations on views are quite |
401 | /// straightforward, but in-place ones are very tricky. Even if the base |
402 | /// variable may not require grad when we create the view, we still need to |
403 | /// track the view relation because future in-place ops may require back-proping |
404 | /// through it. For example, we need to support |
405 | /// |
406 | /// (1) in-place operation on view, e.g., |
407 | /// |
408 | /// # Have: |
409 | /// # base.requires_grad = False |
410 | /// # var.requires_grad = True |
411 | /// base[1] = var # i.e., base[1].copy_(var) |
412 | /// torch.autograd.grad(base.sum(), var) <- should return an all ones |
413 | /// tensor |
414 | /// |
415 | /// (2) in-place operation on base after view is created, e.g., |
416 | /// |
417 | /// # Have: |
418 | /// # base.requires_grad = False |
419 | /// # var.requires_grad = True |
420 | /// view = base[1] |
421 | /// base.copy_(var) |
422 | /// torch.autograd.grad(view.sum(), var) <- should return a tensor with |
423 | /// var[1] filled with all ones and |
424 | /// zeros everywhere else |
425 | /// |
426 | /// - Forward Mode View |
427 | /// Forward differentiable views follow the same semantic as backward ones but |
428 | /// show up differently as they are computed along with the forward evaluation. |
429 | /// The hard examples above are thus very similar |
430 | /// |
431 | /// (1) in-place operation on view, e.g., |
432 | /// |
433 | /// # Have: |
434 | /// # base is a regular Tensor |
435 | /// # var is a dual Tensor whose tangent is all ones |
436 | /// base[1] = var # i.e., base[1].copy_(var) |
437 | /// # Now, base is a dual Tensor |
438 | /// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with |
439 | /// fw_grad[1] filled with all ones |
440 | /// and zeros everywhere else |
441 | /// |
442 | /// (2) in-place operation on base after view is created, e.g., |
443 | /// |
444 | /// # Have: |
445 | /// # base is a regular Tensor |
446 | /// # var is a dual Tensor whose tangent is all ones |
447 | /// view = base[1] |
448 | /// base.copy_(var) |
449 | /// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones |
450 | /// tensor |
451 | /// |
452 | /// See Note [Forward Grad View/inplace] for more details on how we handle these |
453 | /// hard cases. |
454 | /// |
455 | /// |
456 | /// DifferentiableViewMeta is created to support gradient tracking of |
457 | /// such **in-place** operations. In particular, |
458 | /// + if an in-place op is done on base, the grad_fn field of the view may |
459 | /// become stale. So accesses should always go through grad_fn(), which |
460 | /// reconstructs an updated grad_fn if the version_counter has incremented. |
461 | /// All other fields are always valid. |
462 | /// + if an in-place op is done on view, in rebase_history() of view, which is |
463 | /// called after every in-place op in VariableType.cpp, the grad_fn of base |
464 | /// is updated. |
465 | /// + if a single autograd Node returns multiple differentiable views, if any |
466 | /// output is modified by an inplace operation, the autograd engine will |
467 | /// make an equivalent graph (corresponding to the view operations) without |
468 | /// using equivalent graph, where each output is treated as if it were |
469 | /// produced by a distinct view operation. This discards the original (e.g., |
470 | /// user provided) grad_fn. If the provided grad_fn does more than the |
471 | /// backward of the view, then the DifferentiableViewMeta must be created |
472 | /// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the |
473 | /// engine from ignoring the provided grad_fn. |
474 | /// |
475 | /// Interaction with GradMode: |
476 | /// The particular case that we consider here is: |
477 | /// |
478 | /// # Have: |
479 | /// # base.requires_grad = True or False |
480 | /// with torch.no_grad(): |
481 | /// view = base[1] |
482 | /// base.requires_grad_() |
483 | /// view.copy_(var) |
484 | /// torch.autograd.grad(base.sum(), var) <- what should it return? |
485 | /// |
486 | /// Given that this particular code example is ambiguous and can easily be |
487 | /// replace by either moving both inside the no_grad block or both outside, we |
488 | /// explicitly forbid it. For now, it is deprecated by a warning. This is |
489 | /// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all |
490 | /// differentiable views created in no_grad mode. |
491 | /// |
492 | /// See Note [View + Inplace update for base tensor] |
493 | /// and Note [View + Inplace update for view tensor] for the details how |
494 | /// autograd handles inplace update with view ops. |
495 | /// |
496 | /// Non-Differentiable Views |
497 | /// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
498 | /// In certain cases, although function outputs share storage with inputs, they |
499 | /// will **never** require gradient history tracking. Instead of registering the |
500 | /// view relation via DifferentiableViewMeta in autograd, the views will be |
501 | /// using usual AutogradMeta and just share the version counters with the base |
502 | /// Variables. |
503 | /// Such views include: |
504 | /// 1. Views created from .detach() |
505 | /// 2. Views that are non-differentiable by its nature. |
506 | /// E.g., `sparse_tensor.indices()` is a integral view on a (possibly) |
507 | /// floating point tensor. |
508 | /// See top of `derivatives.yaml` on how to specify that outputs of a |
509 | /// function are non-differentiable. |
510 | /// These are called non-differentiable views as the gradients do not flow |
511 | /// through the view relation. |
512 | /// |
513 | /// Relevant logic for both differentiable and non-differentiable views is |
514 | /// implemented in make_variable_(non_)differentiable_view below, and |
515 | /// wrap_output of gen_variable_type.py. |
516 | |
517 | /// NOTE [ View + Inplace detection ] |
518 | /// |
519 | /// We want to detect views followed by inplace as they are often forbidden to |
520 | /// ensure correctness of the computed gradients. But since we want to only |
521 | /// notify the user when both happen, we tag the DifferentiableViewMeta when the |
522 | /// view is created via the `make_variable_*_view()` functions. This tag is then |
523 | /// checked by the `check_inplace()` function from `VariableTypeUtils.h` that |
524 | /// should be called before every inplace operation and to detect cases where |
525 | /// other views are modified and this one is rebased by side effect, we also |
526 | /// check in the `VariableHooks::grad_fn()`. |
527 | |
528 | /// Flag that gives more information about when this view was created: |
529 | /// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom |
530 | /// autograd Function is returned. |
531 | /// - NO_GRAD_MODE should be set when a view in created when GradMode is |
532 | /// disabled |
533 | /// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code |
534 | /// returns |
535 | /// multiple differentiable views |
536 | /// - Inference_MODE should be set when a view of normal tensor is created in |
537 | /// InferenceMode. |
538 | /// - DEFAULT is for all other cases |
539 | enum class CreationMeta : uint8_t { |
540 | DEFAULT, |
541 | IN_CUSTOM_FUNCTION, |
542 | MULTI_OUTPUT_NODE, |
543 | NO_GRAD_MODE, |
544 | INFERENCE_MODE |
545 | }; |
546 | |
547 | /// Handles correctly propagating CreationMeta when a new view is created from a |
548 | /// previous view. In general, we don't want the new view to be _less_ |
549 | /// restrictive than the previous view (it's okay to be _more_ restrictive). A |
550 | /// CreationMeta value of DEFAULT is currently the least restrictive, as the |
551 | /// behavior for all other CreationMeta values is to error out for in-place ops. |
552 | /// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so |
553 | /// it takes precedence in propagation. If this changes, the logic here will |
554 | /// need to be updated to properly handle the new semantics. |
555 | inline CreationMeta propagate_creation_meta( |
556 | CreationMeta prev_view_creation_meta, |
557 | CreationMeta new_view_creation_meta) { |
558 | return (new_view_creation_meta == CreationMeta::DEFAULT) |
559 | ? prev_view_creation_meta |
560 | : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE |
561 | ? prev_view_creation_meta |
562 | : new_view_creation_meta); |
563 | } |
564 | |
565 | /// Unified function to handle error checking when rebase happens |
566 | /// indirect=true means that the caller is not doing the inplace, but the |
567 | /// inplace happened somewhere else. |
568 | TORCH_API void handle_view_on_rebase( |
569 | DifferentiableViewMeta* diff_view_meta, |
570 | bool indirect = false); |
571 | |
572 | struct TORCH_API DifferentiableViewMeta : public AutogradMeta { |
573 | private: |
574 | /// Informations about the views |
575 | c10::optional<ViewInfo> backward_info_; |
576 | c10::optional<ViewInfo> forward_info_; |
577 | |
578 | // Optimization to reduce the number of ViewInfo we create. |
579 | // In the (very common) case where backward_info_ == forward_info_, we only |
580 | // populate backward_info_ (that should be used as both the forward and |
581 | // backward view information) and set shared_view_info_ = true. Invariants: |
582 | // - If shared_view_info_ is false, there is no special constraints on |
583 | // backward_info_ and forward_info_ |
584 | // - If shared_view_info_ is true, we must have: |
585 | // - backward_info_.has_value() == true |
586 | // - forward_info_.has_value() == false |
587 | bool shared_view_info_; |
588 | |
589 | /// The two following fields are extra information that we track to ensure |
590 | /// that any operation on this backward view is valid. |
591 | |
592 | /// The value of the version_counter at the time grad_fn was created. The |
593 | /// grad_fn field is stale if attr_version_ != |
594 | /// version_counter.current_version(). |
595 | uint32_t attr_version_; |
596 | CreationMeta creation_meta_; |
597 | |
598 | public: |
599 | /// requires_grad is a backward AD field so we only use the view specific |
600 | /// logic for backward differentiable views |
601 | bool requires_grad() const override { |
602 | return requires_grad_ || grad_fn_ || |
603 | (has_bw_view() && get_backward_view().base_.requires_grad()); |
604 | } |
605 | |
606 | bool shared_view_info() const { |
607 | return shared_view_info_; |
608 | } |
609 | |
610 | bool has_bw_view() const { |
611 | return backward_info_.has_value(); |
612 | } |
613 | |
614 | const ViewInfo& get_backward_view() const { |
615 | TORCH_CHECK( |
616 | has_bw_view(), "backward view info can only exist for backward views." ); |
617 | return backward_info_.value(); |
618 | } |
619 | |
620 | uint32_t get_attr_version() const { |
621 | TORCH_CHECK( |
622 | has_bw_view(), "attr_version can only exist for backward views." ); |
623 | return attr_version_; |
624 | } |
625 | |
626 | void set_attr_version(uint32_t new_attr_version) { |
627 | TORCH_CHECK( |
628 | has_bw_view(), "attr_version can only exist for backward views." ); |
629 | attr_version_ = new_attr_version; |
630 | } |
631 | |
632 | CreationMeta get_creation_meta() const { |
633 | TORCH_CHECK( |
634 | has_bw_view(), "creation_meta can only exist for backward views." ); |
635 | return creation_meta_; |
636 | } |
637 | |
638 | void set_creation_meta(CreationMeta new_creation_meta) { |
639 | TORCH_CHECK( |
640 | has_bw_view(), "creation_meta can only exist for backward views." ); |
641 | creation_meta_ = new_creation_meta; |
642 | } |
643 | |
644 | bool has_fw_view() const { |
645 | return shared_view_info_ || forward_info_.has_value(); |
646 | } |
647 | |
648 | const ViewInfo& get_forward_view() const { |
649 | TORCH_CHECK( |
650 | has_fw_view(), "forward view info can only exist for forward views." ); |
651 | TORCH_CHECK( |
652 | !shared_view_info_ || has_bw_view(), |
653 | "forward view info can only exist for forward views." ); |
654 | return shared_view_info_ ? backward_info_.value() : forward_info_.value(); |
655 | } |
656 | |
657 | DifferentiableViewMeta( |
658 | at::TensorImpl* self_impl, |
659 | c10::optional<ViewInfo> backward_info, |
660 | c10::optional<ViewInfo> forward_info, |
661 | bool shared_view_info, |
662 | CreationMeta creation_meta = CreationMeta::DEFAULT); |
663 | }; |
664 | |
665 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
666 | // Variable Implementation |
667 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
668 | |
669 | // Factory Functions |
670 | //~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
671 | |
672 | /// Creates a `Variable` that is a *view* of another (*base*) variable. |
673 | /// The `gradient_edge` is an optional (gradient_function, input_number) pair. |
674 | /// `is_differentiable` is a bool that specifies whether this view is |
675 | /// differentiable, i.e., whether the relation should be tracked by autograd. |
676 | /// See NOTE [ Autograd View Variables ] for details. |
677 | |
678 | /// NOTE: `allow_tensor_metadata_change` is set to true by default, because |
679 | /// there are a lot of call sites to these factory functions that need to change |
680 | /// the variable's size or storage afterwards, and they don't expect the |
681 | /// original tensor (where the variable is created from) to be updated. Setting |
682 | /// `allow_tensor_metadata_change_` to false by default would unnecessarily |
683 | /// prevent those changes from happening and is undesirable. |
684 | |
685 | // See NOTE [ Autograd View Variables ] for details. |
686 | // Differentiable view. Track history with DifferentiableViewMeta. |
687 | inline Variable make_variable_differentiable_view( |
688 | const at::Tensor& data, |
689 | c10::optional<ViewInfo> backward_info, |
690 | c10::optional<ViewInfo> forward_info, |
691 | bool shared_view_info, |
692 | CreationMeta creation_meta, |
693 | bool allow_tensor_metadata_change = true) { |
694 | if (data.defined()) { |
695 | TORCH_CHECK( |
696 | data.getIntrusivePtr()->autograd_meta() == nullptr, |
697 | "Attempted to make a tensor into a differentiable view, but the " |
698 | "tensor already had autograd metadata associated with it. If you are " |
699 | "using a __torch_dispatch__ mode, the most common cause for this " |
700 | "problem is that you used torch.overrides.enable_reentrant_dispatch() " |
701 | "improperly; tensors created within the extent of reentrant dispatch " |
702 | "MUST NOT be directly returned from __torch_dispatch__; instead, they " |
703 | "must be wrapped into fresh tensors that serve as the output. If you " |
704 | "are not using wrappers, you probably don't need reentrant dispatch. " |
705 | "If this doesn't seem applicable, please file a bug to PyTorch." ); |
706 | at::TensorImpl* data_impl = data.unsafeGetTensorImpl(); |
707 | data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |
708 | data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>( |
709 | data_impl, |
710 | std::move(backward_info), |
711 | std::move(forward_info), |
712 | shared_view_info, |
713 | creation_meta)); |
714 | return data; |
715 | } |
716 | return Variable(); |
717 | } |
718 | |
719 | // See NOTE [ Autograd View Variables ] for details. |
720 | // Non-differentiable view. Just share version counter. |
721 | inline Variable make_variable_non_differentiable_view( |
722 | Variable base, |
723 | const at::Tensor& data, |
724 | bool allow_tensor_metadata_change = true) { |
725 | if (data.defined()) { |
726 | // Currently all of non-differentiable view ops(detach/_indices/_values) |
727 | // share the same TensorImpl as their base Tensor. Thus a new TensorImpl |
728 | // allocation here is required. |
729 | auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( |
730 | /*version_counter=*/impl::version_counter(base), |
731 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
732 | data_impl_copy->set_autograd_meta(nullptr); |
733 | return Variable(data_impl_copy); |
734 | } |
735 | return Variable(); |
736 | } |
737 | |
738 | /// Creates a `Variable` from the given `Tensor`, copying its underlying |
739 | /// `TensorImpl`. `requires_grad` should be set only for leaves, and determines |
740 | /// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be |
741 | /// a `Variable` already. Its dynamic type *must* be `Tensor`. |
742 | /// |
743 | /// TODO: Eliminate this function as much as possible, as it can be expressed |
744 | /// more clearly as detach() or a no-op in most call sites (especially when |
745 | /// there is only one use of the variable). |
746 | inline Variable make_variable( |
747 | at::Tensor data, |
748 | bool requires_grad = false, |
749 | bool allow_tensor_metadata_change = true) { |
750 | if (data.defined()) { |
751 | if (data.getIntrusivePtr().use_count() == 1 && |
752 | data.getIntrusivePtr()->unique_version()) { |
753 | auto data_impl = data.unsafeReleaseIntrusivePtr(); |
754 | data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change); |
755 | // NOLINTNEXTLINE(bugprone-branch-clone) |
756 | if (requires_grad) { |
757 | data_impl->set_autograd_meta( |
758 | std::make_unique<AutogradMeta>(data_impl.get(), requires_grad)); |
759 | } else { |
760 | data_impl->set_autograd_meta(nullptr); |
761 | } |
762 | return Variable(std::move(data_impl)); |
763 | } else { |
764 | auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( |
765 | /*version_counter=*/0, |
766 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
767 | // NOLINTNEXTLINE(bugprone-branch-clone) |
768 | if (requires_grad) { |
769 | data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>( |
770 | data_impl_copy.get(), requires_grad)); |
771 | } else { |
772 | data_impl_copy->set_autograd_meta(nullptr); |
773 | } |
774 | return Variable(data_impl_copy); |
775 | } |
776 | } |
777 | return Variable(); |
778 | } |
779 | |
780 | /// Creates a `Variable` from the given `Tensor`, copying its underlying |
781 | /// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair |
782 | /// specifying the function in the autograd graph, and what particular input of |
783 | /// that function, this variable is connected to. |
784 | inline Variable make_variable( |
785 | at::Tensor data, |
786 | Edge gradient_edge, |
787 | bool allow_tensor_metadata_change = true) { |
788 | if (data.defined()) { |
789 | auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach( |
790 | /*version_counter=*/0, |
791 | /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); |
792 | data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>( |
793 | data_impl_copy.get(), false, std::move(gradient_edge))); |
794 | return Variable(data_impl_copy); |
795 | } |
796 | return Variable(); |
797 | } |
798 | |
799 | namespace utils { |
800 | |
801 | TORCH_API bool has_same_meta(const Variable& base, const Variable& other); |
802 | |
803 | } // namespace utils |
804 | } // namespace autograd |
805 | } // namespace torch |
806 | |
807 | #endif /* DOXYGEN_SHOULD_SKIP_THIS */ |
808 | |