1 | #include <torch/csrc/autograd/python_function.h> |
2 | |
3 | #include <ATen/ATen.h> |
4 | #include <ATen/SequenceNumber.h> |
5 | #include <c10/util/irange.h> |
6 | #include <pybind11/pybind11.h> |
7 | #include <structmember.h> |
8 | #include <torch/csrc/python_headers.h> |
9 | #include <torch/csrc/utils/pybind.h> |
10 | |
11 | #include <ATen/FuncTorchTLS.h> |
12 | #include <ATen/functorch/DynamicLayer.h> |
13 | #include <torch/csrc/DynamicTypes.h> |
14 | #include <torch/csrc/Exceptions.h> |
15 | #include <torch/csrc/THP.h> |
16 | #include <torch/csrc/autograd/functions/accumulate_grad.h> |
17 | #include <torch/csrc/autograd/functions/basic_ops.h> |
18 | #include <torch/csrc/autograd/functions/utils.h> |
19 | #include <torch/csrc/autograd/grad_mode.h> |
20 | #include <torch/csrc/autograd/graph_task.h> |
21 | #include <torch/csrc/autograd/python_anomaly_mode.h> |
22 | #include <torch/csrc/autograd/python_cpp_function.h> |
23 | #include <torch/csrc/autograd/python_hook.h> |
24 | #include <torch/csrc/autograd/saved_variable.h> |
25 | #include <torch/csrc/jit/frontend/tracer.h> |
26 | #include <torch/csrc/jit/ir/ir.h> |
27 | #include <torch/csrc/jit/python/pybind_utils.h> |
28 | #include <torch/csrc/jit/python/python_tracer.h> |
29 | #include <torch/csrc/utils/python_strings.h> |
30 | |
31 | #include <exception> |
32 | #include <functional> |
33 | #include <memory> |
34 | #include <stdexcept> |
35 | #include <string> |
36 | #include <tuple> |
37 | #include <unordered_map> |
38 | #include <unordered_set> |
39 | #include <utility> |
40 | #include <vector> |
41 | |
42 | using namespace torch; |
43 | using namespace torch::autograd; |
44 | using at::Tensor; |
45 | |
46 | PyObject* THPFunctionClass = nullptr; |
47 | |
48 | #define THPFunction_assert(condition, ...) \ |
49 | if (!(condition)) { \ |
50 | THPUtils_setError(__VA_ARGS__); \ |
51 | throw python_error(); \ |
52 | } |
53 | |
54 | // Anonymous namespace for helpful functions used in this file |
55 | namespace { |
56 | |
57 | // Throw a python_error with the PyErr state persisted, so that we |
58 | // don't lose the error state if the GIL is released when we don't |
59 | // have a PyThreadState created beforehand, this is made so that |
60 | // even for pure C++ thread without a pre-created PyThreadState could |
61 | // also capture the correct error message. |
62 | // TODO: This is a temporary approach to allow C++ thread to correctly |
63 | // capture Python Error in autograd, remove this when c10 thread pool |
64 | // allow to do one time initialization. |
65 | // see discussion in https://github.com/pytorch/pytorch/pull/34845 |
66 | // Follow up issue: https://github.com/pytorch/pytorch/issues/35006 |
67 | void throw_python_error() { |
68 | python_error err; |
69 | err.persist(); |
70 | throw err; |
71 | } |
72 | |
73 | } // namespace |
74 | |
75 | namespace torch { |
76 | namespace autograd { |
77 | |
78 | // NOTE: this function is written in a way that assumes it's only called for |
79 | // backward; it's used by engine.cpp. This is responsible for forwarding a call |
80 | // from C++'s Node::apply to a Python method "apply". |
81 | auto PyNode::apply(variable_list&& inputs) -> variable_list { |
82 | pybind11::gil_scoped_acquire gil; |
83 | at::OptionalDeviceGuard _device_guard; |
84 | THPFunction* py_fn = (THPFunction*)obj; |
85 | |
86 | // Massage a C++ variable_list into a Python arguments tuple |
87 | auto num_inputs = inputs.size(); |
88 | THPObjectPtr pyInputs(PyTuple_New(num_inputs)); |
89 | if (!pyInputs) |
90 | throw_python_error(); |
91 | auto& output_info = py_fn->output_info; |
92 | for (const auto i : c10::irange(num_inputs)) { |
93 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
94 | PyObject* input; |
95 | if (inputs[i].defined() || !py_fn->materialize_grads) { |
96 | input = THPVariable_Wrap(inputs[i]); |
97 | } else { |
98 | input = THPVariable_Wrap(output_info[i].zeros(_device_guard)); |
99 | } |
100 | if (!input) |
101 | throw_python_error(); |
102 | PyTuple_SET_ITEM(pyInputs.get(), i, input); |
103 | } |
104 | |
105 | THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply" )); |
106 | if (!apply_fn) |
107 | throw_python_error(); |
108 | THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get())); |
109 | if (!r) |
110 | throw_python_error(); |
111 | ensure_tuple(r); |
112 | |
113 | auto& is_variable_input = py_fn->is_variable_input; |
114 | int num_outputs = PyTuple_GET_SIZE(r.get()); |
115 | int num_forward_inputs = is_variable_input.size(); |
116 | // Returning too many results is ok, but only as long as they're all None. |
117 | // Truncate the result tuple in that case. |
118 | if (num_outputs > num_forward_inputs) { |
119 | bool all_none = true; |
120 | for (const auto i : c10::irange(num_forward_inputs, num_outputs)) { |
121 | all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None; |
122 | } |
123 | if (all_none) { |
124 | num_outputs = num_forward_inputs; |
125 | r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs); |
126 | if (!r) |
127 | throw_python_error(); |
128 | } |
129 | } |
130 | |
131 | // Now the number of gradients should match |
132 | if (num_outputs != num_forward_inputs) { |
133 | std::string msg("function " ); |
134 | msg += name() + " returned an incorrect number of gradients (expected " ; |
135 | msg += std::to_string(num_forward_inputs) + ", got " ; |
136 | msg += std::to_string(num_outputs) + ")" ; |
137 | throw std::runtime_error(msg); |
138 | } |
139 | |
140 | // Massage the Python results tuple back into a C++ variable_list |
141 | variable_list results; |
142 | results.reserve(num_outputs); |
143 | for (int i = 0; i != num_outputs; ++i) { |
144 | PyObject* output = PyTuple_GET_ITEM(r.get(), i); |
145 | bool was_variable = is_variable_input[i]; |
146 | if (!was_variable) { |
147 | if (output != Py_None) { |
148 | std::string msg("function " ); |
149 | msg += name() + " returned a gradient different than None at position " ; |
150 | msg += std::to_string(i + 1) + |
151 | ", but the corresponding forward input was not a Variable" ; |
152 | throw std::runtime_error(msg); |
153 | } |
154 | continue; |
155 | } |
156 | if (output == Py_None) { |
157 | results.emplace_back(); |
158 | } else { |
159 | if (!THPVariable_Check(output)) { |
160 | std::string msg("expected Variable or None (got " ); |
161 | msg += THPUtils_typename(output); |
162 | msg += ")" ; |
163 | throw std::runtime_error(msg); |
164 | } |
165 | results.emplace_back(THPVariable_Unpack(output)); |
166 | } |
167 | } |
168 | |
169 | return results; |
170 | } |
171 | |
172 | auto PyNode::is_traceable() -> bool { |
173 | pybind11::gil_scoped_acquire gil; |
174 | THPObjectPtr forward_class{PyObject_GetAttrString(obj, "_forward_cls" )}; |
175 | if (!forward_class) |
176 | throw_python_error(); |
177 | THPObjectPtr traceable_py_bool{ |
178 | PyObject_GetAttrString(forward_class, "is_traceable" )}; |
179 | if (!traceable_py_bool) |
180 | throw_python_error(); |
181 | return traceable_py_bool == Py_True; |
182 | } |
183 | |
184 | auto PyNode::release_variables() -> void { |
185 | // This function is called as part of the Node destructor! |
186 | // Since this object might be kept alive by C++, it is possible |
187 | // that the python interpreter is already dead here. In that case |
188 | // we just leak the saved objects. |
189 | if (Py_IsInitialized()) { |
190 | pybind11::gil_scoped_acquire gil; |
191 | auto f = (THPFunction*)obj; |
192 | f->saved_variables.clear(); |
193 | f->has_freed_buffers = 1; |
194 | } |
195 | } |
196 | |
197 | auto PyNode::name() const -> std::string { |
198 | pybind11::gil_scoped_acquire gil; |
199 | auto f = (THPFunction*)obj; |
200 | auto name = std::string(Py_TYPE(f)->tp_name); |
201 | return name; |
202 | } |
203 | |
204 | } // namespace autograd |
205 | } // namespace torch |
206 | |
207 | // Traverse and clear are required for supporting Python's GC cycle handling. |
208 | static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) { |
209 | // cdata could be null if the PyNode has already gone out of scope |
210 | // by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn |
211 | // only). |
212 | // |
213 | // TODO: I'm not really sure if we're actually obligated to traverse PyObject |
214 | // that is stored in PyNode, since we don't really own that C++ object. |
215 | if (auto cdata = self->cdata.lock()) { |
216 | for (const auto& hook : cdata->tensor_pre_hooks()) { |
217 | if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) { |
218 | Py_VISIT(pyhook->dict); |
219 | } |
220 | } |
221 | // See NOTE [retains_grad_hook PyObject traversal] |
222 | for (const auto& pair : cdata->retains_grad_hooks()) { |
223 | if (auto pyhook = |
224 | dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) { |
225 | Py_VISIT(pyhook->dict); |
226 | } |
227 | } |
228 | for (const auto& hook : cdata->pre_hooks()) { |
229 | if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) { |
230 | Py_VISIT(pyhook->dict); |
231 | } |
232 | } |
233 | for (const auto& hook : cdata->post_hooks()) { |
234 | if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) { |
235 | Py_VISIT(pyhook->dict); |
236 | } |
237 | } |
238 | } |
239 | Py_VISIT(self->to_save); |
240 | Py_VISIT(self->non_differentiable); |
241 | Py_VISIT(self->dirty_tensors); |
242 | Py_VISIT(self->saved_for_forward); |
243 | return 0; |
244 | } |
245 | |
246 | static int THPFunction_clear(THPFunction* self) { |
247 | // Note that the cdata might not be expired yet in the case where this |
248 | // object is part of a cycle and the GC happens to tp_clear this PyObject |
249 | // before the other ones that trigger the de-allocation of the cdata |
250 | |
251 | Py_CLEAR(self->needs_input_grad); |
252 | |
253 | Py_CLEAR(self->to_save); |
254 | Py_CLEAR(self->non_differentiable); |
255 | Py_CLEAR(self->dirty_tensors); |
256 | Py_CLEAR(self->saved_for_forward); |
257 | |
258 | self->output_info.clear(); |
259 | self->input_info.clear(); |
260 | self->saved_variables.clear(); |
261 | self->is_variable_input.clear(); |
262 | |
263 | return 0; |
264 | } |
265 | |
266 | static void THPFunction_dealloc(THPFunction* self) { |
267 | // Why is this guaranteed to be true? Suppose that self->cdata is non-null |
268 | // (otherwise the condition is trivially true). Then there is a PyNode |
269 | // which contains an owning reference to this object. But we are only |
270 | // allowed to clear if all owning references are gone! Contradiction. |
271 | // |
272 | // However, note that THPFunction_clear is typically called in the shared_ptr |
273 | // destructor of PyNode; in that case, per |
274 | // https://cplusplus.github.io/LWG/lwg-active.html#2751 it's not currently |
275 | // specified in the standard that this is guaranteed. If you see this |
276 | // assert triggering in the wild, feel free to comment it out. They're |
277 | // likely to standardize that you ARE guaranteed to see the weak pointers |
278 | // as expired in the destructor in the future, so we'll keep this for now. |
279 | TORCH_INTERNAL_ASSERT(self->cdata.expired()); |
280 | |
281 | PyObject_GC_UnTrack(self); |
282 | THPFunction_clear(self); |
283 | self->cdata.~weak_ptr<PyNode>(); |
284 | self->output_info.~vector(); |
285 | self->input_info.~vector(); |
286 | self->saved_variables.~vector(); |
287 | self->is_variable_input.~vector(); |
288 | Py_TYPE(self)->tp_free((PyObject*)self); |
289 | } |
290 | |
291 | PyObject* THPFunction_new( |
292 | PyTypeObject* type, |
293 | PyObject* args, |
294 | PyObject* kwargs) { |
295 | PyObject* obj = type->tp_alloc(type, 0); |
296 | if (!obj) |
297 | return nullptr; |
298 | // Python zero-initializes the object memory, so there's no need to initialize |
299 | // most fields |
300 | THPFunction* self = (THPFunction*)obj; |
301 | // Setup the PyNode later; we can't keep it live here |
302 | new (&self->cdata) std::weak_ptr<PyNode>(); |
303 | new (&self->output_info) std::vector<VariableInfo>(); |
304 | new (&self->input_info) std::vector<VariableInfo>(); |
305 | new (&self->saved_variables) std::vector<SavedVariable>(); |
306 | new (&self->is_variable_input) std::vector<bool>(); |
307 | self->materialize_grads = true; |
308 | return obj; |
309 | } |
310 | |
311 | //////////////////////////////////////////////////////////////////////////////// |
312 | // Forward |
313 | //////////////////////////////////////////////////////////////////////////////// |
314 | |
315 | // Bump the counters of all recorded dirty input tensors, adding each of them |
316 | // into dirty_inputs. Also does some sanity checking. |
317 | static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction* self) { |
318 | // Increase versions of modified tensors |
319 | std::unordered_set<at::TensorImpl*> dirty_inputs; |
320 | if (!self->dirty_tensors) |
321 | return dirty_inputs; |
322 | |
323 | THPFunction_assert( |
324 | PyTuple_Check(self->dirty_tensors), |
325 | "autograd " |
326 | "internal error: dirty_tensors attribute is expected to be a tuple " |
327 | "but is %s" , |
328 | THPUtils_typename(self->dirty_tensors)); |
329 | Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors); |
330 | dirty_inputs.reserve(num_dirty); |
331 | for (const auto i : c10::irange(num_dirty)) { |
332 | PyObject* obj = PyTuple_GET_ITEM(self->dirty_tensors, i); |
333 | THPFunction_assert( |
334 | THPVariable_Check(obj), |
335 | "mark_dirty can " |
336 | "only accept variables, but argument %d is of type %s" , |
337 | i, |
338 | THPUtils_typename(obj)); |
339 | |
340 | const auto& tensor = THPVariable_Unpack(obj); |
341 | dirty_inputs.insert(tensor.unsafeGetTensorImpl()); |
342 | torch::autograd::impl::bump_version(tensor); |
343 | } |
344 | // We're not going to ever need this so let's remove references now |
345 | Py_CLEAR(self->dirty_tensors); |
346 | return dirty_inputs; |
347 | } |
348 | |
349 | static std::unordered_set<at::TensorImpl*> _parse_non_differentiable( |
350 | THPFunction* self); |
351 | |
352 | // Given a Python tuple of raw output tensors (raw_output), set each of |
353 | // the corresponding entries in a different Python tuple (outputs) with |
354 | // these tensors wrapped with variables. We save the gradient function (self) |
355 | // to the variable if the output requires grad. |
356 | // |
357 | // There is a considerable amount of complexity to handle if the operation |
358 | // that produced these output tensors is inplace. A mapping of *input* |
359 | // tensors to variables (t2var) is used to test if this occurred, and |
360 | // the set of dirty tensors (dirty_inputs) is used to figure out what to |
361 | // do in this case. After this method is run, t2var is extended with |
362 | // mappings for output tensors as well. |
363 | static void _wrap_outputs( |
364 | const std::shared_ptr<PyNode>& cdata, |
365 | THPFunction* self, |
366 | const variable_list& input_vars, |
367 | PyObject* raw_output, |
368 | PyObject* outputs, |
369 | bool is_executable) { |
370 | auto cdata_if_executable = is_executable ? cdata : nullptr; |
371 | Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output); |
372 | if (is_executable) { |
373 | self->output_info.clear(); |
374 | self->output_info.reserve(num_outputs); |
375 | } |
376 | |
377 | auto non_differentiable = _parse_non_differentiable(self); |
378 | auto dirty_inputs = _mark_dirty(self); |
379 | |
380 | std::vector<c10::optional<Variable>> raw_output_vars; |
381 | raw_output_vars.reserve(num_outputs); |
382 | for (const auto i : c10::irange(num_outputs)) { |
383 | PyObject* obj = PyTuple_GET_ITEM(raw_output, i); |
384 | // Only process tensors as outputs for autograd purposes. |
385 | if (THPVariable_Check(obj)) { |
386 | raw_output_vars.emplace_back(THPVariable_Unpack(obj)); |
387 | } else { |
388 | raw_output_vars.emplace_back(); |
389 | } |
390 | } |
391 | |
392 | _jvp_fn_t jvp_user_function = [self]( |
393 | variable_list inputs, |
394 | variable_list grad_inputs) { |
395 | pybind11::gil_scoped_acquire gil; |
396 | |
397 | // Massage a C++ variable_list into a Python arguments tuple |
398 | // Making sure to introduce the proper None for non-Tensor inputs |
399 | auto num_inputs = self->is_variable_input.size(); |
400 | THPObjectPtr pyInputs(PyTuple_New(num_inputs)); |
401 | if (!pyInputs) |
402 | throw_python_error(); |
403 | int64_t variable_idx = 0; |
404 | for (const auto i : c10::irange(num_inputs)) { |
405 | PyObject* input = nullptr; |
406 | if (self->is_variable_input[i]) { |
407 | if (grad_inputs[variable_idx].defined() || !self->materialize_grads || |
408 | !isDifferentiableType(inputs[variable_idx].scalar_type())) { |
409 | input = THPVariable_Wrap(grad_inputs[variable_idx]); |
410 | } else { |
411 | input = THPVariable_Wrap(at::zeros_like(inputs[variable_idx])); |
412 | } |
413 | if (!input) { |
414 | throw_python_error(); |
415 | } |
416 | variable_idx++; |
417 | } else { |
418 | Py_INCREF(Py_None); |
419 | input = Py_None; |
420 | } |
421 | PyTuple_SET_ITEM(pyInputs.get(), i, input); |
422 | } |
423 | |
424 | THPObjectPtr apply_jvp_fn( |
425 | PyObject_GetAttrString((PyObject*)self, "apply_jvp" )); |
426 | if (!apply_jvp_fn) |
427 | throw_python_error(); |
428 | THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get())); |
429 | if (!r) |
430 | throw_python_error(); |
431 | ensure_tuple(r); |
432 | |
433 | // Massage the Python results tuple back into a C++ variable_list |
434 | // Don't do any check on the number of results here as |
435 | // it is handled by the caller |
436 | const int num_outputs = PyTuple_GET_SIZE(r.get()); |
437 | variable_list results; |
438 | results.reserve(num_outputs); |
439 | for (const auto i : c10::irange(num_outputs)) { |
440 | PyObject* output = PyTuple_GET_ITEM(r.get(), i); |
441 | if (output == Py_None) { |
442 | results.emplace_back(); |
443 | } else { |
444 | TORCH_CHECK( |
445 | THPVariable_Check(output), |
446 | "expected Variable or None (got " , |
447 | THPUtils_typename(output), |
448 | ") for grad output " , |
449 | i, |
450 | "." ) |
451 | results.emplace_back(THPVariable_Unpack(output)); |
452 | } |
453 | } |
454 | |
455 | return results; |
456 | }; |
457 | |
458 | // Wrap only the tensor outputs. |
459 | auto wrapped_outputs = _wrap_outputs( |
460 | input_vars, |
461 | non_differentiable, |
462 | dirty_inputs, |
463 | raw_output_vars, |
464 | cdata_if_executable, |
465 | std::move(jvp_user_function)); |
466 | |
467 | for (const auto i : c10::irange(num_outputs)) { |
468 | PyObject* obj = PyTuple_GetItem(raw_output, i); |
469 | // Keep the non-tensor outputs as is. |
470 | if (!THPVariable_Check(obj)) { |
471 | if (is_executable) { |
472 | self->output_info.emplace_back(); |
473 | } |
474 | Py_INCREF(obj); |
475 | PyTuple_SetItem(outputs, i, obj); |
476 | } else { |
477 | if (is_executable) { |
478 | self->output_info.emplace_back(*wrapped_outputs[i]); |
479 | } |
480 | PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i])); |
481 | } |
482 | } |
483 | } |
484 | |
485 | // Save any variables that requested by to_save |
486 | static void _save_variables( |
487 | const std::shared_ptr<PyNode>& cdata_ptr, |
488 | THPFunction* self) { |
489 | if (!self->to_save) |
490 | return; |
491 | |
492 | THPFunction_assert( |
493 | PyTuple_Check(self->to_save), |
494 | "autograd internal " |
495 | "error: to_save attribute is expected to be a tuple but is %s" , |
496 | THPUtils_typename(self->to_save)); |
497 | Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save); |
498 | self->saved_variables.clear(); |
499 | self->saved_variables.reserve(num_saved); |
500 | for (const auto i : c10::irange(num_saved)) { |
501 | PyObject* obj = PyTuple_GET_ITEM(self->to_save, i); |
502 | if (obj == Py_None) { |
503 | self->saved_variables.emplace_back(); |
504 | continue; |
505 | } else if (THPVariable_Check(obj)) { |
506 | const auto& tensor = THPVariable_Unpack(obj); |
507 | bool is_output = tensor.grad_fn().get() == cdata_ptr.get(); |
508 | self->saved_variables.emplace_back(tensor, is_output); |
509 | } else { |
510 | throw torch::TypeError( |
511 | "save_for_backward can only save variables, but argument %ld is of " |
512 | "type %s" , |
513 | i, |
514 | Py_TYPE(obj)->tp_name); |
515 | } |
516 | } |
517 | // Free .to_save |
518 | Py_CLEAR(self->to_save); |
519 | } |
520 | |
521 | // Mark requires_grad = 0 on non-differentiable variables (as per |
522 | // non_differentiable) |
523 | static std::unordered_set<at::TensorImpl*> _parse_non_differentiable( |
524 | THPFunction* self) { |
525 | std::unordered_set<at::TensorImpl*> set; |
526 | if (!self->non_differentiable) |
527 | return set; |
528 | |
529 | THPFunction_assert( |
530 | PyTuple_Check(self->non_differentiable), |
531 | "autograd " |
532 | "internal error: non_differentiable attribute is expected to be a " |
533 | "tuple but is %s" , |
534 | THPUtils_typename(self->non_differentiable)); |
535 | Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable); |
536 | set.reserve(num_nondiff); |
537 | for (const auto i : c10::irange(num_nondiff)) { |
538 | PyObject* t = PyTuple_GET_ITEM(self->non_differentiable, i); |
539 | THPFunction_assert( |
540 | THPVariable_Check(t), |
541 | "mark_non_differentiable " |
542 | "only accepts variable arguments, but got %s" , |
543 | THPUtils_typename(t)); |
544 | set.insert(THPVariable_Unpack(t).unsafeGetTensorImpl()); |
545 | } |
546 | Py_CLEAR(self->non_differentiable); |
547 | return set; |
548 | } |
549 | |
550 | struct UnpackedInput { |
551 | THPObjectPtr input_tuple; |
552 | variable_list input_vars; |
553 | }; |
554 | |
555 | struct InputFlags { |
556 | bool is_executable = false; |
557 | edge_list next_edges; |
558 | THPObjectPtr needs_input_grad; |
559 | std::vector<bool> is_variable_input; |
560 | }; |
561 | |
562 | template <bool enforce_variables> |
563 | std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) { |
564 | UnpackedInput unpacked; |
565 | InputFlags flags; |
566 | |
567 | auto num_args = PyTuple_GET_SIZE(args); |
568 | unpacked.input_tuple = PyTuple_New(num_args); |
569 | flags.needs_input_grad = PyTuple_New(num_args); |
570 | for (const auto i : c10::irange(num_args)) { |
571 | PyObject* arg = PyTuple_GET_ITEM(args, i); |
572 | |
573 | bool is_variable = THPVariable_Check(arg); |
574 | flags.is_variable_input.push_back(is_variable); |
575 | if (!is_variable) { |
576 | // TODO: remove this code path once Variable and Tensor are merged in |
577 | // Python |
578 | if (enforce_variables) { |
579 | THPUtils_setError( |
580 | "expected a Tensor argument, but got %s" , THPUtils_typename(arg)); |
581 | throw python_error(); |
582 | } |
583 | Py_INCREF(Py_False); |
584 | PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False); |
585 | } else { |
586 | const auto& tensor = THPVariable_Unpack(arg); |
587 | unpacked.input_vars.push_back(tensor); |
588 | PyObject* needs_grad = tensor.requires_grad() ? Py_True : Py_False; |
589 | Py_INCREF(needs_grad); |
590 | PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad); |
591 | } |
592 | Py_INCREF(arg); |
593 | PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg); |
594 | } |
595 | |
596 | flags.is_executable = |
597 | GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars); |
598 | flags.next_edges = |
599 | (flags.is_executable ? collect_next_edges(unpacked.input_vars) |
600 | : edge_list()); |
601 | return std::make_pair(std::move(unpacked), std::move(flags)); |
602 | } |
603 | |
604 | // Given a prim::PythonOp node, _append_subgraph creates a subgraph such that: |
605 | // (1) It has the same inputs as the prim::PythonOp node |
606 | // (2) The intermediate nodes used in the PythonOp are cloned and stored in the |
607 | // subgraph (3) trace_outputs stores the Value* objects, before a new trace |
608 | // value is assigned by the prim::PythonOp node and helps to eventually route |
609 | // the outputs of the subgraph correctly This newly created subgraph is then |
610 | // added to the prim::PythonOp node as a subgraph attribute |
611 | static void _append_subgraph( |
612 | torch::jit::Node* node, |
613 | torch::jit::Graph* graph, |
614 | std::vector<torch::jit::Value*> trace_outputs, |
615 | bool unpack_output) { |
616 | using Value = torch::jit::Value; |
617 | node->g_( |
618 | torch::jit::attr::Subgraph, |
619 | std::make_shared<torch::jit::Graph>(graph->current_scope())); |
620 | auto subgraph = node->g(torch::jit::attr::Subgraph); |
621 | |
622 | std::unordered_map<Value*, Value*> value_map; |
623 | auto value_map_func = [&](Value* v) { return value_map.at(v); }; |
624 | for (size_t i = 0; i < node->inputs().size(); ++i) { |
625 | auto subgraph_input = subgraph->addInput(); |
626 | subgraph_input->copyMetadata(node->inputs().at(i)); |
627 | value_map[node->inputs().at(i)] = subgraph_input; |
628 | } |
629 | // Find node position in owning block, all subsequent nodes after are added to |
630 | // subgraph |
631 | auto owning_block = node->owningBlock(); |
632 | auto it = std::find( |
633 | owning_block->nodes().begin(), owning_block->nodes().end(), node); |
634 | // Skip TupleUnpack node if created |
635 | if (!unpack_output) { |
636 | it++; |
637 | } |
638 | for (it++; it != owning_block->nodes().end(); ++it) { |
639 | torch::jit::Node* node = *it; |
640 | auto* clone_node = |
641 | subgraph->insertNode(subgraph->createClone(node, value_map_func)); |
642 | for (size_t i = 0; i < node->outputs().size(); ++i) { |
643 | value_map[node->outputs()[i]] = clone_node->outputs()[i]; |
644 | auto trace_it = std::find( |
645 | trace_outputs.begin(), trace_outputs.end(), node->outputs()[i]); |
646 | if (trace_it != trace_outputs.end()) { |
647 | subgraph->registerOutput(clone_node->outputs()[i]); |
648 | } |
649 | } |
650 | } |
651 | } |
652 | |
653 | static torch::jit::Node* _trace_pre_record( |
654 | PyObject* op_obj, |
655 | PyObject* input_objects, |
656 | const variable_list& input_vars) { |
657 | if (!jit::tracer::isTracing()) { |
658 | return nullptr; |
659 | } |
660 | |
661 | // Save scalar args and the calling convention |
662 | auto num_args = PyTuple_GET_SIZE(input_objects); |
663 | pyobj_list scalar_args; |
664 | std::string arg_types; |
665 | arg_types.reserve(num_args); |
666 | scalar_args.reserve(num_args); |
667 | for (const auto i : c10::irange(num_args)) { |
668 | PyObject* arg_object = PyTuple_GET_ITEM(input_objects, i); |
669 | if (THPVariable_Check(arg_object)) { |
670 | arg_types.push_back('d'); |
671 | } else { |
672 | arg_types.push_back('c'); |
673 | Py_INCREF(arg_object); |
674 | scalar_args.emplace_back(arg_object); |
675 | } |
676 | } |
677 | |
678 | Py_INCREF(op_obj); |
679 | auto pyobj = THPObjectPtr(op_obj); |
680 | return jit::tracer::preRecordPythonTrace( |
681 | std::move(pyobj), arg_types, input_vars, std::move(scalar_args)); |
682 | } |
683 | |
684 | static void _trace_post_record( |
685 | torch::jit::Node* node, |
686 | PyObject* op_obj, |
687 | const variable_list& input_vars, |
688 | PyObject* output_objects, |
689 | bool is_inplace, |
690 | bool unpack_output) { |
691 | if (!jit::tracer::isTracing()) { |
692 | return; |
693 | } |
694 | |
695 | node->i_(jit::attr::inplace, is_inplace); |
696 | if (PyObject* module_name = PyDict_GetItemString( |
697 | ((PyTypeObject*)op_obj)->tp_dict, "__module__" )) { |
698 | if (auto ptr = PyUnicode_AsUTF8(module_name)) { |
699 | node->s_(jit::attr::module, std::string(ptr)); |
700 | } |
701 | } |
702 | |
703 | // Isolate C variable ptrs in a vector |
704 | int num_outputs = PyTuple_GET_SIZE(output_objects); |
705 | auto graph = node->owningGraph(); |
706 | node->addOutput(); |
707 | auto old_node = node; |
708 | if (!unpack_output) { |
709 | std::vector<at::TypePtr> tuple_values(num_outputs, at::TensorType::get()); |
710 | auto tuple_type = at::TupleType::create(std::move(tuple_values)); |
711 | // Original type is tuple of tensors "without" element type and shape. |
712 | // The missed parts will be added below. |
713 | node->output()->setType(std::move(tuple_type)); |
714 | auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node); |
715 | node = unpacked; |
716 | } |
717 | |
718 | std::vector<torch::jit::Value*> trace_outputs; |
719 | for (const auto i : c10::irange(num_outputs)) { |
720 | PyObject* obj = PyTuple_GET_ITEM(output_objects, i); |
721 | if (THPVariable_Check(obj)) { |
722 | auto value = node->outputs()[i]; |
723 | const auto& tensor = THPVariable_Unpack(obj); |
724 | if (tensor.defined()) { |
725 | value->inferTypeFrom(tensor); |
726 | trace_outputs.push_back(jit::tracer::getValueTrace(tensor)); |
727 | jit::tracer::setValueTrace(tensor, value); |
728 | } |
729 | } |
730 | } |
731 | py::bool_ is_in_onnx_export = |
732 | py::module::import("torch.onnx.__init__" ).attr("is_in_onnx_export" ); |
733 | if (py::cast<bool>(is_in_onnx_export)) { |
734 | _append_subgraph(old_node, graph, std::move(trace_outputs), unpack_output); |
735 | } |
736 | |
737 | // If TupleUnpack operator is created, we copy its output type back |
738 | // to the original tuple type. |
739 | if (!unpack_output) { |
740 | std::vector<at::TypePtr> new_tuple_values; |
741 | for (const auto i : c10::irange(num_outputs)) { |
742 | auto ptr = node->outputs()[i]->type(); |
743 | new_tuple_values.push_back(ptr); |
744 | } |
745 | auto tuple_type = at::TupleType::create(std::move(new_tuple_values)); |
746 | // The i-th tuple element receives a new tensor type with element type and |
747 | // shape. |
748 | old_node->output()->setType(std::move(tuple_type)); |
749 | } |
750 | } |
751 | |
752 | PyObject* process_outputs( |
753 | PyObject* op_obj, |
754 | const std::shared_ptr<PyNode>& cdata, |
755 | THPFunction* grad_fn, |
756 | const UnpackedInput& unpacked, |
757 | PyObject* inputs, |
758 | THPObjectPtr&& raw_output, |
759 | bool is_executable, |
760 | torch::jit::Node* node) { |
761 | bool unpack_output = ensure_tuple(raw_output); |
762 | |
763 | auto num_outputs = PyTuple_GET_SIZE(raw_output.get()); |
764 | |
765 | THPObjectPtr outputs(PyTuple_New(num_outputs)); |
766 | if (!outputs) |
767 | throw python_error(); |
768 | |
769 | cdata->clear_input_metadata(); |
770 | |
771 | // Record type, device, and size information about inputs |
772 | if (is_executable) { |
773 | grad_fn->input_info.clear(); |
774 | grad_fn->input_info.reserve(unpacked.input_vars.size()); |
775 | for (auto& var : unpacked.input_vars) { |
776 | grad_fn->input_info.emplace_back(var); |
777 | } |
778 | } |
779 | |
780 | bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors); |
781 | _wrap_outputs( |
782 | cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable); |
783 | _trace_post_record( |
784 | node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output); |
785 | |
786 | // It is important that creating the SavedVariables happen after the output |
787 | // wrapping as the outputs must have their grad_fn/fw_grad properly set before |
788 | // we save them. |
789 | if (is_executable) { |
790 | _save_variables(cdata, grad_fn); |
791 | } else { |
792 | // Remove unnecessary attributes |
793 | Py_XDECREF(grad_fn->to_save); |
794 | grad_fn->to_save = nullptr; |
795 | Py_XDECREF(grad_fn->non_differentiable); |
796 | grad_fn->non_differentiable = nullptr; |
797 | } |
798 | |
799 | Py_XDECREF(grad_fn->saved_for_forward); |
800 | grad_fn->saved_for_forward = nullptr; |
801 | |
802 | // Unpack the output, unless .forward() returned a tuple |
803 | if (unpack_output) { |
804 | PyObject* output = PyTuple_GET_ITEM(outputs.get(), 0); |
805 | Py_INCREF(output); |
806 | return output; |
807 | } |
808 | |
809 | return outputs.release(); |
810 | } |
811 | |
812 | PyObject* THPFunction_name(PyObject* self, PyObject* noargs) { |
813 | HANDLE_TH_ERRORS |
814 | auto cdata = ((THPFunction*)self)->cdata.lock(); |
815 | TORCH_CHECK( |
816 | cdata, |
817 | "Attribute 'name' is invalid for this instance of _C._FunctionBase. " |
818 | "Accessing this attribute directly on an instance of autograd.Function is a legacy " |
819 | "access pattern that is no longer supported. For examples on how to use new-style " |
820 | "autograd functions, see " |
821 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ); |
822 | return THPUtils_packString(cdata->name()); |
823 | END_HANDLE_TH_ERRORS |
824 | } |
825 | |
826 | PyObject* THPFunction_maybe_clear_saved_tensors( |
827 | PyObject* self, |
828 | PyObject* noargs) { |
829 | HANDLE_TH_ERRORS; |
830 | auto cdata = ((THPFunction*)self)->cdata.lock(); |
831 | if (!get_current_graph_task_keep_graph()) { |
832 | cdata->release_variables(); |
833 | } |
834 | Py_RETURN_NONE; |
835 | END_HANDLE_TH_ERRORS |
836 | } |
837 | |
838 | namespace { |
839 | |
840 | THPObjectPtr make_ctx_input_tuple( |
841 | THPFunction* ctx, |
842 | const UnpackedInput& unpacked_input, |
843 | int64_t num_args) { |
844 | THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1)); |
845 | if (!ctx_input_tuple) |
846 | return {}; |
847 | Py_INCREF(ctx); |
848 | PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx); |
849 | for (const auto i : c10::irange(num_args)) { |
850 | PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i); |
851 | Py_INCREF(arg); |
852 | PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg); |
853 | } |
854 | return ctx_input_tuple; |
855 | } |
856 | |
857 | THPObjectPtr make_ctx_input_output_tuple( |
858 | THPFunction* ctx, |
859 | UnpackedInput& unpacked_input, |
860 | PyObject* output) { |
861 | THPObjectPtr result(PyTuple_New(3)); |
862 | if (!result) |
863 | return {}; |
864 | Py_INCREF(ctx); |
865 | Py_INCREF(unpacked_input.input_tuple.get()); |
866 | Py_INCREF(output); |
867 | PyTuple_SET_ITEM(result.get(), 0, (PyObject*)ctx); |
868 | PyTuple_SET_ITEM(result.get(), 1, unpacked_input.input_tuple.get()); |
869 | PyTuple_SET_ITEM(result.get(), 2, output); |
870 | return result; |
871 | } |
872 | |
873 | } // namespace |
874 | |
875 | static PyObject* THPFunction_setup_context = nullptr; |
876 | |
877 | static PyObject* get_base_setup_context() { |
878 | if (THPFunction_setup_context != nullptr) { |
879 | return THPFunction_setup_context; |
880 | } |
881 | |
882 | auto module = THPObjectPtr(PyImport_ImportModule("torch.autograd.function" )); |
883 | if (!module) |
884 | return nullptr; |
885 | |
886 | auto function = |
887 | THPObjectPtr(PyObject_GetAttrString(module, "_SingleLevelFunction" )); |
888 | if (!function) |
889 | return nullptr; |
890 | |
891 | // setup_context gets "leaked" - we return a new reference and hold onto it |
892 | // forever. |
893 | auto setup_context = PyObject_GetAttrString(function, "setup_context" ); |
894 | if (!setup_context) |
895 | return nullptr; |
896 | THPFunction_setup_context = setup_context; |
897 | return THPFunction_setup_context; |
898 | } |
899 | |
900 | PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) { |
901 | HANDLE_TH_ERRORS |
902 | |
903 | // save a local copy of seq_id before it gets incremented |
904 | int seq_id = at::sequence_number::peek(); |
905 | auto info_pair = unpack_input<false>(inputs); |
906 | UnpackedInput& unpacked_input = info_pair.first; |
907 | InputFlags& input_info = info_pair.second; |
908 | |
909 | // Call record function after all the inputs have been decoded, but |
910 | // before context has been allocated. |
911 | RECORD_FUNCTION( |
912 | ((PyTypeObject*)cls)->tp_name, |
913 | std::vector<c10::IValue>( |
914 | unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()), |
915 | seq_id); |
916 | |
917 | const auto& functorch_tls = at::functorch::functorchTLSAccessor(); |
918 | if (functorch_tls) { |
919 | // autograd.Function support for functorch is handled in Python. |
920 | // If we have gotten here, then either we are dealing with a |
921 | // torch.autograd.function._SingleLevelFunction, or something in |
922 | // the implementation went wrong. |
923 | // The following code is useful for debugging when something goes wrong |
924 | // because it'll raise a loud error (instead of being silently incorrect). |
925 | functorch_tls->checkSupportsSingleLevelAutogradFunction(); |
926 | } |
927 | |
928 | THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls" )); |
929 | if (!backward_cls) |
930 | return nullptr; |
931 | THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr)); |
932 | if (!ctx_obj) |
933 | return nullptr; |
934 | THPFunction* ctx = (THPFunction*)ctx_obj.get(); |
935 | |
936 | auto cdata = |
937 | std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode); |
938 | ctx->cdata = cdata; |
939 | |
940 | // Record input nodes if tracing |
941 | auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars); |
942 | |
943 | // Initialize backward function (and ctx) |
944 | bool is_executable = input_info.is_executable; |
945 | cdata->set_next_edges(std::move(input_info.next_edges)); |
946 | ctx->needs_input_grad = input_info.needs_input_grad.release(); |
947 | ctx->is_variable_input = std::move(input_info.is_variable_input); |
948 | |
949 | // autograd.Function may optionally override a setup_context staticmethod. |
950 | // In this case, autograd.Function.forward does NOT accept a ctx object. |
951 | // Determine if this is the case. |
952 | auto cls_setup_context = |
953 | THPObjectPtr(PyObject_GetAttrString(cls, "setup_context" )); |
954 | if (!cls_setup_context) { |
955 | return nullptr; |
956 | } |
957 | auto orig_setup_context = get_base_setup_context(); |
958 | if (!orig_setup_context) { |
959 | return nullptr; |
960 | } |
961 | auto overridden_setup_context = cls_setup_context.get() != orig_setup_context; |
962 | |
963 | auto num_args = PyTuple_GET_SIZE(inputs); |
964 | |
965 | // Call forward |
966 | THPObjectPtr output; |
967 | { |
968 | AutoGradMode grad_mode(false); |
969 | at::AutoFwGradMode fw_grad_mode(false); |
970 | THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward" )); |
971 | if (!forward_fn) |
972 | return nullptr; |
973 | if (overridden_setup_context) { |
974 | // call forward followed by setup_context |
975 | output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple); |
976 | if (!output) { |
977 | return nullptr; |
978 | } |
979 | // signature is setup_context(ctx, inputs, output) |
980 | auto ctx_input_output_tuple = |
981 | make_ctx_input_output_tuple(ctx, unpacked_input, output); |
982 | if (!ctx_input_output_tuple) { |
983 | return nullptr; |
984 | } |
985 | THPObjectPtr setup_context_fn( |
986 | PyObject_GetAttrString(cls, "setup_context" )); |
987 | auto result = |
988 | PyObject_CallObject(setup_context_fn, ctx_input_output_tuple); |
989 | if (!result) { |
990 | return nullptr; |
991 | } |
992 | } else { |
993 | // call forward |
994 | auto ctx_input_tuple = |
995 | make_ctx_input_tuple(ctx, unpacked_input, num_args); |
996 | if (!ctx_input_tuple) { |
997 | return nullptr; |
998 | } |
999 | output = PyObject_CallObject(forward_fn, ctx_input_tuple); |
1000 | } |
1001 | if (!output) |
1002 | return nullptr; |
1003 | } |
1004 | |
1005 | return process_outputs( |
1006 | cls, |
1007 | cdata, |
1008 | ctx, |
1009 | unpacked_input, |
1010 | inputs, |
1011 | std::move(output), |
1012 | is_executable, |
1013 | node); |
1014 | END_HANDLE_TH_ERRORS |
1015 | } |
1016 | |
1017 | //////////////////////////////////////////////////////////////////////////////// |
1018 | // Other methods / attributes |
1019 | //////////////////////////////////////////////////////////////////////////////// |
1020 | |
1021 | PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) { |
1022 | HANDLE_TH_ERRORS |
1023 | THPUtils_assert( |
1024 | THPVariable_Check(_var), "_register_hook_dict expected a Tensor" ); |
1025 | THPVariable* var = reinterpret_cast<THPVariable*>(_var); |
1026 | const auto& tensor = THPVariable_Unpack(var); |
1027 | std::unique_ptr<FunctionPreHook> hook( |
1028 | new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr())); |
1029 | auto self = (THPFunction*)_self; |
1030 | auto cdata = self->cdata.lock(); |
1031 | TORCH_CHECK( |
1032 | cdata, |
1033 | "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. " |
1034 | "Accessing this attribute directly on an instance of autograd.Function is a legacy " |
1035 | "access pattern that is no longer supported. For examples on how to use new-style " |
1036 | "autograd functions, see " |
1037 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ); |
1038 | cdata->add_tensor_pre_hook(std::move(hook)); |
1039 | Py_RETURN_NONE; |
1040 | END_HANDLE_TH_ERRORS |
1041 | } |
1042 | |
1043 | PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) { |
1044 | HANDLE_TH_ERRORS |
1045 | auto self = (THPFunction*)_self; |
1046 | auto cdata = self->cdata.lock(); |
1047 | TORCH_CHECK( |
1048 | cdata, |
1049 | "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. " |
1050 | "Accessing this attribute directly on an instance of autograd.Function is a legacy " |
1051 | "access pattern that is no longer supported. For examples on how to use new-style " |
1052 | "autograd functions, see " |
1053 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ); |
1054 | return torch::autograd::registerFunctionHook(*cdata, hook); |
1055 | END_HANDLE_TH_ERRORS |
1056 | } |
1057 | |
1058 | PyObject* THPFunction_register_prehook(PyObject* _self, PyObject* hook) { |
1059 | HANDLE_TH_ERRORS |
1060 | auto self = (THPFunction*)_self; |
1061 | auto cdata = self->cdata.lock(); |
1062 | TORCH_CHECK( |
1063 | cdata, |
1064 | "Attribute 'register_prehook' is invalid for this instance of _C._FunctionBase. " |
1065 | "Accessing this attribute directly on an instance of autograd.Function is a legacy " |
1066 | "access pattern that is no longer supported. For examples on how to use new-style " |
1067 | "autograd functions, see " |
1068 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ); |
1069 | return torch::autograd::registerFunctionPreHook(*cdata, hook); |
1070 | END_HANDLE_TH_ERRORS |
1071 | } |
1072 | |
1073 | int THPFunction_set_materialize_grads( |
1074 | THPFunction* self, |
1075 | PyObject* value, |
1076 | void* unused) { |
1077 | HANDLE_TH_ERRORS |
1078 | if (!PyBool_Check(value)) { |
1079 | THPUtils_invalidArguments( |
1080 | value, nullptr, "set_materialize_grads" , 1, "(bool)" ); |
1081 | return -1; |
1082 | } |
1083 | self->materialize_grads = (value == Py_True); |
1084 | return 0; |
1085 | END_HANDLE_TH_ERRORS_RET(-1) |
1086 | } |
1087 | |
1088 | static PyObject* unpack_saved_variables( |
1089 | THPFunction* self, |
1090 | const std::function<PyObject*(const Variable&)>& unpack_fn) { |
1091 | THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE); |
1092 | auto& saved_variables = self->saved_variables; |
1093 | if (saved_variables.empty()) |
1094 | return PyTuple_New(0); |
1095 | |
1096 | int num_saved = saved_variables.size(); |
1097 | THPObjectPtr saved(PyTuple_New(num_saved)); |
1098 | if (!saved) |
1099 | return nullptr; |
1100 | auto saved_for = self->cdata.lock(); |
1101 | // This is really a true assert, because we've already tested for the |
1102 | // self->has_freed_buffers case at the beginning of this function: |
1103 | // buffers are freed when PyNode dies; if the buffers are not freed, |
1104 | // PyNode must be live. (Note that the buffers could be freed |
1105 | // even though the PyNode is live, but that doesn't matter here |
1106 | // because we will never hit this line of code if the buffers are freed-- |
1107 | // and in any case saved_for will be non-NULL.) |
1108 | TORCH_INTERNAL_ASSERT(saved_for); |
1109 | for (const auto i : c10::irange(num_saved)) { |
1110 | auto unpacked_var = saved_variables[i].unpack(saved_for); |
1111 | THPObjectPtr value; |
1112 | if (!unpacked_var.defined()) { |
1113 | Py_INCREF(Py_None); |
1114 | value = Py_None; |
1115 | } else { |
1116 | value = unpack_fn(unpacked_var); |
1117 | } |
1118 | PyTuple_SET_ITEM(saved.get(), i, value.release()); |
1119 | } |
1120 | return saved.release(); |
1121 | } |
1122 | |
1123 | PyObject* THPFunction_saved_tensors(THPFunction* self, void* _unused) { |
1124 | HANDLE_TH_ERRORS |
1125 | if (self->saved_for_forward) { |
1126 | Py_INCREF(self->saved_for_forward); |
1127 | return self->saved_for_forward; |
1128 | } else { |
1129 | return unpack_saved_variables( |
1130 | self, [](const Variable& var) { return THPVariable_Wrap(var); }); |
1131 | } |
1132 | END_HANDLE_TH_ERRORS |
1133 | } |
1134 | |
1135 | PyObject* THPFunction_saved_variables(THPFunction* self, void* _unused) { |
1136 | HANDLE_TH_ERRORS |
1137 | auto r = PyErr_WarnEx( |
1138 | PyExc_DeprecationWarning, |
1139 | "'saved_variables' is deprecated; use 'saved_tensors'" , |
1140 | 0); |
1141 | if (r != 0) |
1142 | throw python_error(); |
1143 | return unpack_saved_variables( |
1144 | self, [](const Variable& var) { return THPVariable_Wrap(var); }); |
1145 | END_HANDLE_TH_ERRORS |
1146 | } |
1147 | |
1148 | PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) { |
1149 | HANDLE_TH_ERRORS |
1150 | // User tries to access saved variables after they have been freed |
1151 | THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE); |
1152 | const auto& saved_variables = self->saved_variables; |
1153 | if (saved_variables.empty()) |
1154 | return PyTuple_New(0); |
1155 | size_t num_saved = saved_variables.size(); |
1156 | THPObjectPtr saved(PyTuple_New(num_saved)); |
1157 | if (!saved) { |
1158 | return nullptr; |
1159 | } |
1160 | for (const auto i : c10::irange(num_saved)) { |
1161 | py::object obj = |
1162 | py::cast(saved_variables[i], py::return_value_policy::reference); |
1163 | PyTuple_SET_ITEM(saved.get(), i, obj.release().ptr()); |
1164 | } |
1165 | return saved.release(); |
1166 | END_HANDLE_TH_ERRORS |
1167 | } |
1168 | |
1169 | PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) { |
1170 | HANDLE_TH_ERRORS |
1171 | auto cdata = self->cdata.lock(); |
1172 | TORCH_CHECK( |
1173 | cdata, |
1174 | "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. " |
1175 | "Accessing this attribute directly on an instance of autograd.Function is a legacy " |
1176 | "access pattern that is no longer supported. For examples on how to use new-style " |
1177 | "autograd functions, see " |
1178 | "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function " ); |
1179 | const auto num_outputs = cdata->num_outputs(); |
1180 | THPObjectPtr result(PyTuple_New(num_outputs)); |
1181 | if (!result) |
1182 | return nullptr; |
1183 | for (const auto i : c10::irange(num_outputs)) { |
1184 | THPObjectPtr fn_tuple(PyTuple_New(2)); |
1185 | if (!fn_tuple) |
1186 | return nullptr; |
1187 | const auto& edge = cdata->next_edge(i); |
1188 | PyObject* fn = functionToPyObject(edge.function); |
1189 | if (!fn) |
1190 | return nullptr; |
1191 | PyTuple_SET_ITEM(fn_tuple.get(), 0, fn); |
1192 | PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr)); |
1193 | PyTuple_SET_ITEM(result.get(), i, fn_tuple.release()); |
1194 | } |
1195 | return result.release(); |
1196 | END_HANDLE_TH_ERRORS |
1197 | } |
1198 | |
1199 | PyObject* THPFunction_metadata(THPFunction* self, void* _unused) { |
1200 | HANDLE_TH_ERRORS |
1201 | auto cdata = self->cdata.lock(); |
1202 | // The correct way to solve this problem is to stop exposing grad_fn |
1203 | // of PyFunctions as THPFunction; instead, we should use THPCppFunction |
1204 | // like everyone else. But this is a BC-breaking change as it would |
1205 | // mean that you no longer get the property that grad_fn is a subclass |
1206 | // of the autograd function class that you defined in the custom case, |
1207 | // so I didn't fix it here. |
1208 | TORCH_CHECK( |
1209 | cdata, |
1210 | "You attempted to access the anomaly metadata of a custom autograd function " |
1211 | "but the underlying PyNode has already been deallocated. The most likely " |
1212 | "reason this occurred is because you assigned x.grad_fn to a local variable " |
1213 | "and then let the original variable get deallocated. Don't do that! If " |
1214 | "you really have no way of restructuring your code so this is the case, " |
1215 | "please file an issue reporting that you are affected by this." ); |
1216 | auto metadata = static_cast<PyAnomalyMetadata*>(cdata->metadata())->dict(); |
1217 | |
1218 | Py_INCREF(metadata); |
1219 | return metadata; |
1220 | END_HANDLE_TH_ERRORS |
1221 | } |
1222 | |
1223 | typedef PyObject* (*getter)(PyObject*, void*); |
1224 | typedef int (*setter)(PyObject*, PyObject*, void*); |
1225 | |
1226 | namespace { |
1227 | |
1228 | template <PyObject* THPFunction::*ptr> |
1229 | PyObject* getObject(PyObject* obj, void* _unused) { |
1230 | auto self = (THPFunction*)obj; |
1231 | PyObject* value = self->*ptr; |
1232 | if (!value) { |
1233 | Py_RETURN_NONE; |
1234 | } |
1235 | Py_INCREF(value); |
1236 | return value; |
1237 | } |
1238 | |
1239 | template <PyObject* THPFunction::*ptr> |
1240 | int setObject(PyObject* obj, PyObject* value, void* _unused) { |
1241 | auto self = (THPFunction*)obj; |
1242 | if (value == Py_None) { |
1243 | value = nullptr; |
1244 | } |
1245 | Py_XDECREF((self->*ptr)); |
1246 | Py_XINCREF(value); |
1247 | self->*ptr = value; |
1248 | return 0; |
1249 | } |
1250 | |
1251 | template <typename M, M THPFunction::*ptr, PyObject* (*Convert)(long)> |
1252 | PyObject* getMember(PyObject* obj, void* _unused) { |
1253 | auto self = (THPFunction*)obj; |
1254 | return Convert(self->*ptr); |
1255 | } |
1256 | |
1257 | template <typename M, M autograd::Node::*ptr, PyObject* (*Convert)(long)> |
1258 | PyObject* getImplMember(PyObject* obj, void* _unused) { |
1259 | auto self = (THPFunction*)obj; |
1260 | return Convert(self->cdata.*ptr); |
1261 | } |
1262 | |
1263 | PyObject* getRequiresGrad(PyObject* obj, void* _unused) { |
1264 | Py_RETURN_TRUE; |
1265 | } |
1266 | |
1267 | } // namespace |
1268 | |
1269 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
1270 | static struct PyGetSetDef THPFunction_properties[] = { |
1271 | {"saved_tensors" , |
1272 | (getter)THPFunction_saved_tensors, |
1273 | nullptr, |
1274 | nullptr, |
1275 | nullptr}, |
1276 | {"saved_variables" , |
1277 | (getter)THPFunction_saved_variables, |
1278 | nullptr, |
1279 | nullptr, |
1280 | nullptr}, |
1281 | {"_raw_saved_tensors" , |
1282 | (getter)THPFunction_raw_saved_tensors, |
1283 | nullptr, |
1284 | nullptr, |
1285 | nullptr}, |
1286 | {"next_functions" , |
1287 | (getter)THPFunction_next_functions, |
1288 | nullptr, |
1289 | nullptr, |
1290 | nullptr}, |
1291 | {"to_save" , |
1292 | &getObject<&THPFunction::to_save>, |
1293 | &setObject<&THPFunction::to_save>, |
1294 | nullptr, |
1295 | nullptr}, |
1296 | {"non_differentiable" , |
1297 | &getObject<&THPFunction::non_differentiable>, |
1298 | &setObject<&THPFunction::non_differentiable>, |
1299 | nullptr, |
1300 | nullptr}, |
1301 | {"dirty_tensors" , |
1302 | &getObject<&THPFunction::dirty_tensors>, |
1303 | &setObject<&THPFunction::dirty_tensors>, |
1304 | nullptr, |
1305 | nullptr}, |
1306 | {"saved_for_forward" , |
1307 | &getObject<&THPFunction::saved_for_forward>, |
1308 | &setObject<&THPFunction::saved_for_forward>, |
1309 | nullptr, |
1310 | nullptr}, |
1311 | {"needs_input_grad" , |
1312 | &getObject<&THPFunction::needs_input_grad>, |
1313 | nullptr, |
1314 | nullptr, |
1315 | nullptr}, |
1316 | {"requires_grad" , getRequiresGrad, nullptr, nullptr, nullptr}, |
1317 | {"metadata" , (getter)THPFunction_metadata, nullptr, nullptr, nullptr}, |
1318 | {"materialize_grads" , |
1319 | nullptr, |
1320 | (setter)THPFunction_set_materialize_grads, |
1321 | nullptr, |
1322 | nullptr}, |
1323 | {nullptr}}; |
1324 | |
1325 | // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables) |
1326 | static struct PyMethodDef THPFunction_methods[] = { |
1327 | {(char*)"name" , THPFunction_name, METH_NOARGS, nullptr}, |
1328 | {(char*)"maybe_clear_saved_tensors" , |
1329 | THPFunction_maybe_clear_saved_tensors, |
1330 | METH_NOARGS, |
1331 | nullptr}, |
1332 | {(char*)"apply" , THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr}, |
1333 | {(char*)"_register_hook_dict" , |
1334 | THPFunction__register_hook_dict, |
1335 | METH_O, |
1336 | nullptr}, |
1337 | {(char*)"register_hook" , THPFunction_register_hook, METH_O, nullptr}, |
1338 | {(char*)"register_prehook" , THPFunction_register_prehook, METH_O, nullptr}, |
1339 | {nullptr}}; |
1340 | |
1341 | PyTypeObject THPFunctionType = { |
1342 | PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._FunctionBase" , /* tp_name */ |
1343 | sizeof(THPFunction), /* tp_basicsize */ |
1344 | 0, /* tp_itemsize */ |
1345 | (destructor)THPFunction_dealloc, /* tp_dealloc */ |
1346 | 0, /* tp_vectorcall_offset */ |
1347 | nullptr, /* tp_getattr */ |
1348 | nullptr, /* tp_setattr */ |
1349 | nullptr, /* tp_reserved */ |
1350 | nullptr, /* tp_repr */ |
1351 | nullptr, /* tp_as_number */ |
1352 | nullptr, /* tp_as_sequence */ |
1353 | nullptr, /* tp_as_mapping */ |
1354 | nullptr, /* tp_hash */ |
1355 | nullptr, /* tp_call */ |
1356 | nullptr, /* tp_str */ |
1357 | nullptr, /* tp_getattro */ |
1358 | nullptr, /* tp_setattro */ |
1359 | nullptr, /* tp_as_buffer */ |
1360 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | |
1361 | Py_TPFLAGS_HAVE_GC, /* tp_flags */ |
1362 | nullptr, /* tp_doc */ |
1363 | (traverseproc)THPFunction_traverse, /* tp_traverse */ |
1364 | (inquiry)THPFunction_clear, /* tp_clear */ |
1365 | nullptr, /* tp_richcompare */ |
1366 | 0, /* tp_weaklistoffset */ |
1367 | nullptr, /* tp_iter */ |
1368 | nullptr, /* tp_iternext */ |
1369 | THPFunction_methods, /* tp_methods */ |
1370 | nullptr, /* tp_members */ |
1371 | THPFunction_properties, /* tp_getset */ |
1372 | nullptr, /* tp_base */ |
1373 | nullptr, /* tp_dict */ |
1374 | nullptr, /* tp_descr_get */ |
1375 | nullptr, /* tp_descr_set */ |
1376 | 0, /* tp_dictoffset */ |
1377 | nullptr, /* tp_init */ |
1378 | nullptr, /* tp_alloc */ |
1379 | THPFunction_new /* tp_new */ |
1380 | }; |
1381 | |
1382 | bool THPFunction_initModule(PyObject* module) { |
1383 | if (PyType_Ready(&THPFunctionType) < 0) |
1384 | return false; |
1385 | Py_INCREF(&THPFunctionType); |
1386 | PyModule_AddObject(module, "_FunctionBase" , (PyObject*)&THPFunctionType); |
1387 | return true; |
1388 | } |
1389 | |