1#pragma once
2
3#include <torch/csrc/Export.h>
4#include <torch/csrc/autograd/forward_grad.h>
5#include <torch/csrc/autograd/saved_variable_hooks.h>
6
7#include <ATen/core/Tensor.h>
8
9#include <cstdint>
10#include <memory>
11
12namespace torch {
13namespace autograd {
14
15using Variable = at::Tensor;
16struct Node;
17
18TORCH_API extern const char* ERR_BACKWARD_TWICE;
19
20/// A snapshot of a variable at a certain version. A `SavedVariable` stores
21/// enough information to reconstruct a variable from a certain point in time.
22class TORCH_API SavedVariable {
23 public:
24 SavedVariable() = default;
25 SavedVariable(
26 const Variable& variable,
27 bool is_output,
28 bool is_inplace_on_view = false);
29 SavedVariable(
30 const c10::optional<Variable>& variable,
31 bool is_output,
32 bool is_inplace_on_view = false);
33 SavedVariable(SavedVariable&&) = default;
34 SavedVariable& operator=(SavedVariable&&) = default;
35 ~SavedVariable() {
36 if (fw_grad_) {
37 // See note [ Using ForwardGrad ]
38 fw_grad_->clear();
39 }
40 }
41
42 /// Reconstructs the saved variable. Pass `saved_for` as the gradient
43 /// function if constructing the `SavedVariable` with it would have caused a
44 /// circular reference.
45 Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const;
46
47 void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks);
48
49 void reset_data();
50
51 private:
52 // This field contains either:
53 // 1. the variable to save
54 // 2. or its tensor_data.
55 // If storing the variable itself would create a circular reference,
56 // we fall into the second case and its metadata is also saved separately.
57 // In that case, the grad_fn must be passed in to the unpack function when
58 // reconstructing the Variable (except when we are doing an inplace operation
59 // on a view, see below). The field saved_orignal_ below reflects the two
60 // cases: its value is true in the first case and false in the second case.
61 // The value data_.defined() can be false in three cases:
62 // 1. SavedVariable was constructed without a Tensor (the value to save is
63 // None), in that case was_default_constructed_ will be kept at true
64 // 2. The saved variable has been released by calling
65 // SavedVariable::reset_data(), typically during the backward pass
66 // 3. Hooks have been registered. In that case, hooks_ will be defined
67 // instead. Note that the value of saved_original_ only reflects what happened
68 // during the construction of the SavedVariable. If saved_original_ is true,
69 // we saved the original tensor in data_, but if the user registers hooks, we
70 // will no longer have it (despite the saved_original_ still being true)
71 at::Tensor data_;
72
73 // This field is used to store the forward AD gradients associated with
74 // the saved Tensor. Note that this shared_ptr must never be shared with
75 // either the saved Tensor or the unpacked Tensor. See note [ Using
76 // ForwardGrad ]
77 std::shared_ptr<ForwardGrad> fw_grad_;
78
79 // Weak version of grad_fn_ that prevents leaks in rebase_history() for
80 // inplace views.
81 // This variable is used when the user chooses to create a SavedVariable with
82 // is_inplace_on_view = true.
83 // In that case, the grad_fn passed in to the unpack function at unwrapping
84 // time is unused.
85 std::weak_ptr<Node> weak_grad_fn_;
86 c10::VariableVersion version_counter_;
87
88 uint32_t saved_version_ = 0;
89 uint32_t output_nr_ = 0;
90 bool was_default_constructed_ = true;
91 bool is_inplace_on_view_ = false;
92 bool saved_original_ = false;
93 bool is_leaf_ = false;
94 bool is_output_ = false;
95
96 // Hooks are a pair of functions pack_hook/unpack_hook that provides
97 // fine-grained control over how the SavedVariable should save its data.
98 // pack_hook is called upon registration, while unpack_hook is called when
99 // unpacking.
100 std::unique_ptr<SavedVariableHooks> hooks_;
101 // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if
102 // hooks are defined. They are set before pack_hook is called and used after
103 // unpack_hook is called.
104 std::shared_ptr<Node> grad_fn_;
105 std::weak_ptr<Node> grad_accumulator_;
106 bool requires_grad_ = false;
107
108 void save_metadata(const Variable& data);
109 static std::unique_ptr<SavedVariableHooks> get_default_hooks();
110 void set_hooks_and_pack_data(
111 std::unique_ptr<SavedVariableHooks>&& hooks,
112 const Variable& data);
113};
114} // namespace autograd
115} // namespace torch
116