1#include <torch/csrc/autograd/saved_variable.h>
2
3#include <torch/csrc/autograd/anomaly_mode.h>
4#include <torch/csrc/autograd/edge.h>
5#include <torch/csrc/autograd/engine.h>
6#include <torch/csrc/autograd/function.h>
7#include <torch/csrc/autograd/grad_mode.h>
8#include <torch/csrc/autograd/variable.h>
9
10#include <ATen/Tensor.h>
11
12#include <cstdint>
13#include <list>
14#include <memory>
15#include <sstream>
16
17namespace torch {
18namespace autograd {
19
20SavedVariable::SavedVariable(
21 const Variable& variable,
22 bool is_output,
23 bool is_inplace_on_view) {
24 if (variable.defined()) {
25 // Note [Inference tensor cannot be saved for backward]
26 // Invariant:
27 // You can't save an inference tensor for backwards.
28 // If an inference tensor was saved for backward in an autograd session and
29 // then you reenter inference mode and make an inplace update to the tensor
30 // without bumping version_counter, it'll lead to silent wrong result when
31 // you do backward() for the previous autograd session. Technically we
32 // don't have to check here since it'll fail when querying `current_version`
33 // on the inference tensor, but we can give a much better error message
34 // here.
35 //
36 // Note in the documentation we say "inference tensor cannot participate
37 // in autograd" which is more restrictive than the invariant. In practice
38 // the check is more permissive and only error out when an inference tensor
39 // is saved for backward. Whether a tensor is saved for backward is
40 // determined by derivative formula and thus varies op by op, so by saying
41 // "no inference tensor in autograd" it's easier for users to understand and
42 // follow.
43 TORCH_CHECK(
44 !variable.is_inference(),
45 "Inference tensors cannot be saved for backward. To work around "
46 "you can make a clone to get a normal tensor and use it in autograd.")
47
48 was_default_constructed_ = false;
49 const auto& version_counter = impl::version_counter(variable);
50 saved_version_ = version_counter.current_version();
51 is_leaf_ = variable.is_leaf();
52 is_output_ = is_output;
53 is_inplace_on_view_ = is_inplace_on_view;
54
55 if (is_inplace_on_view) {
56 TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output);
57 weak_grad_fn_ = variable.grad_fn();
58 }
59
60 auto maybe_hooks = get_default_hooks();
61
62 // Avoid wrapped numbers from being leaked to the user
63 if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) {
64 save_metadata(variable);
65 set_hooks_and_pack_data(std::move(maybe_hooks), variable);
66 return;
67 }
68
69 // If the variable is a leaf or is not an output, we can safely save the
70 // original variable without running the risk of reference cycles.
71 // 1. If the variable is not an output, its grad_fn has already been fully
72 // created and in particular will be a different Node than the one
73 // we are currently constructing (the one that owns this SavedVariable).
74 // 2. If the variable is a leaf, it only has weak reference to the
75 // grad_accumulator which cannot create a cycle. In those cases, we save the
76 // original variable and don't need further processing.
77 if (!is_output || is_leaf_) {
78 saved_original_ = true;
79 data_ = variable;
80 return;
81 }
82
83 save_metadata(variable);
84
85 // Only do this if we actually need to.
86 data_ = variable.tensor_data();
87 }
88}
89
90void SavedVariable::save_metadata(const Variable& data) {
91 // Save output number, version counter and fw_grad if needed
92
93 output_nr_ = data.output_nr();
94 version_counter_ = impl::version_counter(data);
95
96 if (is_leaf_) {
97 grad_accumulator_ = impl::grad_accumulator(data);
98 requires_grad_ = data.requires_grad();
99 } else if (!is_output_) {
100 grad_fn_ = data.grad_fn();
101 }
102
103 // TODO(albanD) This needs to be updated when moving to multiple levels
104 const auto& fw_grad = data._fw_grad(/* level */ 0);
105 if (fw_grad.defined()) {
106 fw_grad_ = std::make_shared<ForwardGrad>();
107 fw_grad_->set_value(fw_grad, /* level */ 0);
108 }
109}
110
111std::unique_ptr<SavedVariableHooks> SavedVariable::get_default_hooks() {
112 return Engine::get_default_engine().get_default_saved_variable_hooks();
113}
114
115void SavedVariable::reset_data() {
116 hooks_.reset();
117 grad_fn_.reset();
118 data_.reset();
119}
120
121SavedVariable::SavedVariable(
122 const c10::optional<Variable>& variable,
123 bool is_output,
124 bool is_inplace_on_view)
125 : SavedVariable(
126 variable.has_value() ? *variable : Variable(),
127 is_output,
128 is_inplace_on_view) {}
129
130Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const {
131 if (was_default_constructed_) {
132 return Variable();
133 }
134
135 if (!data_.defined()) {
136 TORCH_CHECK(hooks_, ERR_BACKWARD_TWICE);
137 }
138
139 // We want grad_fn here to provide the most helpful debug message to the user
140 // if versions don't match
141
142 auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock()
143 : !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr
144 : grad_fn_;
145
146 if (!is_leaf_ && !grad_fn) {
147 // This issue was introduced when we added logic to save the original
148 // because now we rely on data_.grad_fn(), but can be unreliable if the
149 // autograd_meta of that saved tensor is cleared with an in-place detach.
150 // As a simple fix, we choose to disallow that behavior here even though
151 // it makes behavior inconsistent depending on whether you are saving
152 // input or output.
153 TORCH_CHECK(
154 saved_for,
155 "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()."
156 "This is not supported, please use out-of-place `.detach()` instead");
157 grad_fn = std::move(saved_for);
158 }
159
160 // Only check version counter in the case without hooks
161 // If user provides hooks, we can't track versions through the hooks
162 if (!hooks_) {
163 auto current_version = saved_original_
164 ? impl::version_counter(data_).current_version()
165 : version_counter_.current_version();
166
167 if (saved_version_ != current_version) {
168 std::stringstream message;
169 message
170 << "one of the variables needed for gradient computation has been "
171 "modified by an inplace operation: ["
172 << data_.toString() << " ";
173 if (data_.is_nested()) {
174 message << data_._nested_tensor_size() << "]";
175 } else {
176 message << data_.sizes() << "]";
177 }
178 if (grad_fn) {
179 message << ", which is output " << output_nr_ << " of "
180 << grad_fn->name() << ",";
181 }
182 message << " is at version " << current_version << "; expected version "
183 << saved_version_ << " instead.";
184 if (!AnomalyMode::is_enabled()) {
185 message << " Hint: enable anomaly detection to find the operation "
186 "that failed to compute its gradient, with torch.autograd."
187 "set_detect_anomaly(True).";
188 } else {
189 message
190 << " Hint: the backtrace further above shows the operation "
191 "that failed to compute its gradient. The variable in question "
192 "was changed in there or anywhere later. Good luck!";
193 }
194 TORCH_CHECK(false, message.str());
195 }
196 }
197
198 // The version counter is correct.
199 // Additionnally, if we deal with a non-leaf variable, we have its correct
200 // grad_fn.
201
202 // If we have the original variable, we simply return it
203 if (!hooks_ && saved_original_) {
204 return data_;
205 }
206
207 const auto data = hooks_ ? hooks_->call_unpack_hook() : data_;
208
209 // NB: saved views are unpacked as normal Variables (not views) even though
210 // they still share the same storage. This works only because we never call
211 // in-place functions on unpacked variables.
212 Variable var;
213 if (grad_fn) {
214 var = make_variable(data, Edge(std::move(grad_fn), output_nr_));
215 } else {
216 var = make_variable(data, requires_grad_);
217 }
218
219 impl::set_version_counter(var, version_counter_);
220
221 // If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we
222 // should have saved the grad accumulator. Even if the Variable is no longer
223 // alive, the accumulator should be kept alive by the references in the
224 // graph.
225 if (is_leaf_ && requires_grad_) {
226 TORCH_INTERNAL_ASSERT(
227 !grad_accumulator_.expired(), "No grad accumulator for a saved leaf");
228 }
229 impl::set_grad_accumulator(var, grad_accumulator_);
230
231 // NB: var here is never a view so there is no need to make anything special
232 // for the case where the saved Tensor was a view. This whole argument relies
233 // on the fact that the Tensor returned by this function is never
234 // modified in-place.
235 if (fw_grad_ && !fw_grad_->empty()) {
236 // TODO(albanD) This needs to be updated when moving to multiple levels
237 auto new_fw_grad = fw_grad_->value(/* level */ 0);
238 var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false);
239 }
240
241 return var;
242}
243
244void SavedVariable::set_hooks_and_pack_data(
245 std::unique_ptr<SavedVariableHooks>&& hooks,
246 const Variable& data) {
247 hooks_ = std::move(hooks);
248 at::NoGradGuard guard;
249 const auto version = impl::version_counter(data).current_version();
250 hooks_->call_pack_hook(saved_original_ ? data.detach() : data);
251 TORCH_CHECK(
252 version == impl::version_counter(data).current_version(),
253 "A saved tensor pack hook is modifying its input in place. "
254 "Tensors provided as input to pack hook can not be modified by "
255 "in-place operations as this can lead to unexpected side-effects. "
256 "Please open an issue if you need to perform in-place operations on "
257 "the input to a pack hook.");
258}
259
260void SavedVariable::register_hooks(
261 std::unique_ptr<SavedVariableHooks>&& hooks) {
262 TORCH_INTERNAL_ASSERT(hooks);
263 TORCH_CHECK(
264 !hooks_,
265 "Calling register_hooks on a saved tensor whose hooks have already been set. "
266 "Hint: only one pair of hooks is allowed at a time.");
267 if (!data_.defined()) {
268 if (!was_default_constructed_) {
269 TORCH_CHECK(
270 false,
271 "Calling register_hooks on a saved tensor after it has been freed. "
272 "Saved intermediate values of the graph are freed when you call "
273 ".backward() or autograd.grad(). Specify retain_graph=True if you "
274 "need to backward through the graph a second time or if you need to "
275 "access saved variables after calling backward.");
276 } else {
277 TORCH_CHECK(
278 false,
279 "Calling register_hooks on a saved tensor with value None is forbidden");
280 }
281 }
282 // If we didn't save the original variable, we already saved metadata
283 if (saved_original_) {
284 save_metadata(data_);
285 }
286 set_hooks_and_pack_data(std::move(hooks), data_);
287 data_.reset();
288}
289
290const char* ERR_BACKWARD_TWICE =
291 "Trying to backward through the graph a second time (or directly access saved "
292 "tensors after they have already been freed). Saved intermediate values "
293 "of the graph are freed when you call .backward() or autograd.grad(). Specify "
294 "retain_graph=True if you need to backward through the graph a second time or "
295 "if you need to access saved tensors after calling backward.";
296
297} // namespace autograd
298} // namespace torch
299