1 | #pragma once |
2 | |
3 | #include <torch/csrc/Export.h> |
4 | #include <torch/csrc/autograd/forward_grad.h> |
5 | #include <torch/csrc/autograd/saved_variable_hooks.h> |
6 | |
7 | #include <ATen/core/Tensor.h> |
8 | |
9 | #include <cstdint> |
10 | #include <memory> |
11 | |
12 | namespace torch { |
13 | namespace autograd { |
14 | |
15 | using Variable = at::Tensor; |
16 | struct Node; |
17 | |
18 | TORCH_API extern const char* ERR_BACKWARD_TWICE; |
19 | |
20 | /// A snapshot of a variable at a certain version. A `SavedVariable` stores |
21 | /// enough information to reconstruct a variable from a certain point in time. |
22 | class TORCH_API SavedVariable { |
23 | public: |
24 | SavedVariable() = default; |
25 | SavedVariable( |
26 | const Variable& variable, |
27 | bool is_output, |
28 | bool is_inplace_on_view = false); |
29 | SavedVariable( |
30 | const c10::optional<Variable>& variable, |
31 | bool is_output, |
32 | bool is_inplace_on_view = false); |
33 | SavedVariable(SavedVariable&&) = default; |
34 | SavedVariable& operator=(SavedVariable&&) = default; |
35 | ~SavedVariable() { |
36 | if (fw_grad_) { |
37 | // See note [ Using ForwardGrad ] |
38 | fw_grad_->clear(); |
39 | } |
40 | } |
41 | |
42 | /// Reconstructs the saved variable. Pass `saved_for` as the gradient |
43 | /// function if constructing the `SavedVariable` with it would have caused a |
44 | /// circular reference. |
45 | Variable unpack(std::shared_ptr<Node> saved_for = nullptr) const; |
46 | |
47 | void register_hooks(std::unique_ptr<SavedVariableHooks>&& hooks); |
48 | |
49 | void reset_data(); |
50 | |
51 | private: |
52 | // This field contains either: |
53 | // 1. the variable to save |
54 | // 2. or its tensor_data. |
55 | // If storing the variable itself would create a circular reference, |
56 | // we fall into the second case and its metadata is also saved separately. |
57 | // In that case, the grad_fn must be passed in to the unpack function when |
58 | // reconstructing the Variable (except when we are doing an inplace operation |
59 | // on a view, see below). The field saved_orignal_ below reflects the two |
60 | // cases: its value is true in the first case and false in the second case. |
61 | // The value data_.defined() can be false in three cases: |
62 | // 1. SavedVariable was constructed without a Tensor (the value to save is |
63 | // None), in that case was_default_constructed_ will be kept at true |
64 | // 2. The saved variable has been released by calling |
65 | // SavedVariable::reset_data(), typically during the backward pass |
66 | // 3. Hooks have been registered. In that case, hooks_ will be defined |
67 | // instead. Note that the value of saved_original_ only reflects what happened |
68 | // during the construction of the SavedVariable. If saved_original_ is true, |
69 | // we saved the original tensor in data_, but if the user registers hooks, we |
70 | // will no longer have it (despite the saved_original_ still being true) |
71 | at::Tensor data_; |
72 | |
73 | // This field is used to store the forward AD gradients associated with |
74 | // the saved Tensor. Note that this shared_ptr must never be shared with |
75 | // either the saved Tensor or the unpacked Tensor. See note [ Using |
76 | // ForwardGrad ] |
77 | std::shared_ptr<ForwardGrad> fw_grad_; |
78 | |
79 | // Weak version of grad_fn_ that prevents leaks in rebase_history() for |
80 | // inplace views. |
81 | // This variable is used when the user chooses to create a SavedVariable with |
82 | // is_inplace_on_view = true. |
83 | // In that case, the grad_fn passed in to the unpack function at unwrapping |
84 | // time is unused. |
85 | std::weak_ptr<Node> weak_grad_fn_; |
86 | c10::VariableVersion version_counter_; |
87 | |
88 | uint32_t saved_version_ = 0; |
89 | uint32_t output_nr_ = 0; |
90 | bool was_default_constructed_ = true; |
91 | bool is_inplace_on_view_ = false; |
92 | bool saved_original_ = false; |
93 | bool is_leaf_ = false; |
94 | bool is_output_ = false; |
95 | |
96 | // Hooks are a pair of functions pack_hook/unpack_hook that provides |
97 | // fine-grained control over how the SavedVariable should save its data. |
98 | // pack_hook is called upon registration, while unpack_hook is called when |
99 | // unpacking. |
100 | std::unique_ptr<SavedVariableHooks> hooks_; |
101 | // Fields grad_fn_, grad_accumulator_, and requires_grad_ are only used if |
102 | // hooks are defined. They are set before pack_hook is called and used after |
103 | // unpack_hook is called. |
104 | std::shared_ptr<Node> grad_fn_; |
105 | std::weak_ptr<Node> grad_accumulator_; |
106 | bool requires_grad_ = false; |
107 | |
108 | void save_metadata(const Variable& data); |
109 | static std::unique_ptr<SavedVariableHooks> get_default_hooks(); |
110 | void set_hooks_and_pack_data( |
111 | std::unique_ptr<SavedVariableHooks>&& hooks, |
112 | const Variable& data); |
113 | }; |
114 | } // namespace autograd |
115 | } // namespace torch |
116 | |