1#pragma once
2
3#include <torch/csrc/utils/python_stub.h>
4
5#include <torch/csrc/Export.h>
6#include <torch/csrc/autograd/cpp_hook.h>
7#include <torch/csrc/autograd/edge.h>
8#include <torch/csrc/autograd/forward_grad.h>
9#include <torch/csrc/autograd/function_hook.h>
10
11#include <ATen/NamedTensorUtils.h>
12#include <ATen/core/Tensor.h>
13#include <c10/util/Exception.h>
14
15#include <cstdint>
16#include <memory>
17#include <mutex>
18#include <stdexcept>
19#include <string>
20#include <utility>
21#include <vector>
22
23namespace torch {
24namespace autograd {
25
26/// `Variable` is exactly the same as `Tensor` (i.e. we have `using Variable =
27/// at::Tensor`). This means you can perform all the usual mathematical and
28/// other operations you can perform on `Tensor`s also on `Variable`s.
29///
30/// The only reason we are keeping the `Variable` class is backward
31/// compatibility with external user's legacy C++ frontend code. Our intention
32/// is to eliminate the `Variable` class in the near future.
33using Variable = at::Tensor;
34
35} // namespace autograd
36} // namespace torch
37
38// The following are all internal APIs and should not be shown in libtorch docs.
39// Therefore, we wrap the following code with `#ifndef DOXYGEN_SHOULD_SKIP_THIS
40// ... #endif`
41
42#ifndef DOXYGEN_SHOULD_SKIP_THIS
43
44namespace torch {
45namespace autograd {
46
47/// Check if this type is supported by the autograd engine.
48/// If you change this, update the doc at the top of the
49/// torch/autograd/__init__.py file and
50/// "test_set_requires_grad_only_for_continuous_types" in test/test_autograd.py
51static inline bool isDifferentiableType(at::ScalarType t) {
52 return isFloatingType(t) || isComplexType(t);
53}
54
55struct Node;
56
57///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
58/// Variable
59///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
60/// A `Variable` augments a `Tensor` with the ability to interact in our
61/// autograd machinery. Conceptually, `Variable`s travel along `Edge`s between
62/// `Node`s in the autograd graph. A `Variable` can either be a leaf, like a
63/// weight in a neural network, or an interior variable, when it is the result
64/// of an operation between variables. Every `Variable` also stores another
65/// `Variable` called its `grad` (gradient). If the variable is a leaf, its
66/// gradient will be accumulated into this variable.
67///
68/// Every Tensor is a Variable, but sometimes we colloquially refer to Variables
69/// that don't require gradients as Tensors (since none of the autograd
70/// machinery for Variables applies). Historically, Variables and Tensors
71/// were separate concepts, but now they are exactly the same (i.e. we have
72/// `using Variable = at::Tensor`).
73///
74/// Gradient Edges
75///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
76/// Furthermore, `Variable`s have the notion of a `gradient_edge`, which is the
77/// edge in the autograd graph that connects the variable to a particular input
78/// of the gradient function that will be invoked with the variable during the
79/// backward pass. More precisely, this gradient function can be one of two
80/// things:
81/// 1. A `grad_fn`, if the variable is in the interior of the graph. This is the
82/// gradient of the function that produced the variable.
83/// 2. A `grad_accumulator`, if the variable is a leaf, which accumulates a
84/// scalar gradient value into its `grad` variable.
85///
86/// Versioning
87///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
88/// Another major feature of `Variable`s are *versions*. Versions are
89/// incremented when an in-place mutation of a variable occurs. Versions are
90/// useful when constructing `SavedVariable`s, which take a snapshot of a
91/// `Variable` at a certain version. You can retrieve a `Variable`'s version
92/// through its `current_version()` method.
93///
94/// Views
95///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
96/// It is possible for a `Variable` to be a *view* of another `Variable`, in
97/// which case it tracks that `Variable`'s data and autograd history. Beyond
98/// construction, the interface of a view is identical to that of a regular
99/// `Variable`. You can determine whether `Variable` is in fact a view by
100/// probing its `is_view()` method. Note that the *view* semantics are only
101/// meaningful for `Variable` relations that are relevant to autograd.
102/// See NOTE [ Autograd View Variables ] for more details.
103///~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
104
105struct AutogradMeta;
106struct DifferentiableViewMeta;
107
108// Private-ish functions for manipulating variables; we don't want to put them
109// on Tensor proper
110namespace impl {
111
112// WARNING: This may return a nullptr. If you require AutogradMeta to return
113// a materialized structure, use materialize_autograd_meta instead.
114TORCH_API AutogradMeta* get_autograd_meta(const at::TensorBase&);
115
116// WARNING: This will return a nullptr if the Tensor is not a view.
117TORCH_API DifferentiableViewMeta* get_view_autograd_meta(const at::TensorBase&);
118
119// Returns the current autograd meta, materializing it if it was previously
120// none. This counts as a *mutating* operation, so do not call it on
121// "read-only" operators; in particular, this is NOT thread safe
122TORCH_API AutogradMeta* materialize_autograd_meta(const at::TensorBase&);
123
124/// Set the gradient accumulator of the `Variable`. This is only applicable to
125/// leaf variables. Interior variables should call `set_gradient_edge()`.
126TORCH_API void set_grad_accumulator(
127 const Variable&,
128 std::weak_ptr<Node> grad_accumulator);
129
130/// Attempts to get a pointer to the gradient accumulator of the `Variable`,
131/// if it still exists. If the gradient accumulator function has been
132/// destroyed, returns a `nullptr`.
133TORCH_API std::shared_ptr<Node> try_get_grad_accumulator(const Variable&);
134
135/// Gets the gradient accumulator of the `Variable` if it has one, or else
136/// create one on the fly and return it.
137TORCH_API std::shared_ptr<Node> grad_accumulator(const Variable&);
138
139/// Returns the "canonical" gradient edge of this `Variable`, i.e. either the
140/// gradient function if this is an interior `Variable`, or the gradient
141/// accumulator otherwise. If the `Variable` is interior, the returned `Edge`
142/// will store the input index of the `Node` to which this variable is
143/// connected in its `input_nr` field. For leaves, the `input_nr` is always
144/// zero. Note that `set_gradient_edge` and `gradient_edge` are not
145/// symmetric. You must use `set_gradient_edge` to set the `grad_fn` and
146/// `set_grad_accumulator` to set the accumulator.
147TORCH_API Edge gradient_edge(const Variable&);
148
149/// Set the gradient edge -- i.e. `grad_fn` and `input_nr` -- of the
150/// `Variable`.
151/// NOTE: This will always set the `grad_fn`, even if this is a leaf variable,
152/// and never the `grad_accumulator`. For the latter, use
153/// `set_grad_accumulator`. This allows late construction of an interior
154/// `Variable`.
155TORCH_API void set_gradient_edge(const Variable&, Edge edge);
156
157// Autograd Graph Interaction
158//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
159
160/// Update the `grad_fn` of an existing Variable. Called after in-place
161/// modifications.
162///
163/// For View Variables:
164/// Called after in-place modifications. Modifies the grad_fn of the base
165/// Variable.
166TORCH_API void rebase_history(const Variable&, Edge gradient_edge);
167
168/// Gets the raw gradient function pointer, whatever it currently is.
169TORCH_API Node* grad_fn_unsafe(const Variable&);
170
171/// Increments the version count of this `Variable`.
172TORCH_API void bump_version(const Variable&);
173TORCH_API void set_version_counter(
174 const Variable&,
175 const c10::VariableVersion& version_counter);
176
177/// Retrieves this `Variable`s version counter.
178TORCH_API const c10::VariableVersion& version_counter(const Variable&);
179
180TORCH_API void set_name(const Variable&, const std::string& name);
181
182TORCH_API void add_hook(
183 const at::TensorBase&,
184 std::unique_ptr<FunctionPreHook> hook);
185TORCH_API std::vector<std::unique_ptr<FunctionPreHook>>& hooks(const Variable&);
186TORCH_API void clear_hooks(const at::TensorBase&);
187
188TORCH_API void create_cpp_hook(
189 const at::TensorBase&,
190 bool is_retains_grad_hooks = false);
191} // namespace impl
192
193//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
194// AutogradMeta
195//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
196
197/// Each `Variable` has one unique `AutogradMeta` struct, which stores autograd
198/// metadata fields that are necessary for tracking the Variable's autograd
199/// history. As an optimization, a Variable may store a nullptr, in lieu of a
200/// default constructed AutogradMeta.
201
202struct TORCH_API AutogradMeta : public c10::AutogradMetaInterface {
203 std::string name_;
204
205 Variable grad_;
206 std::shared_ptr<Node> grad_fn_;
207 std::weak_ptr<Node> grad_accumulator_;
208
209 // This field is used to store all the forward AD gradients
210 // associated with this AutogradMeta (and the Tensor it corresponds to)
211 // There is a semantic 1:1 correspondence between AutogradMeta and
212 // ForwardGrad but:
213 // - This field is lazily populated.
214 // - This field is a shared_ptr but it must never be
215 // shared by multiple Tensors. See Note [ Using ForwardGrad ]
216 // Any transition from not_initialized to initialized
217 // must be protected by mutex_
218 std::shared_ptr<ForwardGrad> fw_grad_;
219
220 // The hooks_ field is actually reused by both python and cpp logic
221 // For both cases, we have a data structure, cpp_hooks_list_ (cpp)
222 // or dict (python) which is the canonical copy.
223 // Then, for both cases, we always register a single hook to
224 // hooks_ which wraps all the hooks in the list/dict.
225 // And, again in both cases, if the grad_fn exists on that tensor
226 // we will additionally register a single hook to the grad_fn.
227 //
228 // Note that the cpp and python use cases aren't actually aware of
229 // each other, so using both is not defined behavior.
230 std::vector<std::unique_ptr<FunctionPreHook>> hooks_;
231 std::shared_ptr<hooks_list> cpp_hooks_list_;
232
233 // Only meaningful on leaf variables (must be false otherwise)
234 bool requires_grad_{false};
235
236 // Only meaningful on non-leaf variables (must be false otherwise)
237 bool retains_grad_{false};
238
239 bool is_view_{false};
240
241 // The "output number" of this variable; e.g., if this variable
242 // was the second output of a function, then output_nr == 1.
243 // We use this to make sure we can setup the backwards trace
244 // correctly when this variable is passed to another function.
245 uint32_t output_nr_;
246
247 // Mutex to ensure that concurrent read operations that modify internal
248 // state are still thread-safe. Used by grad_fn(), grad_accumulator(),
249 // fw_grad() and set_fw_grad()
250 // This is mutable because we need to be able to acquire this from const
251 // version of this class for the functions above
252 mutable std::mutex mutex_;
253
254 /// Sets the `requires_grad` property of `Variable`. This should be true for
255 /// leaf variables that want to accumulate gradients, and false for all other
256 /// variables.
257 void set_requires_grad(bool requires_grad, at::TensorImpl* self_impl)
258 override {
259 TORCH_CHECK(
260 !requires_grad ||
261 isDifferentiableType(at::typeMetaToScalarType(self_impl->dtype())),
262 "Only Tensors of floating point and complex dtype can require gradients");
263 requires_grad_ = requires_grad;
264 }
265
266 bool requires_grad() const override {
267 return requires_grad_ || grad_fn_;
268 }
269
270 /// Accesses the gradient `Variable` of this `Variable`.
271 Variable& mutable_grad() override {
272 return grad_;
273 }
274
275 const Variable& grad() const override {
276 return grad_;
277 }
278
279 const Variable& fw_grad(uint64_t level, const at::TensorBase& self)
280 const override;
281
282 void set_fw_grad(
283 const at::TensorBase& new_grad,
284 const at::TensorBase& self,
285 uint64_t level,
286 bool is_inplace_op) override;
287
288 AutogradMeta(
289 at::TensorImpl* self_impl = nullptr,
290 bool requires_grad = false,
291 Edge gradient_edge = Edge())
292 : grad_fn_(std::move(gradient_edge.function)),
293
294 output_nr_(gradient_edge.input_nr) {
295 // set_requires_grad also checks error conditions.
296 if (requires_grad) {
297 TORCH_INTERNAL_ASSERT(self_impl);
298 // NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
299 set_requires_grad(requires_grad, self_impl);
300 }
301 TORCH_CHECK(
302 !grad_fn_ || !requires_grad_,
303 "requires_grad should be false if grad_fn is set");
304 }
305
306 ~AutogradMeta() override {
307 // If AutogradMeta is being destroyed, it means that there is no other
308 // reference to its corresponding Tensor. It implies that no other thread
309 // can be using this object and so there is no need to lock mutex_ here to
310 // guard the check if fw_grad_ is populated.
311 if (fw_grad_) {
312 // See note [ Using ForwardGrad ]
313 fw_grad_->clear();
314 }
315 }
316};
317
318struct TORCH_API ViewInfo {
319 /// The base `Variable`
320 /// If this ViewInfo represents a forward (respectively backward) AD gradient,
321 /// then this Tensor cannot be a forward (respectively backward) view.
322 Variable base_;
323
324 /// By default we use as_strided to recover views which is more efficient.
325 /// view_fn is only saved when as_strided is not supported.
326 /// If view_fn has value, we use it to recover views in backward.
327 std::function<Variable(const Variable&)> view_fn_;
328
329 /// Accessors for the view function
330 bool has_view_fn() const {
331 return view_fn_ != nullptr;
332 }
333
334 std::function<Variable(const Variable&)> view_fn() const {
335 TORCH_CHECK(
336 has_view_fn(), "Can only access the view function if it exists.");
337 return view_fn_;
338 }
339
340 /// The chain function can be used to build a new ViewInfo for a
341 /// differentiable view function. It will return a new view info that
342 /// accurately represents how "tensor" is a view of this instance's "base_".
343 /// The "base" and "tensor" are respectively the input and output of the
344 /// differentiable view function that happened. They are required to properly
345 /// set the optional view_fn_ when it is not provided. The "view_func", if
346 /// provided, should be a function that allows to re-do the view between
347 /// "base" and "tensor".
348 ViewInfo chain(
349 const Variable& base,
350 const Variable& tensor,
351 std::function<Variable(const Variable&)> view_func = nullptr) const;
352
353 ViewInfo(Variable base, std::function<Variable(const Variable&)> view_fn)
354 : base_(std::move(base)), view_fn_(std::move(view_fn)) {
355 TORCH_CHECK(base_.defined(), "base is undefined");
356 }
357};
358
359//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
360// DifferentiableViewMeta
361//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
362
363/// NOTE [ Autograd View Variables ]
364///
365/// Many operations return Variable that shares storage with an input Variable.
366/// The returned Variable is called a **view** Variable on the input **base**
367/// Variable.
368///
369/// In PyTorch, we have two types of views: differentiable views, and
370/// non-differentiable views. In either type, to support proper version
371/// checking, the base and view Variables must always share the same
372/// version_counter.
373///
374///
375/// Differentiable Views
376/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
377/// This class allows to track both forward and backward AD differentiable
378/// views. These views can have different base as non-differentiable view for
379/// forward and backward mode AD are not the same.
380///
381/// Most function are either both forward and backward differentiable views (for
382/// example: view, select, narrow, transpose, etc) or both not forward and not
383/// backward differentiable views (for example: indices, values, eq, lt, etc).
384/// But there are also functions that are forward but not backward
385/// differentiable views (only detach for now) or functions that are backward
386/// but not forward differentiable view (only make_dual and unpack dual for
387/// now).
388///
389/// A concrete example of two views with different bases is as follow:
390///
391/// # Have:
392/// # dual is a dual Tensor that is neither a forward or backward view
393/// detached_dual = dual.detach()
394/// view = detached_dual.view_as(dual)
395/// # The forward base of view is dual
396/// # The backward base of view is detached_dual
397///
398/// - Backward Mode View
399/// Differentiable views are the view variables where you want gradients to flow
400/// back to the base variables. Out-of-place operations on views are quite
401/// straightforward, but in-place ones are very tricky. Even if the base
402/// variable may not require grad when we create the view, we still need to
403/// track the view relation because future in-place ops may require back-proping
404/// through it. For example, we need to support
405///
406/// (1) in-place operation on view, e.g.,
407///
408/// # Have:
409/// # base.requires_grad = False
410/// # var.requires_grad = True
411/// base[1] = var # i.e., base[1].copy_(var)
412/// torch.autograd.grad(base.sum(), var) <- should return an all ones
413/// tensor
414///
415/// (2) in-place operation on base after view is created, e.g.,
416///
417/// # Have:
418/// # base.requires_grad = False
419/// # var.requires_grad = True
420/// view = base[1]
421/// base.copy_(var)
422/// torch.autograd.grad(view.sum(), var) <- should return a tensor with
423/// var[1] filled with all ones and
424/// zeros everywhere else
425///
426/// - Forward Mode View
427/// Forward differentiable views follow the same semantic as backward ones but
428/// show up differently as they are computed along with the forward evaluation.
429/// The hard examples above are thus very similar
430///
431/// (1) in-place operation on view, e.g.,
432///
433/// # Have:
434/// # base is a regular Tensor
435/// # var is a dual Tensor whose tangent is all ones
436/// base[1] = var # i.e., base[1].copy_(var)
437/// # Now, base is a dual Tensor
438/// _, fw_grad = fwAD.unpack_dual(base) <- fw_grad should be a tensor with
439/// fw_grad[1] filled with all ones
440/// and zeros everywhere else
441///
442/// (2) in-place operation on base after view is created, e.g.,
443///
444/// # Have:
445/// # base is a regular Tensor
446/// # var is a dual Tensor whose tangent is all ones
447/// view = base[1]
448/// base.copy_(var)
449/// _, fw_grad = fwAD.unpack_dual(view) <- fw_grad should be an all ones
450/// tensor
451///
452/// See Note [Forward Grad View/inplace] for more details on how we handle these
453/// hard cases.
454///
455///
456/// DifferentiableViewMeta is created to support gradient tracking of
457/// such **in-place** operations. In particular,
458/// + if an in-place op is done on base, the grad_fn field of the view may
459/// become stale. So accesses should always go through grad_fn(), which
460/// reconstructs an updated grad_fn if the version_counter has incremented.
461/// All other fields are always valid.
462/// + if an in-place op is done on view, in rebase_history() of view, which is
463/// called after every in-place op in VariableType.cpp, the grad_fn of base
464/// is updated.
465/// + if a single autograd Node returns multiple differentiable views, if any
466/// output is modified by an inplace operation, the autograd engine will
467/// make an equivalent graph (corresponding to the view operations) without
468/// using equivalent graph, where each output is treated as if it were
469/// produced by a distinct view operation. This discards the original (e.g.,
470/// user provided) grad_fn. If the provided grad_fn does more than the
471/// backward of the view, then the DifferentiableViewMeta must be created
472/// with creation_meta= CreationMeta::MULTI_OUTPUT_NODE to prevent the
473/// engine from ignoring the provided grad_fn.
474///
475/// Interaction with GradMode:
476/// The particular case that we consider here is:
477///
478/// # Have:
479/// # base.requires_grad = True or False
480/// with torch.no_grad():
481/// view = base[1]
482/// base.requires_grad_()
483/// view.copy_(var)
484/// torch.autograd.grad(base.sum(), var) <- what should it return?
485///
486/// Given that this particular code example is ambiguous and can easily be
487/// replace by either moving both inside the no_grad block or both outside, we
488/// explicitly forbid it. For now, it is deprecated by a warning. This is
489/// achieved by setting creation_meta=CreationMeta::NO_GRAD_MODE for all
490/// differentiable views created in no_grad mode.
491///
492/// See Note [View + Inplace update for base tensor]
493/// and Note [View + Inplace update for view tensor] for the details how
494/// autograd handles inplace update with view ops.
495///
496/// Non-Differentiable Views
497/// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
498/// In certain cases, although function outputs share storage with inputs, they
499/// will **never** require gradient history tracking. Instead of registering the
500/// view relation via DifferentiableViewMeta in autograd, the views will be
501/// using usual AutogradMeta and just share the version counters with the base
502/// Variables.
503/// Such views include:
504/// 1. Views created from .detach()
505/// 2. Views that are non-differentiable by its nature.
506/// E.g., `sparse_tensor.indices()` is a integral view on a (possibly)
507/// floating point tensor.
508/// See top of `derivatives.yaml` on how to specify that outputs of a
509/// function are non-differentiable.
510/// These are called non-differentiable views as the gradients do not flow
511/// through the view relation.
512///
513/// Relevant logic for both differentiable and non-differentiable views is
514/// implemented in make_variable_(non_)differentiable_view below, and
515/// wrap_output of gen_variable_type.py.
516
517/// NOTE [ View + Inplace detection ]
518///
519/// We want to detect views followed by inplace as they are often forbidden to
520/// ensure correctness of the computed gradients. But since we want to only
521/// notify the user when both happen, we tag the DifferentiableViewMeta when the
522/// view is created via the `make_variable_*_view()` functions. This tag is then
523/// checked by the `check_inplace()` function from `VariableTypeUtils.h` that
524/// should be called before every inplace operation and to detect cases where
525/// other views are modified and this one is rebased by side effect, we also
526/// check in the `VariableHooks::grad_fn()`.
527
528/// Flag that gives more information about when this view was created:
529/// - IN_CUSTOM_FUNCTION should be set when the view is created inside a custom
530/// autograd Function is returned.
531/// - NO_GRAD_MODE should be set when a view in created when GradMode is
532/// disabled
533/// - MULTI_OUTPUT_NODE should be set when a Node created by codegen code
534/// returns
535/// multiple differentiable views
536/// - Inference_MODE should be set when a view of normal tensor is created in
537/// InferenceMode.
538/// - DEFAULT is for all other cases
539enum class CreationMeta : uint8_t {
540 DEFAULT,
541 IN_CUSTOM_FUNCTION,
542 MULTI_OUTPUT_NODE,
543 NO_GRAD_MODE,
544 INFERENCE_MODE
545};
546
547/// Handles correctly propagating CreationMeta when a new view is created from a
548/// previous view. In general, we don't want the new view to be _less_
549/// restrictive than the previous view (it's okay to be _more_ restrictive). A
550/// CreationMeta value of DEFAULT is currently the least restrictive, as the
551/// behavior for all other CreationMeta values is to error out for in-place ops.
552/// A CreationMeta value of INFERENCE_MODE is currently the most restrictive, so
553/// it takes precedence in propagation. If this changes, the logic here will
554/// need to be updated to properly handle the new semantics.
555inline CreationMeta propagate_creation_meta(
556 CreationMeta prev_view_creation_meta,
557 CreationMeta new_view_creation_meta) {
558 return (new_view_creation_meta == CreationMeta::DEFAULT)
559 ? prev_view_creation_meta
560 : (prev_view_creation_meta == CreationMeta::INFERENCE_MODE
561 ? prev_view_creation_meta
562 : new_view_creation_meta);
563}
564
565/// Unified function to handle error checking when rebase happens
566/// indirect=true means that the caller is not doing the inplace, but the
567/// inplace happened somewhere else.
568TORCH_API void handle_view_on_rebase(
569 DifferentiableViewMeta* diff_view_meta,
570 bool indirect = false);
571
572struct TORCH_API DifferentiableViewMeta : public AutogradMeta {
573 private:
574 /// Informations about the views
575 c10::optional<ViewInfo> backward_info_;
576 c10::optional<ViewInfo> forward_info_;
577
578 // Optimization to reduce the number of ViewInfo we create.
579 // In the (very common) case where backward_info_ == forward_info_, we only
580 // populate backward_info_ (that should be used as both the forward and
581 // backward view information) and set shared_view_info_ = true. Invariants:
582 // - If shared_view_info_ is false, there is no special constraints on
583 // backward_info_ and forward_info_
584 // - If shared_view_info_ is true, we must have:
585 // - backward_info_.has_value() == true
586 // - forward_info_.has_value() == false
587 bool shared_view_info_;
588
589 /// The two following fields are extra information that we track to ensure
590 /// that any operation on this backward view is valid.
591
592 /// The value of the version_counter at the time grad_fn was created. The
593 /// grad_fn field is stale if attr_version_ !=
594 /// version_counter.current_version().
595 uint32_t attr_version_;
596 CreationMeta creation_meta_;
597
598 public:
599 /// requires_grad is a backward AD field so we only use the view specific
600 /// logic for backward differentiable views
601 bool requires_grad() const override {
602 return requires_grad_ || grad_fn_ ||
603 (has_bw_view() && get_backward_view().base_.requires_grad());
604 }
605
606 bool shared_view_info() const {
607 return shared_view_info_;
608 }
609
610 bool has_bw_view() const {
611 return backward_info_.has_value();
612 }
613
614 const ViewInfo& get_backward_view() const {
615 TORCH_CHECK(
616 has_bw_view(), "backward view info can only exist for backward views.");
617 return backward_info_.value();
618 }
619
620 uint32_t get_attr_version() const {
621 TORCH_CHECK(
622 has_bw_view(), "attr_version can only exist for backward views.");
623 return attr_version_;
624 }
625
626 void set_attr_version(uint32_t new_attr_version) {
627 TORCH_CHECK(
628 has_bw_view(), "attr_version can only exist for backward views.");
629 attr_version_ = new_attr_version;
630 }
631
632 CreationMeta get_creation_meta() const {
633 TORCH_CHECK(
634 has_bw_view(), "creation_meta can only exist for backward views.");
635 return creation_meta_;
636 }
637
638 void set_creation_meta(CreationMeta new_creation_meta) {
639 TORCH_CHECK(
640 has_bw_view(), "creation_meta can only exist for backward views.");
641 creation_meta_ = new_creation_meta;
642 }
643
644 bool has_fw_view() const {
645 return shared_view_info_ || forward_info_.has_value();
646 }
647
648 const ViewInfo& get_forward_view() const {
649 TORCH_CHECK(
650 has_fw_view(), "forward view info can only exist for forward views.");
651 TORCH_CHECK(
652 !shared_view_info_ || has_bw_view(),
653 "forward view info can only exist for forward views.");
654 return shared_view_info_ ? backward_info_.value() : forward_info_.value();
655 }
656
657 DifferentiableViewMeta(
658 at::TensorImpl* self_impl,
659 c10::optional<ViewInfo> backward_info,
660 c10::optional<ViewInfo> forward_info,
661 bool shared_view_info,
662 CreationMeta creation_meta = CreationMeta::DEFAULT);
663};
664
665//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
666// Variable Implementation
667//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
668
669// Factory Functions
670//~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
671
672/// Creates a `Variable` that is a *view* of another (*base*) variable.
673/// The `gradient_edge` is an optional (gradient_function, input_number) pair.
674/// `is_differentiable` is a bool that specifies whether this view is
675/// differentiable, i.e., whether the relation should be tracked by autograd.
676/// See NOTE [ Autograd View Variables ] for details.
677
678/// NOTE: `allow_tensor_metadata_change` is set to true by default, because
679/// there are a lot of call sites to these factory functions that need to change
680/// the variable's size or storage afterwards, and they don't expect the
681/// original tensor (where the variable is created from) to be updated. Setting
682/// `allow_tensor_metadata_change_` to false by default would unnecessarily
683/// prevent those changes from happening and is undesirable.
684
685// See NOTE [ Autograd View Variables ] for details.
686// Differentiable view. Track history with DifferentiableViewMeta.
687inline Variable make_variable_differentiable_view(
688 const at::Tensor& data,
689 c10::optional<ViewInfo> backward_info,
690 c10::optional<ViewInfo> forward_info,
691 bool shared_view_info,
692 CreationMeta creation_meta,
693 bool allow_tensor_metadata_change = true) {
694 if (data.defined()) {
695 TORCH_CHECK(
696 data.getIntrusivePtr()->autograd_meta() == nullptr,
697 "Attempted to make a tensor into a differentiable view, but the "
698 "tensor already had autograd metadata associated with it. If you are "
699 "using a __torch_dispatch__ mode, the most common cause for this "
700 "problem is that you used torch.overrides.enable_reentrant_dispatch() "
701 "improperly; tensors created within the extent of reentrant dispatch "
702 "MUST NOT be directly returned from __torch_dispatch__; instead, they "
703 "must be wrapped into fresh tensors that serve as the output. If you "
704 "are not using wrappers, you probably don't need reentrant dispatch. "
705 "If this doesn't seem applicable, please file a bug to PyTorch.");
706 at::TensorImpl* data_impl = data.unsafeGetTensorImpl();
707 data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
708 data_impl->set_autograd_meta(std::make_unique<DifferentiableViewMeta>(
709 data_impl,
710 std::move(backward_info),
711 std::move(forward_info),
712 shared_view_info,
713 creation_meta));
714 return data;
715 }
716 return Variable();
717}
718
719// See NOTE [ Autograd View Variables ] for details.
720// Non-differentiable view. Just share version counter.
721inline Variable make_variable_non_differentiable_view(
722 Variable base,
723 const at::Tensor& data,
724 bool allow_tensor_metadata_change = true) {
725 if (data.defined()) {
726 // Currently all of non-differentiable view ops(detach/_indices/_values)
727 // share the same TensorImpl as their base Tensor. Thus a new TensorImpl
728 // allocation here is required.
729 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
730 /*version_counter=*/impl::version_counter(base),
731 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
732 data_impl_copy->set_autograd_meta(nullptr);
733 return Variable(data_impl_copy);
734 }
735 return Variable();
736}
737
738/// Creates a `Variable` from the given `Tensor`, copying its underlying
739/// `TensorImpl`. `requires_grad` should be set only for leaves, and determines
740/// whether the `Variable` will accumulate gradients. NOTE: `data` must *not* be
741/// a `Variable` already. Its dynamic type *must* be `Tensor`.
742///
743/// TODO: Eliminate this function as much as possible, as it can be expressed
744/// more clearly as detach() or a no-op in most call sites (especially when
745/// there is only one use of the variable).
746inline Variable make_variable(
747 at::Tensor data,
748 bool requires_grad = false,
749 bool allow_tensor_metadata_change = true) {
750 if (data.defined()) {
751 if (data.getIntrusivePtr().use_count() == 1 &&
752 data.getIntrusivePtr()->unique_version()) {
753 auto data_impl = data.unsafeReleaseIntrusivePtr();
754 data_impl->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
755 // NOLINTNEXTLINE(bugprone-branch-clone)
756 if (requires_grad) {
757 data_impl->set_autograd_meta(
758 std::make_unique<AutogradMeta>(data_impl.get(), requires_grad));
759 } else {
760 data_impl->set_autograd_meta(nullptr);
761 }
762 return Variable(std::move(data_impl));
763 } else {
764 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
765 /*version_counter=*/0,
766 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
767 // NOLINTNEXTLINE(bugprone-branch-clone)
768 if (requires_grad) {
769 data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
770 data_impl_copy.get(), requires_grad));
771 } else {
772 data_impl_copy->set_autograd_meta(nullptr);
773 }
774 return Variable(data_impl_copy);
775 }
776 }
777 return Variable();
778}
779
780/// Creates a `Variable` from the given `Tensor`, copying its underlying
781/// `TensorImpl`. `gradient_edge` should be a (function, input_nr) pair
782/// specifying the function in the autograd graph, and what particular input of
783/// that function, this variable is connected to.
784inline Variable make_variable(
785 at::Tensor data,
786 Edge gradient_edge,
787 bool allow_tensor_metadata_change = true) {
788 if (data.defined()) {
789 auto data_impl_copy = data.getIntrusivePtr()->shallow_copy_and_detach(
790 /*version_counter=*/0,
791 /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
792 data_impl_copy->set_autograd_meta(std::make_unique<AutogradMeta>(
793 data_impl_copy.get(), false, std::move(gradient_edge)));
794 return Variable(data_impl_copy);
795 }
796 return Variable();
797}
798
799namespace utils {
800
801TORCH_API bool has_same_meta(const Variable& base, const Variable& other);
802
803} // namespace utils
804} // namespace autograd
805} // namespace torch
806
807#endif /* DOXYGEN_SHOULD_SKIP_THIS */
808