1 | #include <torch/csrc/autograd/saved_variable.h> |
2 | |
3 | #include <torch/csrc/autograd/anomaly_mode.h> |
4 | #include <torch/csrc/autograd/edge.h> |
5 | #include <torch/csrc/autograd/engine.h> |
6 | #include <torch/csrc/autograd/function.h> |
7 | #include <torch/csrc/autograd/grad_mode.h> |
8 | #include <torch/csrc/autograd/variable.h> |
9 | |
10 | #include <ATen/Tensor.h> |
11 | |
12 | #include <cstdint> |
13 | #include <list> |
14 | #include <memory> |
15 | #include <sstream> |
16 | |
17 | namespace torch { |
18 | namespace autograd { |
19 | |
20 | SavedVariable::SavedVariable( |
21 | const Variable& variable, |
22 | bool is_output, |
23 | bool is_inplace_on_view) { |
24 | if (variable.defined()) { |
25 | // Note [Inference tensor cannot be saved for backward] |
26 | // Invariant: |
27 | // You can't save an inference tensor for backwards. |
28 | // If an inference tensor was saved for backward in an autograd session and |
29 | // then you reenter inference mode and make an inplace update to the tensor |
30 | // without bumping version_counter, it'll lead to silent wrong result when |
31 | // you do backward() for the previous autograd session. Technically we |
32 | // don't have to check here since it'll fail when querying `current_version` |
33 | // on the inference tensor, but we can give a much better error message |
34 | // here. |
35 | // |
36 | // Note in the documentation we say "inference tensor cannot participate |
37 | // in autograd" which is more restrictive than the invariant. In practice |
38 | // the check is more permissive and only error out when an inference tensor |
39 | // is saved for backward. Whether a tensor is saved for backward is |
40 | // determined by derivative formula and thus varies op by op, so by saying |
41 | // "no inference tensor in autograd" it's easier for users to understand and |
42 | // follow. |
43 | TORCH_CHECK( |
44 | !variable.is_inference(), |
45 | "Inference tensors cannot be saved for backward. To work around " |
46 | "you can make a clone to get a normal tensor and use it in autograd." ) |
47 | |
48 | was_default_constructed_ = false; |
49 | const auto& version_counter = impl::version_counter(variable); |
50 | saved_version_ = version_counter.current_version(); |
51 | is_leaf_ = variable.is_leaf(); |
52 | is_output_ = is_output; |
53 | is_inplace_on_view_ = is_inplace_on_view; |
54 | |
55 | if (is_inplace_on_view) { |
56 | TORCH_INTERNAL_ASSERT(!is_leaf_ && is_output); |
57 | weak_grad_fn_ = variable.grad_fn(); |
58 | } |
59 | |
60 | auto maybe_hooks = get_default_hooks(); |
61 | |
62 | // Avoid wrapped numbers from being leaked to the user |
63 | if (maybe_hooks && !variable.unsafeGetTensorImpl()->is_wrapped_number()) { |
64 | save_metadata(variable); |
65 | set_hooks_and_pack_data(std::move(maybe_hooks), variable); |
66 | return; |
67 | } |
68 | |
69 | // If the variable is a leaf or is not an output, we can safely save the |
70 | // original variable without running the risk of reference cycles. |
71 | // 1. If the variable is not an output, its grad_fn has already been fully |
72 | // created and in particular will be a different Node than the one |
73 | // we are currently constructing (the one that owns this SavedVariable). |
74 | // 2. If the variable is a leaf, it only has weak reference to the |
75 | // grad_accumulator which cannot create a cycle. In those cases, we save the |
76 | // original variable and don't need further processing. |
77 | if (!is_output || is_leaf_) { |
78 | saved_original_ = true; |
79 | data_ = variable; |
80 | return; |
81 | } |
82 | |
83 | save_metadata(variable); |
84 | |
85 | // Only do this if we actually need to. |
86 | data_ = variable.tensor_data(); |
87 | } |
88 | } |
89 | |
90 | void SavedVariable::save_metadata(const Variable& data) { |
91 | // Save output number, version counter and fw_grad if needed |
92 | |
93 | output_nr_ = data.output_nr(); |
94 | version_counter_ = impl::version_counter(data); |
95 | |
96 | if (is_leaf_) { |
97 | grad_accumulator_ = impl::grad_accumulator(data); |
98 | requires_grad_ = data.requires_grad(); |
99 | } else if (!is_output_) { |
100 | grad_fn_ = data.grad_fn(); |
101 | } |
102 | |
103 | // TODO(albanD) This needs to be updated when moving to multiple levels |
104 | const auto& fw_grad = data._fw_grad(/* level */ 0); |
105 | if (fw_grad.defined()) { |
106 | fw_grad_ = std::make_shared<ForwardGrad>(); |
107 | fw_grad_->set_value(fw_grad, /* level */ 0); |
108 | } |
109 | } |
110 | |
111 | std::unique_ptr<SavedVariableHooks> SavedVariable::get_default_hooks() { |
112 | return Engine::get_default_engine().get_default_saved_variable_hooks(); |
113 | } |
114 | |
115 | void SavedVariable::reset_data() { |
116 | hooks_.reset(); |
117 | grad_fn_.reset(); |
118 | data_.reset(); |
119 | } |
120 | |
121 | SavedVariable::SavedVariable( |
122 | const c10::optional<Variable>& variable, |
123 | bool is_output, |
124 | bool is_inplace_on_view) |
125 | : SavedVariable( |
126 | variable.has_value() ? *variable : Variable(), |
127 | is_output, |
128 | is_inplace_on_view) {} |
129 | |
130 | Variable SavedVariable::unpack(std::shared_ptr<Node> saved_for) const { |
131 | if (was_default_constructed_) { |
132 | return Variable(); |
133 | } |
134 | |
135 | if (!data_.defined()) { |
136 | TORCH_CHECK(hooks_, ERR_BACKWARD_TWICE); |
137 | } |
138 | |
139 | // We want grad_fn here to provide the most helpful debug message to the user |
140 | // if versions don't match |
141 | |
142 | auto grad_fn = is_inplace_on_view_ ? weak_grad_fn_.lock() |
143 | : !hooks_ ? saved_original_ ? data_.grad_fn() : nullptr |
144 | : grad_fn_; |
145 | |
146 | if (!is_leaf_ && !grad_fn) { |
147 | // This issue was introduced when we added logic to save the original |
148 | // because now we rely on data_.grad_fn(), but can be unreliable if the |
149 | // autograd_meta of that saved tensor is cleared with an in-place detach. |
150 | // As a simple fix, we choose to disallow that behavior here even though |
151 | // it makes behavior inconsistent depending on whether you are saving |
152 | // input or output. |
153 | TORCH_CHECK( |
154 | saved_for, |
155 | "Trying to use a saved tensor that has been detached in-place, i.e. with .detach_()." |
156 | "This is not supported, please use out-of-place `.detach()` instead" ); |
157 | grad_fn = std::move(saved_for); |
158 | } |
159 | |
160 | // Only check version counter in the case without hooks |
161 | // If user provides hooks, we can't track versions through the hooks |
162 | if (!hooks_) { |
163 | auto current_version = saved_original_ |
164 | ? impl::version_counter(data_).current_version() |
165 | : version_counter_.current_version(); |
166 | |
167 | if (saved_version_ != current_version) { |
168 | std::stringstream message; |
169 | message |
170 | << "one of the variables needed for gradient computation has been " |
171 | "modified by an inplace operation: [" |
172 | << data_.toString() << " " ; |
173 | if (data_.is_nested()) { |
174 | message << data_._nested_tensor_size() << "]" ; |
175 | } else { |
176 | message << data_.sizes() << "]" ; |
177 | } |
178 | if (grad_fn) { |
179 | message << ", which is output " << output_nr_ << " of " |
180 | << grad_fn->name() << "," ; |
181 | } |
182 | message << " is at version " << current_version << "; expected version " |
183 | << saved_version_ << " instead." ; |
184 | if (!AnomalyMode::is_enabled()) { |
185 | message << " Hint: enable anomaly detection to find the operation " |
186 | "that failed to compute its gradient, with torch.autograd." |
187 | "set_detect_anomaly(True)." ; |
188 | } else { |
189 | message |
190 | << " Hint: the backtrace further above shows the operation " |
191 | "that failed to compute its gradient. The variable in question " |
192 | "was changed in there or anywhere later. Good luck!" ; |
193 | } |
194 | TORCH_CHECK(false, message.str()); |
195 | } |
196 | } |
197 | |
198 | // The version counter is correct. |
199 | // Additionnally, if we deal with a non-leaf variable, we have its correct |
200 | // grad_fn. |
201 | |
202 | // If we have the original variable, we simply return it |
203 | if (!hooks_ && saved_original_) { |
204 | return data_; |
205 | } |
206 | |
207 | const auto data = hooks_ ? hooks_->call_unpack_hook() : data_; |
208 | |
209 | // NB: saved views are unpacked as normal Variables (not views) even though |
210 | // they still share the same storage. This works only because we never call |
211 | // in-place functions on unpacked variables. |
212 | Variable var; |
213 | if (grad_fn) { |
214 | var = make_variable(data, Edge(std::move(grad_fn), output_nr_)); |
215 | } else { |
216 | var = make_variable(data, requires_grad_); |
217 | } |
218 | |
219 | impl::set_version_counter(var, version_counter_); |
220 | |
221 | // If a Variable is a leaf (no grad_fn saved), and it requires_grad, then we |
222 | // should have saved the grad accumulator. Even if the Variable is no longer |
223 | // alive, the accumulator should be kept alive by the references in the |
224 | // graph. |
225 | if (is_leaf_ && requires_grad_) { |
226 | TORCH_INTERNAL_ASSERT( |
227 | !grad_accumulator_.expired(), "No grad accumulator for a saved leaf" ); |
228 | } |
229 | impl::set_grad_accumulator(var, grad_accumulator_); |
230 | |
231 | // NB: var here is never a view so there is no need to make anything special |
232 | // for the case where the saved Tensor was a view. This whole argument relies |
233 | // on the fact that the Tensor returned by this function is never |
234 | // modified in-place. |
235 | if (fw_grad_ && !fw_grad_->empty()) { |
236 | // TODO(albanD) This needs to be updated when moving to multiple levels |
237 | auto new_fw_grad = fw_grad_->value(/* level */ 0); |
238 | var._set_fw_grad(new_fw_grad, /* level */ 0, /* is_inplace_op */ false); |
239 | } |
240 | |
241 | return var; |
242 | } |
243 | |
244 | void SavedVariable::set_hooks_and_pack_data( |
245 | std::unique_ptr<SavedVariableHooks>&& hooks, |
246 | const Variable& data) { |
247 | hooks_ = std::move(hooks); |
248 | at::NoGradGuard guard; |
249 | const auto version = impl::version_counter(data).current_version(); |
250 | hooks_->call_pack_hook(saved_original_ ? data.detach() : data); |
251 | TORCH_CHECK( |
252 | version == impl::version_counter(data).current_version(), |
253 | "A saved tensor pack hook is modifying its input in place. " |
254 | "Tensors provided as input to pack hook can not be modified by " |
255 | "in-place operations as this can lead to unexpected side-effects. " |
256 | "Please open an issue if you need to perform in-place operations on " |
257 | "the input to a pack hook." ); |
258 | } |
259 | |
260 | void SavedVariable::register_hooks( |
261 | std::unique_ptr<SavedVariableHooks>&& hooks) { |
262 | TORCH_INTERNAL_ASSERT(hooks); |
263 | TORCH_CHECK( |
264 | !hooks_, |
265 | "Calling register_hooks on a saved tensor whose hooks have already been set. " |
266 | "Hint: only one pair of hooks is allowed at a time." ); |
267 | if (!data_.defined()) { |
268 | if (!was_default_constructed_) { |
269 | TORCH_CHECK( |
270 | false, |
271 | "Calling register_hooks on a saved tensor after it has been freed. " |
272 | "Saved intermediate values of the graph are freed when you call " |
273 | ".backward() or autograd.grad(). Specify retain_graph=True if you " |
274 | "need to backward through the graph a second time or if you need to " |
275 | "access saved variables after calling backward." ); |
276 | } else { |
277 | TORCH_CHECK( |
278 | false, |
279 | "Calling register_hooks on a saved tensor with value None is forbidden" ); |
280 | } |
281 | } |
282 | // If we didn't save the original variable, we already saved metadata |
283 | if (saved_original_) { |
284 | save_metadata(data_); |
285 | } |
286 | set_hooks_and_pack_data(std::move(hooks), data_); |
287 | data_.reset(); |
288 | } |
289 | |
290 | const char* ERR_BACKWARD_TWICE = |
291 | "Trying to backward through the graph a second time (or directly access saved " |
292 | "tensors after they have already been freed). Saved intermediate values " |
293 | "of the graph are freed when you call .backward() or autograd.grad(). Specify " |
294 | "retain_graph=True if you need to backward through the graph a second time or " |
295 | "if you need to access saved tensors after calling backward." ; |
296 | |
297 | } // namespace autograd |
298 | } // namespace torch |
299 | |