1 | #include <c10/util/irange.h> |
2 | #include <torch/csrc/autograd/autograd.h> |
3 | #include <torch/csrc/autograd/custom_function.h> |
4 | #include <torch/csrc/autograd/functions/accumulate_grad.h> |
5 | |
6 | #include <utility> |
7 | |
8 | namespace torch { |
9 | namespace autograd { |
10 | |
11 | VariableInfo::VariableInfo(const Variable& var) |
12 | : layout(var.layout()), |
13 | device(var.device()), |
14 | scalar_type(var.scalar_type()), |
15 | size(var.sizes().vec()), |
16 | requires_grad(var.requires_grad()), |
17 | is_empty(false) {} |
18 | |
19 | VariableInfo::VariableInfo() : requires_grad(false), is_empty(true) {} |
20 | |
21 | Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const { |
22 | if (is_empty) { |
23 | // Return undefined tensor. |
24 | return at::Tensor(); |
25 | } else { |
26 | return at::zeros( |
27 | size, at::TensorOptions(scalar_type).device(device).layout(layout)); |
28 | } |
29 | } |
30 | |
31 | // This function has two main goals: |
32 | // 1) Use the user-provided jvp function to populate the the outputs' forward |
33 | // gradient 2) Perform error checking to ensure that view and inplace ops are |
34 | // properly handled |
35 | // |
36 | // For 1) we have to: |
37 | // - Create a variable_list of grad_inputs based on the function inputs |
38 | // - Call the user jvp function with these to get the grad_outputs |
39 | // - Set the forward grad field on each output based on these grad_outputs |
40 | // |
41 | // For 2) we want to check the following: |
42 | // - If an output is a view, then the generated forward grad must be a view as |
43 | // well and |
44 | // the output's base's forward grad must be the output's forward grad's base. |
45 | // - If an input was modified inplace (it must be an output as well) we make |
46 | // sure that its |
47 | // forward grad was also modified inplace and already present on the |
48 | // corresponding output. |
49 | void _process_forward_mode_AD( |
50 | const variable_list& inputs, |
51 | std::unordered_map<at::TensorImpl*, size_t> inputs_mapping, |
52 | const at::ArrayRef<c10::optional<Variable>> raw_outputs, |
53 | const optional_variable_list& outputs, |
54 | const std::unordered_set<at::TensorImpl*>& non_differentiable, |
55 | const std::unordered_set<at::TensorImpl*>& dirty_inputs, |
56 | _jvp_fn_t jvp_user_function) { |
57 | // TODO handle multiple levels here |
58 | uint64_t level = 0; |
59 | |
60 | const auto num_inputs = inputs.size(); |
61 | const auto num_outputs = outputs.size(); |
62 | |
63 | // The tracking info below are used to perform the view and inplace checks. |
64 | // They are lazily initialized to reduce the cost of this function in the |
65 | // common case where the user is not using forward mode AD. |
66 | variable_list input_grads; |
67 | std::vector<int64_t> grad_versions; |
68 | std::vector<at::TensorImpl*> grad_impls; |
69 | std::unordered_map<at::TensorImpl*, size_t> inputs_bases; |
70 | |
71 | auto init_tracked_info = [&]() { |
72 | input_grads.resize(num_inputs); |
73 | grad_versions.resize(num_inputs); |
74 | grad_impls.resize(num_inputs); |
75 | |
76 | for (const auto i : c10::irange(num_inputs)) { |
77 | const auto& inp = inputs[i]; |
78 | if (inp.is_view() && impl::get_view_autograd_meta(inp)->has_fw_view()) { |
79 | inputs_bases.emplace( |
80 | impl::get_view_autograd_meta(inp) |
81 | ->get_forward_view() |
82 | .base_.unsafeGetTensorImpl(), |
83 | i); |
84 | } else { |
85 | inputs_bases.emplace(inp.unsafeGetTensorImpl(), i); |
86 | } |
87 | } |
88 | }; |
89 | |
90 | bool any_input_has_grad = false; |
91 | // Extract the input's forward gradients and record any info we will need |
92 | // later |
93 | for (const auto i : c10::irange(num_inputs)) { |
94 | const auto& inp = inputs[i]; |
95 | if (!inp.defined()) { |
96 | continue; |
97 | } |
98 | const auto& fw_grad = inp._fw_grad(level); |
99 | if (fw_grad.defined()) { |
100 | if (!any_input_has_grad) { |
101 | any_input_has_grad = true; |
102 | init_tracked_info(); |
103 | } |
104 | input_grads[i] = fw_grad; |
105 | grad_versions[i] = fw_grad._version(); |
106 | grad_impls[i] = fw_grad.unsafeGetTensorImpl(); |
107 | } |
108 | } |
109 | |
110 | // If no input has forward grad, nothing to do here |
111 | if (!any_input_has_grad) { |
112 | return; |
113 | } |
114 | |
115 | torch::autograd::variable_list forward_grads; |
116 | { |
117 | at::AutoFwGradMode fw_grad_mode(false); |
118 | forward_grads = jvp_user_function(inputs, std::move(input_grads)); |
119 | } |
120 | |
121 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
122 | const auto num_forward_grads = forward_grads.size(); |
123 | // contrary to backward mode, we don't allow returning too many gradients |
124 | TORCH_CHECK( |
125 | num_forward_grads == num_outputs, |
126 | "Function's jvp returned " |
127 | "an invalid number of forward gradients (expected " , |
128 | num_outputs, |
129 | " but got " , |
130 | num_forward_grads, |
131 | ")" ); |
132 | |
133 | for (const auto i : c10::irange(num_outputs)) { |
134 | const auto& out = |
135 | outputs[i].has_value() ? outputs[i].value() : at::Tensor(); |
136 | auto out_tensor_impl = raw_outputs[i].value().unsafeGetTensorImpl(); |
137 | bool is_differentiable = |
138 | (non_differentiable.count(out_tensor_impl) == 0 && |
139 | isDifferentiableType(raw_outputs[i].value().scalar_type())); |
140 | const auto& out_grad = forward_grads[i]; |
141 | if (!out.defined() || !is_differentiable) { |
142 | TORCH_CHECK( |
143 | !out_grad.defined(), |
144 | "Function's jvp returned a gradient at position " , |
145 | i, |
146 | ", but " |
147 | " the corresponding forward output is not a differentiable Tensor." |
148 | "You should return None at that position instead." ); |
149 | continue; |
150 | } |
151 | |
152 | TORCH_INTERNAL_ASSERT(raw_outputs[i].has_value()); |
153 | bool is_input = inputs_mapping.count(out_tensor_impl) > 0; |
154 | bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; |
155 | |
156 | if (is_modified) { |
157 | TORCH_CHECK( |
158 | is_input, |
159 | "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" |
160 | " is no need to pass it to mark_dirty()." ); |
161 | auto inp_idx = inputs_mapping[out_tensor_impl]; |
162 | if (grad_impls[inp_idx]) { |
163 | // If there was already a forward grad for that input |
164 | // Just make sure that it is modified inplace and returned as-is |
165 | TORCH_CHECK( |
166 | out_grad._version() != grad_versions[inp_idx], |
167 | "An inplace custom Function is not modifying the " |
168 | "forward mode gradients inplace. If the forward is modifying an input inplace, then the jvp " |
169 | "function must modify the corresponding gradient inplace." ) |
170 | TORCH_CHECK( |
171 | out_grad.unsafeGetTensorImpl() == grad_impls[inp_idx], |
172 | "An inplace custom Function is not returning the " |
173 | "forward mode gradients as-is. If the forward is modifying an input inplace, then the jvp " |
174 | "function must modify the gradient inplace and return it as-is." ) |
175 | } else { |
176 | // If that Tensor didn't had gradients already, set the newly returned |
177 | // one We could also use inputs[inp_idx] here as it is the same as out |
178 | out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); |
179 | } |
180 | } else { |
181 | // At this point, outputs[i] cannot be one of the input (raw_outputs[i] |
182 | // might be but was changed by the backward code) |
183 | TORCH_INTERNAL_ASSERT( |
184 | inputs_mapping.count(out.unsafeGetTensorImpl()) == 0); |
185 | |
186 | if (out.is_view() && impl::get_view_autograd_meta(out)->has_fw_view()) { |
187 | // If the output is a view |
188 | const auto& out_view_info = |
189 | impl::get_view_autograd_meta(out)->get_forward_view(); |
190 | if (inputs_bases.count(out_view_info.base_.unsafeGetTensorImpl())) { |
191 | // And it is a view of an input (either that input is its base or they |
192 | // have a common base) |
193 | const auto matching_input_idx = |
194 | inputs_bases[out_view_info.base_.unsafeGetTensorImpl()]; |
195 | const auto& matching_input = inputs[matching_input_idx]; |
196 | |
197 | const auto& matching_input_grad = matching_input._fw_grad(level); |
198 | |
199 | // If the matching input has a forward grad, the user should have |
200 | // returned a view of that Tensor |
201 | if (matching_input_grad.defined()) { |
202 | TORCH_CHECK( |
203 | out_grad.is_view() && |
204 | impl::get_view_autograd_meta(out_grad)->has_fw_view(), |
205 | "A custom Function's forward is returning a view (or an input as-is) but the jvp is not " |
206 | "returning a view." ); |
207 | const auto& out_grad_base = impl::get_view_autograd_meta(out_grad) |
208 | ->get_forward_view() |
209 | .base_; |
210 | if (matching_input_grad.is_view() && |
211 | impl::get_view_autograd_meta(matching_input_grad) |
212 | ->has_fw_view()) { |
213 | // If the matching input's grad is a view, ensure that the |
214 | // out_grad is a view of the same base |
215 | const auto& matching_input_grad_base = |
216 | impl::get_view_autograd_meta(matching_input_grad) |
217 | ->get_forward_view() |
218 | .base_; |
219 | TORCH_CHECK( |
220 | matching_input_grad_base.unsafeGetTensorImpl() == |
221 | out_grad_base.unsafeGetTensorImpl(), |
222 | "A custom Function is returning a view but the jvp is not returning a view of the same base as " |
223 | "the given grad input." ); |
224 | } else { |
225 | // If the matching input's grad is not a view, then it must be the |
226 | // output gradient's base |
227 | TORCH_CHECK( |
228 | matching_input_grad.unsafeGetTensorImpl() == |
229 | out_grad_base.unsafeGetTensorImpl(), |
230 | "A custom Function is returning a view but the jvp is not returning a view of the given grad input." ); |
231 | } |
232 | } else { |
233 | // We have a view op where the input didn't have a forward grad but |
234 | // the user returned one for the output To ensure that we maintain |
235 | // the view/inplace constraints, we consider this as an inplace op |
236 | // This case CANNOT happen in codegen as all view ops are mapping |
237 | // from one Tensor to one Tensor and so the output of the view |
238 | // cannot have a forward grad if the base does not. |
239 | out._set_fw_grad(out_grad, level, /* is_inplace_op */ true); |
240 | return; |
241 | } |
242 | } |
243 | } |
244 | |
245 | out._set_fw_grad(out_grad, level, /* is_inplace_op */ false); |
246 | } |
247 | } |
248 | } |
249 | |
250 | at::Tensor _view_as_self_with_no_grad(at::Tensor self) { |
251 | // This is called below in _process_backward_mode_ad in two places: |
252 | // |
253 | // (1) An input has been returned, but it wasn't modified. Return it as a view |
254 | // so that we can attach a new grad_fn to the Variable. |
255 | // Run in no_grad mode to mimic the behavior of the forward. |
256 | // |
257 | // (2) Though it is not necessary for the purposes of attaching grad_fn, we |
258 | // also call this function when an output is non-differentiable (and does not |
259 | // require grad). to help custom forward AD UX more consistent. We'd like to |
260 | // uniformly say that returning an input as-is is treated as if |
261 | // `self.view_as(self)` were returned for that output. |
262 | // |
263 | // Alternatively, we could have not disabled forward grad while performing |
264 | // this view, but it would mean that the user defined jvp may be silently |
265 | // ignored. |
266 | at::AutoFwGradMode fw_grad_mode(false); |
267 | AutoGradMode grad_mode(false); |
268 | return self.view_as(self); |
269 | } |
270 | |
271 | optional_variable_list _process_backward_mode_ad( |
272 | const std::unordered_map<at::TensorImpl*, size_t>& inputs_mapping, |
273 | const std::unordered_set<at::TensorImpl*>& non_differentiable, |
274 | const std::unordered_set<at::TensorImpl*>& dirty_inputs, |
275 | const at::ArrayRef<c10::optional<Variable>> raw_outputs, |
276 | const std::shared_ptr<Node>& cdata) { |
277 | int num_outputs = raw_outputs.size(); |
278 | |
279 | // Sets the grad_fn and output_nr of an output Variable. |
280 | auto set_history = [&](Variable& var, |
281 | uint32_t output_nr, |
282 | bool is_input, |
283 | bool is_modified, |
284 | bool is_differentiable) { |
285 | if (!is_differentiable) { |
286 | if (!var.requires_grad()) { |
287 | if (is_input && !is_modified) { |
288 | var = _view_as_self_with_no_grad(var); |
289 | } |
290 | return; |
291 | } |
292 | // Return detached aliases of inputs, instead of changing their |
293 | // requires_grad property. |
294 | if (is_input) { |
295 | var = var.detach(); |
296 | } else if (!var.is_view()) { |
297 | var.detach_(); |
298 | } |
299 | // If var is a view of one of the inputs of the custom autograd Function, |
300 | // we don't detach it in a no_grad block. This is so that we can mimic the |
301 | // behavior of returning a view from a no_grad block: |
302 | // x = torch.randn(3, requires_grad=True) |
303 | // with torch.no_grad(): |
304 | // y = x.view(-1) |
305 | // Here, `y` requires_grad (!). |
306 | } else if (is_modified) { |
307 | if (var.is_leaf() && var.requires_grad()) { |
308 | TORCH_CHECK( |
309 | false, |
310 | "a leaf Variable that requires grad has been used in an in-place operation." ); |
311 | } |
312 | // No need to mark as modified Tensors that are not inputs. |
313 | if (!is_input) { |
314 | TORCH_WARN( |
315 | "Only input Tensors should be given to ctx.mark_dirty(). If a Tensor is not an input, there" |
316 | " is no need to pass it to mark_dirty()." ); |
317 | } |
318 | // If the input is a view, the rebase will need to rewrite the graph and |
319 | // this only works if we have a single output to this Function. |
320 | TORCH_CHECK( |
321 | !(var.is_view() && num_outputs > 1), |
322 | "If your Function modifies inplace an input that is a view" |
323 | " of another Tensor, your Function cannot return more than one Tensor. This is not supported" |
324 | " by the current autograd engine. You should either make sure the input is not a view (using" |
325 | " .clone() for example) or make your Function only return one Tensor (potentially splitting" |
326 | " it into two Functions: one doing the inplace that returns a single Tensor and a second one" |
327 | " that does the other operations). You can ask on the forum https://discuss.pytorch.org/ if" |
328 | " you need help to do this change." ); |
329 | |
330 | // If the input was modified, transplant the grad_fn in the graph: |
331 | // grad_fn <- variable <- self ==> grad_fn <- self <- variable |
332 | var.mutable_grad().reset(); |
333 | impl::clear_hooks(var); |
334 | if (auto grad_acc_fn = impl::try_get_grad_accumulator(var)) { |
335 | auto& grad_acc = dynamic_cast<AccumulateGrad&>(*grad_acc_fn); |
336 | grad_acc.variable.reset(); |
337 | } |
338 | if (cdata) { |
339 | impl::rebase_history(var, {cdata, output_nr}); |
340 | } |
341 | } else if (is_input) { |
342 | var = _view_as_self_with_no_grad(var); |
343 | impl::set_gradient_edge(var, {cdata, output_nr}); |
344 | } else if (cdata) { |
345 | impl::set_gradient_edge(var, {cdata, output_nr}); |
346 | } |
347 | }; |
348 | |
349 | optional_variable_list outputs; |
350 | std::unordered_set<at::TensorImpl*> outputs_impl; // For dirty_inputs check |
351 | outputs.reserve(num_outputs); |
352 | int num_diff_outputs = 0; |
353 | |
354 | for (const auto i : c10::irange(num_outputs)) { |
355 | // For outputs that are not tensors, put a placeholder undefined input. |
356 | if (!raw_outputs[i].has_value()) { |
357 | if (cdata) { |
358 | auto output_nr = cdata->add_input_metadata(Node::undefined_input()); |
359 | AT_ASSERT(i == (int)output_nr); |
360 | } |
361 | outputs.emplace_back(); |
362 | continue; |
363 | } |
364 | |
365 | Variable var = raw_outputs[i].value(); |
366 | |
367 | auto out_tensor_impl = var.unsafeGetTensorImpl(); |
368 | bool is_input = inputs_mapping.count(out_tensor_impl) > 0; |
369 | bool is_modified = dirty_inputs.count(out_tensor_impl) > 0; |
370 | bool is_differentiable = cdata && |
371 | non_differentiable.count(out_tensor_impl) == 0 && |
372 | isDifferentiableType(var.scalar_type()); |
373 | |
374 | if (cdata) { |
375 | auto output_nr = cdata->add_input_metadata(var); |
376 | AT_ASSERT(i == (int)output_nr); |
377 | } |
378 | set_history(var, i, is_input, is_modified, is_differentiable); |
379 | |
380 | // For deprecation cycle. Can be removed after 1.6. In the case where we |
381 | // detected a view in no grad mode during the forward, only warn the user |
382 | // (do not change the flag if we return and input that is a view as is). See |
383 | // NOTE [ View + Inplace detection ] for why we replace everything by a |
384 | // warning. |
385 | if (!(is_input && is_modified) && var.is_view()) { |
386 | // is_view() => diff_view_meta |
387 | auto diff_view_meta = impl::get_view_autograd_meta(var); |
388 | diff_view_meta->set_creation_meta(CreationMeta::IN_CUSTOM_FUNCTION); |
389 | } |
390 | |
391 | if (is_differentiable) { |
392 | ++num_diff_outputs; |
393 | } |
394 | |
395 | outputs_impl.insert(out_tensor_impl); |
396 | outputs.emplace_back(var); |
397 | } |
398 | |
399 | // If multiple differentiable outputs are returned, we do not allow views to |
400 | // be modified inplace See NOTE [ View + Inplace detection ] for more details |
401 | if (num_diff_outputs > 1) { |
402 | for (auto& var : outputs) { |
403 | if (var.has_value()) { |
404 | auto diff_view_meta = impl::get_view_autograd_meta(var.value()); |
405 | if (diff_view_meta && diff_view_meta->has_bw_view()) { |
406 | diff_view_meta->set_creation_meta(CreationMeta::MULTI_OUTPUT_NODE); |
407 | } |
408 | } |
409 | } |
410 | } |
411 | |
412 | // All the modified Tensors must be returned as is for the rewrite to be |
413 | // valid. |
414 | for (auto& dirty_input : dirty_inputs) { |
415 | TORCH_CHECK( |
416 | outputs_impl.count(dirty_input) > 0, |
417 | "Some elements marked as dirty during the forward method were not returned as output. The" |
418 | " inputs that are modified inplace must all be outputs of the Function." ); |
419 | } |
420 | |
421 | return outputs; |
422 | } |
423 | |
424 | optional_variable_list _wrap_outputs( |
425 | const variable_list& input_vars, |
426 | const std::unordered_set<at::TensorImpl*>& non_differentiable, |
427 | const std::unordered_set<at::TensorImpl*>& dirty_inputs, |
428 | const at::ArrayRef<c10::optional<Variable>> raw_outputs, |
429 | const std::shared_ptr<Node>& cdata, |
430 | _jvp_fn_t jvp_user_function) { |
431 | std::unordered_map<at::TensorImpl*, size_t> inputs_mapping; |
432 | inputs_mapping.reserve(input_vars.size()); |
433 | for (const auto i : c10::irange(input_vars.size())) { |
434 | inputs_mapping.emplace(input_vars[i].unsafeGetTensorImpl(), i); |
435 | } |
436 | |
437 | auto outputs = _process_backward_mode_ad( |
438 | inputs_mapping, non_differentiable, dirty_inputs, raw_outputs, cdata); |
439 | |
440 | // This must happen after the backward processing as we expect the |
441 | // computations happening here to track backward mode gradients. |
442 | _process_forward_mode_AD( |
443 | input_vars, |
444 | std::move(inputs_mapping), |
445 | raw_outputs, |
446 | outputs, |
447 | non_differentiable, |
448 | dirty_inputs, |
449 | std::move(jvp_user_function)); |
450 | |
451 | return outputs; |
452 | } |
453 | |
454 | void check_variable_result( |
455 | const at::TensorBase& original, |
456 | const at::TensorBase& result, |
457 | std::string hook_name) { |
458 | if (!original.options().type_equal(result.options())) { |
459 | std::stringstream ss; |
460 | ss << "hook '" << hook_name << "' has changed the type of value (" ; |
461 | ss << "was " << original.toString() << " got " ; |
462 | ss << result.toString() << ")" ; |
463 | throw std::runtime_error(ss.str()); |
464 | } |
465 | |
466 | if (original.is_cuda() != result.is_cuda()) { |
467 | std::stringstream ss; |
468 | ss << "hook '" << hook_name << "' has changed the type of value" ; |
469 | if (original.is_cuda()) { |
470 | ss << " (was CUDA tensor got CPU tensor)" ; |
471 | } else { |
472 | ss << " (was CPU tensor got CUDA tensor)" ; |
473 | } |
474 | throw std::runtime_error(ss.str()); |
475 | } |
476 | |
477 | if (original.sizes().vec() != result.sizes().vec()) { |
478 | std::stringstream ss; |
479 | ss << "hook '" << hook_name << "' has changed the size of value" ; |
480 | throw std::runtime_error(ss.str()); |
481 | } |
482 | } |
483 | |
484 | void AutogradContext::save_for_backward(variable_list to_save) { |
485 | to_save_ = std::move(to_save); |
486 | } |
487 | |
488 | // The logic for handling saved variables here is the same as |
489 | // python_function.cpp See _save_variables() and unpack_saved_variables() |
490 | void AutogradContext::save_variables() { |
491 | saved_variables_.clear(); |
492 | auto ptr = grad_fn_.lock(); |
493 | |
494 | for (const auto& var : to_save_) { |
495 | // Allow empty variables to be saved |
496 | if (var.defined()) { |
497 | bool is_output = var.grad_fn().get() == ptr.get(); |
498 | saved_variables_.emplace_back(var, is_output); |
499 | } else { |
500 | saved_variables_.emplace_back(); |
501 | } |
502 | } |
503 | to_save_.clear(); |
504 | } |
505 | |
506 | variable_list AutogradContext::get_saved_variables() const { |
507 | TORCH_CHECK(!has_freed_buffers_, ERR_BACKWARD_TWICE); |
508 | variable_list saved; |
509 | saved.reserve(saved_variables_.size()); |
510 | auto ptr = grad_fn_.lock(); |
511 | TORCH_INTERNAL_ASSERT(ptr); |
512 | for (auto& var : saved_variables_) { |
513 | saved.push_back(var.unpack(ptr)); |
514 | } |
515 | return saved; |
516 | } |
517 | |
518 | bool AutogradContext::needs_input_grad(size_t output_edge_index) const { |
519 | auto ptr = grad_fn_.lock(); |
520 | TORCH_INTERNAL_ASSERT(ptr); |
521 | return ptr->task_should_compute_output(output_edge_index); |
522 | } |
523 | |
524 | bool AutogradContext::needs_input_grad( |
525 | std::initializer_list<IndexRange> idxs) const { |
526 | auto ptr = grad_fn_.lock(); |
527 | TORCH_INTERNAL_ASSERT(ptr); |
528 | return ptr->task_should_compute_output(idxs); |
529 | } |
530 | |
531 | void AutogradContext::mark_dirty(const variable_list& inputs) { |
532 | dirty_inputs_.clear(); |
533 | dirty_inputs_.reserve(inputs.size()); |
534 | for (auto& var : inputs) { |
535 | dirty_inputs_.insert(var.unsafeGetTensorImpl()); |
536 | } |
537 | } |
538 | |
539 | void AutogradContext::mark_non_differentiable(const variable_list& outputs) { |
540 | non_differentiable_.clear(); |
541 | non_differentiable_.reserve(outputs.size()); |
542 | for (auto& var : outputs) { |
543 | non_differentiable_.insert(var.unsafeGetTensorImpl()); |
544 | } |
545 | } |
546 | |
547 | void AutogradContext::set_materialize_grads(bool value) { |
548 | materialize_grads_ = value; |
549 | } |
550 | |
551 | const std::unordered_set<at::TensorImpl*>& AutogradContext::get_and_bump_dirty() |
552 | const { |
553 | for (auto& var : dirty_inputs_) { |
554 | var->bump_version(); |
555 | } |
556 | return dirty_inputs_; |
557 | } |
558 | |
559 | const std::unordered_set<at::TensorImpl*>& AutogradContext:: |
560 | get_non_differentiable() const { |
561 | return non_differentiable_; |
562 | } |
563 | } // namespace autograd |
564 | } // namespace torch |
565 | |