1#include <torch/csrc/autograd/python_function.h>
2
3#include <ATen/ATen.h>
4#include <ATen/SequenceNumber.h>
5#include <c10/util/irange.h>
6#include <pybind11/pybind11.h>
7#include <structmember.h>
8#include <torch/csrc/python_headers.h>
9#include <torch/csrc/utils/pybind.h>
10
11#include <ATen/FuncTorchTLS.h>
12#include <ATen/functorch/DynamicLayer.h>
13#include <torch/csrc/DynamicTypes.h>
14#include <torch/csrc/Exceptions.h>
15#include <torch/csrc/THP.h>
16#include <torch/csrc/autograd/functions/accumulate_grad.h>
17#include <torch/csrc/autograd/functions/basic_ops.h>
18#include <torch/csrc/autograd/functions/utils.h>
19#include <torch/csrc/autograd/grad_mode.h>
20#include <torch/csrc/autograd/graph_task.h>
21#include <torch/csrc/autograd/python_anomaly_mode.h>
22#include <torch/csrc/autograd/python_cpp_function.h>
23#include <torch/csrc/autograd/python_hook.h>
24#include <torch/csrc/autograd/saved_variable.h>
25#include <torch/csrc/jit/frontend/tracer.h>
26#include <torch/csrc/jit/ir/ir.h>
27#include <torch/csrc/jit/python/pybind_utils.h>
28#include <torch/csrc/jit/python/python_tracer.h>
29#include <torch/csrc/utils/python_strings.h>
30
31#include <exception>
32#include <functional>
33#include <memory>
34#include <stdexcept>
35#include <string>
36#include <tuple>
37#include <unordered_map>
38#include <unordered_set>
39#include <utility>
40#include <vector>
41
42using namespace torch;
43using namespace torch::autograd;
44using at::Tensor;
45
46PyObject* THPFunctionClass = nullptr;
47
48#define THPFunction_assert(condition, ...) \
49 if (!(condition)) { \
50 THPUtils_setError(__VA_ARGS__); \
51 throw python_error(); \
52 }
53
54// Anonymous namespace for helpful functions used in this file
55namespace {
56
57// Throw a python_error with the PyErr state persisted, so that we
58// don't lose the error state if the GIL is released when we don't
59// have a PyThreadState created beforehand, this is made so that
60// even for pure C++ thread without a pre-created PyThreadState could
61// also capture the correct error message.
62// TODO: This is a temporary approach to allow C++ thread to correctly
63// capture Python Error in autograd, remove this when c10 thread pool
64// allow to do one time initialization.
65// see discussion in https://github.com/pytorch/pytorch/pull/34845
66// Follow up issue: https://github.com/pytorch/pytorch/issues/35006
67void throw_python_error() {
68 python_error err;
69 err.persist();
70 throw err;
71}
72
73} // namespace
74
75namespace torch {
76namespace autograd {
77
78// NOTE: this function is written in a way that assumes it's only called for
79// backward; it's used by engine.cpp. This is responsible for forwarding a call
80// from C++'s Node::apply to a Python method "apply".
81auto PyNode::apply(variable_list&& inputs) -> variable_list {
82 pybind11::gil_scoped_acquire gil;
83 at::OptionalDeviceGuard _device_guard;
84 THPFunction* py_fn = (THPFunction*)obj;
85
86 // Massage a C++ variable_list into a Python arguments tuple
87 auto num_inputs = inputs.size();
88 THPObjectPtr pyInputs(PyTuple_New(num_inputs));
89 if (!pyInputs)
90 throw_python_error();
91 auto& output_info = py_fn->output_info;
92 for (const auto i : c10::irange(num_inputs)) {
93 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
94 PyObject* input;
95 if (inputs[i].defined() || !py_fn->materialize_grads) {
96 input = THPVariable_Wrap(inputs[i]);
97 } else {
98 input = THPVariable_Wrap(output_info[i].zeros(_device_guard));
99 }
100 if (!input)
101 throw_python_error();
102 PyTuple_SET_ITEM(pyInputs.get(), i, input);
103 }
104
105 THPObjectPtr apply_fn(PyObject_GetAttrString(obj, "apply"));
106 if (!apply_fn)
107 throw_python_error();
108 THPObjectPtr r(PyObject_CallObject(apply_fn, pyInputs.get()));
109 if (!r)
110 throw_python_error();
111 ensure_tuple(r);
112
113 auto& is_variable_input = py_fn->is_variable_input;
114 int num_outputs = PyTuple_GET_SIZE(r.get());
115 int num_forward_inputs = is_variable_input.size();
116 // Returning too many results is ok, but only as long as they're all None.
117 // Truncate the result tuple in that case.
118 if (num_outputs > num_forward_inputs) {
119 bool all_none = true;
120 for (const auto i : c10::irange(num_forward_inputs, num_outputs)) {
121 all_none &= PyTuple_GET_ITEM(r.get(), i) == Py_None;
122 }
123 if (all_none) {
124 num_outputs = num_forward_inputs;
125 r = PyTuple_GetSlice(r.get(), 0, num_forward_inputs);
126 if (!r)
127 throw_python_error();
128 }
129 }
130
131 // Now the number of gradients should match
132 if (num_outputs != num_forward_inputs) {
133 std::string msg("function ");
134 msg += name() + " returned an incorrect number of gradients (expected ";
135 msg += std::to_string(num_forward_inputs) + ", got ";
136 msg += std::to_string(num_outputs) + ")";
137 throw std::runtime_error(msg);
138 }
139
140 // Massage the Python results tuple back into a C++ variable_list
141 variable_list results;
142 results.reserve(num_outputs);
143 for (int i = 0; i != num_outputs; ++i) {
144 PyObject* output = PyTuple_GET_ITEM(r.get(), i);
145 bool was_variable = is_variable_input[i];
146 if (!was_variable) {
147 if (output != Py_None) {
148 std::string msg("function ");
149 msg += name() + " returned a gradient different than None at position ";
150 msg += std::to_string(i + 1) +
151 ", but the corresponding forward input was not a Variable";
152 throw std::runtime_error(msg);
153 }
154 continue;
155 }
156 if (output == Py_None) {
157 results.emplace_back();
158 } else {
159 if (!THPVariable_Check(output)) {
160 std::string msg("expected Variable or None (got ");
161 msg += THPUtils_typename(output);
162 msg += ")";
163 throw std::runtime_error(msg);
164 }
165 results.emplace_back(THPVariable_Unpack(output));
166 }
167 }
168
169 return results;
170}
171
172auto PyNode::is_traceable() -> bool {
173 pybind11::gil_scoped_acquire gil;
174 THPObjectPtr forward_class{PyObject_GetAttrString(obj, "_forward_cls")};
175 if (!forward_class)
176 throw_python_error();
177 THPObjectPtr traceable_py_bool{
178 PyObject_GetAttrString(forward_class, "is_traceable")};
179 if (!traceable_py_bool)
180 throw_python_error();
181 return traceable_py_bool == Py_True;
182}
183
184auto PyNode::release_variables() -> void {
185 // This function is called as part of the Node destructor!
186 // Since this object might be kept alive by C++, it is possible
187 // that the python interpreter is already dead here. In that case
188 // we just leak the saved objects.
189 if (Py_IsInitialized()) {
190 pybind11::gil_scoped_acquire gil;
191 auto f = (THPFunction*)obj;
192 f->saved_variables.clear();
193 f->has_freed_buffers = 1;
194 }
195}
196
197auto PyNode::name() const -> std::string {
198 pybind11::gil_scoped_acquire gil;
199 auto f = (THPFunction*)obj;
200 auto name = std::string(Py_TYPE(f)->tp_name);
201 return name;
202}
203
204} // namespace autograd
205} // namespace torch
206
207// Traverse and clear are required for supporting Python's GC cycle handling.
208static int THPFunction_traverse(THPFunction* self, visitproc visit, void* arg) {
209 // cdata could be null if the PyNode has already gone out of scope
210 // by the time we're GC'ing this THPFunction (e.g., the user saved grad_fn
211 // only).
212 //
213 // TODO: I'm not really sure if we're actually obligated to traverse PyObject
214 // that is stored in PyNode, since we don't really own that C++ object.
215 if (auto cdata = self->cdata.lock()) {
216 for (const auto& hook : cdata->tensor_pre_hooks()) {
217 if (auto pyhook = dynamic_cast<PyFunctionTensorPreHook*>(hook.get())) {
218 Py_VISIT(pyhook->dict);
219 }
220 }
221 // See NOTE [retains_grad_hook PyObject traversal]
222 for (const auto& pair : cdata->retains_grad_hooks()) {
223 if (auto pyhook =
224 dynamic_cast<PyFunctionTensorPreHook*>(pair.second.get())) {
225 Py_VISIT(pyhook->dict);
226 }
227 }
228 for (const auto& hook : cdata->pre_hooks()) {
229 if (auto pyhook = dynamic_cast<PyFunctionPreHook*>(hook.get())) {
230 Py_VISIT(pyhook->dict);
231 }
232 }
233 for (const auto& hook : cdata->post_hooks()) {
234 if (auto pyhook = dynamic_cast<PyFunctionPostHook*>(hook.get())) {
235 Py_VISIT(pyhook->dict);
236 }
237 }
238 }
239 Py_VISIT(self->to_save);
240 Py_VISIT(self->non_differentiable);
241 Py_VISIT(self->dirty_tensors);
242 Py_VISIT(self->saved_for_forward);
243 return 0;
244}
245
246static int THPFunction_clear(THPFunction* self) {
247 // Note that the cdata might not be expired yet in the case where this
248 // object is part of a cycle and the GC happens to tp_clear this PyObject
249 // before the other ones that trigger the de-allocation of the cdata
250
251 Py_CLEAR(self->needs_input_grad);
252
253 Py_CLEAR(self->to_save);
254 Py_CLEAR(self->non_differentiable);
255 Py_CLEAR(self->dirty_tensors);
256 Py_CLEAR(self->saved_for_forward);
257
258 self->output_info.clear();
259 self->input_info.clear();
260 self->saved_variables.clear();
261 self->is_variable_input.clear();
262
263 return 0;
264}
265
266static void THPFunction_dealloc(THPFunction* self) {
267 // Why is this guaranteed to be true? Suppose that self->cdata is non-null
268 // (otherwise the condition is trivially true). Then there is a PyNode
269 // which contains an owning reference to this object. But we are only
270 // allowed to clear if all owning references are gone! Contradiction.
271 //
272 // However, note that THPFunction_clear is typically called in the shared_ptr
273 // destructor of PyNode; in that case, per
274 // https://cplusplus.github.io/LWG/lwg-active.html#2751 it's not currently
275 // specified in the standard that this is guaranteed. If you see this
276 // assert triggering in the wild, feel free to comment it out. They're
277 // likely to standardize that you ARE guaranteed to see the weak pointers
278 // as expired in the destructor in the future, so we'll keep this for now.
279 TORCH_INTERNAL_ASSERT(self->cdata.expired());
280
281 PyObject_GC_UnTrack(self);
282 THPFunction_clear(self);
283 self->cdata.~weak_ptr<PyNode>();
284 self->output_info.~vector();
285 self->input_info.~vector();
286 self->saved_variables.~vector();
287 self->is_variable_input.~vector();
288 Py_TYPE(self)->tp_free((PyObject*)self);
289}
290
291PyObject* THPFunction_new(
292 PyTypeObject* type,
293 PyObject* args,
294 PyObject* kwargs) {
295 PyObject* obj = type->tp_alloc(type, 0);
296 if (!obj)
297 return nullptr;
298 // Python zero-initializes the object memory, so there's no need to initialize
299 // most fields
300 THPFunction* self = (THPFunction*)obj;
301 // Setup the PyNode later; we can't keep it live here
302 new (&self->cdata) std::weak_ptr<PyNode>();
303 new (&self->output_info) std::vector<VariableInfo>();
304 new (&self->input_info) std::vector<VariableInfo>();
305 new (&self->saved_variables) std::vector<SavedVariable>();
306 new (&self->is_variable_input) std::vector<bool>();
307 self->materialize_grads = true;
308 return obj;
309}
310
311////////////////////////////////////////////////////////////////////////////////
312// Forward
313////////////////////////////////////////////////////////////////////////////////
314
315// Bump the counters of all recorded dirty input tensors, adding each of them
316// into dirty_inputs. Also does some sanity checking.
317static std::unordered_set<at::TensorImpl*> _mark_dirty(THPFunction* self) {
318 // Increase versions of modified tensors
319 std::unordered_set<at::TensorImpl*> dirty_inputs;
320 if (!self->dirty_tensors)
321 return dirty_inputs;
322
323 THPFunction_assert(
324 PyTuple_Check(self->dirty_tensors),
325 "autograd "
326 "internal error: dirty_tensors attribute is expected to be a tuple "
327 "but is %s",
328 THPUtils_typename(self->dirty_tensors));
329 Py_ssize_t num_dirty = PyTuple_GET_SIZE(self->dirty_tensors);
330 dirty_inputs.reserve(num_dirty);
331 for (const auto i : c10::irange(num_dirty)) {
332 PyObject* obj = PyTuple_GET_ITEM(self->dirty_tensors, i);
333 THPFunction_assert(
334 THPVariable_Check(obj),
335 "mark_dirty can "
336 "only accept variables, but argument %d is of type %s",
337 i,
338 THPUtils_typename(obj));
339
340 const auto& tensor = THPVariable_Unpack(obj);
341 dirty_inputs.insert(tensor.unsafeGetTensorImpl());
342 torch::autograd::impl::bump_version(tensor);
343 }
344 // We're not going to ever need this so let's remove references now
345 Py_CLEAR(self->dirty_tensors);
346 return dirty_inputs;
347}
348
349static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
350 THPFunction* self);
351
352// Given a Python tuple of raw output tensors (raw_output), set each of
353// the corresponding entries in a different Python tuple (outputs) with
354// these tensors wrapped with variables. We save the gradient function (self)
355// to the variable if the output requires grad.
356//
357// There is a considerable amount of complexity to handle if the operation
358// that produced these output tensors is inplace. A mapping of *input*
359// tensors to variables (t2var) is used to test if this occurred, and
360// the set of dirty tensors (dirty_inputs) is used to figure out what to
361// do in this case. After this method is run, t2var is extended with
362// mappings for output tensors as well.
363static void _wrap_outputs(
364 const std::shared_ptr<PyNode>& cdata,
365 THPFunction* self,
366 const variable_list& input_vars,
367 PyObject* raw_output,
368 PyObject* outputs,
369 bool is_executable) {
370 auto cdata_if_executable = is_executable ? cdata : nullptr;
371 Py_ssize_t num_outputs = PyTuple_GET_SIZE(raw_output);
372 if (is_executable) {
373 self->output_info.clear();
374 self->output_info.reserve(num_outputs);
375 }
376
377 auto non_differentiable = _parse_non_differentiable(self);
378 auto dirty_inputs = _mark_dirty(self);
379
380 std::vector<c10::optional<Variable>> raw_output_vars;
381 raw_output_vars.reserve(num_outputs);
382 for (const auto i : c10::irange(num_outputs)) {
383 PyObject* obj = PyTuple_GET_ITEM(raw_output, i);
384 // Only process tensors as outputs for autograd purposes.
385 if (THPVariable_Check(obj)) {
386 raw_output_vars.emplace_back(THPVariable_Unpack(obj));
387 } else {
388 raw_output_vars.emplace_back();
389 }
390 }
391
392 _jvp_fn_t jvp_user_function = [self](
393 variable_list inputs,
394 variable_list grad_inputs) {
395 pybind11::gil_scoped_acquire gil;
396
397 // Massage a C++ variable_list into a Python arguments tuple
398 // Making sure to introduce the proper None for non-Tensor inputs
399 auto num_inputs = self->is_variable_input.size();
400 THPObjectPtr pyInputs(PyTuple_New(num_inputs));
401 if (!pyInputs)
402 throw_python_error();
403 int64_t variable_idx = 0;
404 for (const auto i : c10::irange(num_inputs)) {
405 PyObject* input = nullptr;
406 if (self->is_variable_input[i]) {
407 if (grad_inputs[variable_idx].defined() || !self->materialize_grads ||
408 !isDifferentiableType(inputs[variable_idx].scalar_type())) {
409 input = THPVariable_Wrap(grad_inputs[variable_idx]);
410 } else {
411 input = THPVariable_Wrap(at::zeros_like(inputs[variable_idx]));
412 }
413 if (!input) {
414 throw_python_error();
415 }
416 variable_idx++;
417 } else {
418 Py_INCREF(Py_None);
419 input = Py_None;
420 }
421 PyTuple_SET_ITEM(pyInputs.get(), i, input);
422 }
423
424 THPObjectPtr apply_jvp_fn(
425 PyObject_GetAttrString((PyObject*)self, "apply_jvp"));
426 if (!apply_jvp_fn)
427 throw_python_error();
428 THPObjectPtr r(PyObject_CallObject(apply_jvp_fn, pyInputs.get()));
429 if (!r)
430 throw_python_error();
431 ensure_tuple(r);
432
433 // Massage the Python results tuple back into a C++ variable_list
434 // Don't do any check on the number of results here as
435 // it is handled by the caller
436 const int num_outputs = PyTuple_GET_SIZE(r.get());
437 variable_list results;
438 results.reserve(num_outputs);
439 for (const auto i : c10::irange(num_outputs)) {
440 PyObject* output = PyTuple_GET_ITEM(r.get(), i);
441 if (output == Py_None) {
442 results.emplace_back();
443 } else {
444 TORCH_CHECK(
445 THPVariable_Check(output),
446 "expected Variable or None (got ",
447 THPUtils_typename(output),
448 ") for grad output ",
449 i,
450 ".")
451 results.emplace_back(THPVariable_Unpack(output));
452 }
453 }
454
455 return results;
456 };
457
458 // Wrap only the tensor outputs.
459 auto wrapped_outputs = _wrap_outputs(
460 input_vars,
461 non_differentiable,
462 dirty_inputs,
463 raw_output_vars,
464 cdata_if_executable,
465 std::move(jvp_user_function));
466
467 for (const auto i : c10::irange(num_outputs)) {
468 PyObject* obj = PyTuple_GetItem(raw_output, i);
469 // Keep the non-tensor outputs as is.
470 if (!THPVariable_Check(obj)) {
471 if (is_executable) {
472 self->output_info.emplace_back();
473 }
474 Py_INCREF(obj);
475 PyTuple_SetItem(outputs, i, obj);
476 } else {
477 if (is_executable) {
478 self->output_info.emplace_back(*wrapped_outputs[i]);
479 }
480 PyTuple_SetItem(outputs, i, THPVariable_Wrap(*wrapped_outputs[i]));
481 }
482 }
483}
484
485// Save any variables that requested by to_save
486static void _save_variables(
487 const std::shared_ptr<PyNode>& cdata_ptr,
488 THPFunction* self) {
489 if (!self->to_save)
490 return;
491
492 THPFunction_assert(
493 PyTuple_Check(self->to_save),
494 "autograd internal "
495 "error: to_save attribute is expected to be a tuple but is %s",
496 THPUtils_typename(self->to_save));
497 Py_ssize_t num_saved = PyTuple_GET_SIZE(self->to_save);
498 self->saved_variables.clear();
499 self->saved_variables.reserve(num_saved);
500 for (const auto i : c10::irange(num_saved)) {
501 PyObject* obj = PyTuple_GET_ITEM(self->to_save, i);
502 if (obj == Py_None) {
503 self->saved_variables.emplace_back();
504 continue;
505 } else if (THPVariable_Check(obj)) {
506 const auto& tensor = THPVariable_Unpack(obj);
507 bool is_output = tensor.grad_fn().get() == cdata_ptr.get();
508 self->saved_variables.emplace_back(tensor, is_output);
509 } else {
510 throw torch::TypeError(
511 "save_for_backward can only save variables, but argument %ld is of "
512 "type %s",
513 i,
514 Py_TYPE(obj)->tp_name);
515 }
516 }
517 // Free .to_save
518 Py_CLEAR(self->to_save);
519}
520
521// Mark requires_grad = 0 on non-differentiable variables (as per
522// non_differentiable)
523static std::unordered_set<at::TensorImpl*> _parse_non_differentiable(
524 THPFunction* self) {
525 std::unordered_set<at::TensorImpl*> set;
526 if (!self->non_differentiable)
527 return set;
528
529 THPFunction_assert(
530 PyTuple_Check(self->non_differentiable),
531 "autograd "
532 "internal error: non_differentiable attribute is expected to be a "
533 "tuple but is %s",
534 THPUtils_typename(self->non_differentiable));
535 Py_ssize_t num_nondiff = PyTuple_GET_SIZE(self->non_differentiable);
536 set.reserve(num_nondiff);
537 for (const auto i : c10::irange(num_nondiff)) {
538 PyObject* t = PyTuple_GET_ITEM(self->non_differentiable, i);
539 THPFunction_assert(
540 THPVariable_Check(t),
541 "mark_non_differentiable "
542 "only accepts variable arguments, but got %s",
543 THPUtils_typename(t));
544 set.insert(THPVariable_Unpack(t).unsafeGetTensorImpl());
545 }
546 Py_CLEAR(self->non_differentiable);
547 return set;
548}
549
550struct UnpackedInput {
551 THPObjectPtr input_tuple;
552 variable_list input_vars;
553};
554
555struct InputFlags {
556 bool is_executable = false;
557 edge_list next_edges;
558 THPObjectPtr needs_input_grad;
559 std::vector<bool> is_variable_input;
560};
561
562template <bool enforce_variables>
563std::pair<UnpackedInput, InputFlags> unpack_input(PyObject* args) {
564 UnpackedInput unpacked;
565 InputFlags flags;
566
567 auto num_args = PyTuple_GET_SIZE(args);
568 unpacked.input_tuple = PyTuple_New(num_args);
569 flags.needs_input_grad = PyTuple_New(num_args);
570 for (const auto i : c10::irange(num_args)) {
571 PyObject* arg = PyTuple_GET_ITEM(args, i);
572
573 bool is_variable = THPVariable_Check(arg);
574 flags.is_variable_input.push_back(is_variable);
575 if (!is_variable) {
576 // TODO: remove this code path once Variable and Tensor are merged in
577 // Python
578 if (enforce_variables) {
579 THPUtils_setError(
580 "expected a Tensor argument, but got %s", THPUtils_typename(arg));
581 throw python_error();
582 }
583 Py_INCREF(Py_False);
584 PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, Py_False);
585 } else {
586 const auto& tensor = THPVariable_Unpack(arg);
587 unpacked.input_vars.push_back(tensor);
588 PyObject* needs_grad = tensor.requires_grad() ? Py_True : Py_False;
589 Py_INCREF(needs_grad);
590 PyTuple_SET_ITEM(flags.needs_input_grad.get(), i, needs_grad);
591 }
592 Py_INCREF(arg);
593 PyTuple_SET_ITEM(unpacked.input_tuple.get(), i, arg);
594 }
595
596 flags.is_executable =
597 GradMode::is_enabled() && any_variable_requires_grad(unpacked.input_vars);
598 flags.next_edges =
599 (flags.is_executable ? collect_next_edges(unpacked.input_vars)
600 : edge_list());
601 return std::make_pair(std::move(unpacked), std::move(flags));
602}
603
604// Given a prim::PythonOp node, _append_subgraph creates a subgraph such that:
605// (1) It has the same inputs as the prim::PythonOp node
606// (2) The intermediate nodes used in the PythonOp are cloned and stored in the
607// subgraph (3) trace_outputs stores the Value* objects, before a new trace
608// value is assigned by the prim::PythonOp node and helps to eventually route
609// the outputs of the subgraph correctly This newly created subgraph is then
610// added to the prim::PythonOp node as a subgraph attribute
611static void _append_subgraph(
612 torch::jit::Node* node,
613 torch::jit::Graph* graph,
614 std::vector<torch::jit::Value*> trace_outputs,
615 bool unpack_output) {
616 using Value = torch::jit::Value;
617 node->g_(
618 torch::jit::attr::Subgraph,
619 std::make_shared<torch::jit::Graph>(graph->current_scope()));
620 auto subgraph = node->g(torch::jit::attr::Subgraph);
621
622 std::unordered_map<Value*, Value*> value_map;
623 auto value_map_func = [&](Value* v) { return value_map.at(v); };
624 for (size_t i = 0; i < node->inputs().size(); ++i) {
625 auto subgraph_input = subgraph->addInput();
626 subgraph_input->copyMetadata(node->inputs().at(i));
627 value_map[node->inputs().at(i)] = subgraph_input;
628 }
629 // Find node position in owning block, all subsequent nodes after are added to
630 // subgraph
631 auto owning_block = node->owningBlock();
632 auto it = std::find(
633 owning_block->nodes().begin(), owning_block->nodes().end(), node);
634 // Skip TupleUnpack node if created
635 if (!unpack_output) {
636 it++;
637 }
638 for (it++; it != owning_block->nodes().end(); ++it) {
639 torch::jit::Node* node = *it;
640 auto* clone_node =
641 subgraph->insertNode(subgraph->createClone(node, value_map_func));
642 for (size_t i = 0; i < node->outputs().size(); ++i) {
643 value_map[node->outputs()[i]] = clone_node->outputs()[i];
644 auto trace_it = std::find(
645 trace_outputs.begin(), trace_outputs.end(), node->outputs()[i]);
646 if (trace_it != trace_outputs.end()) {
647 subgraph->registerOutput(clone_node->outputs()[i]);
648 }
649 }
650 }
651}
652
653static torch::jit::Node* _trace_pre_record(
654 PyObject* op_obj,
655 PyObject* input_objects,
656 const variable_list& input_vars) {
657 if (!jit::tracer::isTracing()) {
658 return nullptr;
659 }
660
661 // Save scalar args and the calling convention
662 auto num_args = PyTuple_GET_SIZE(input_objects);
663 pyobj_list scalar_args;
664 std::string arg_types;
665 arg_types.reserve(num_args);
666 scalar_args.reserve(num_args);
667 for (const auto i : c10::irange(num_args)) {
668 PyObject* arg_object = PyTuple_GET_ITEM(input_objects, i);
669 if (THPVariable_Check(arg_object)) {
670 arg_types.push_back('d');
671 } else {
672 arg_types.push_back('c');
673 Py_INCREF(arg_object);
674 scalar_args.emplace_back(arg_object);
675 }
676 }
677
678 Py_INCREF(op_obj);
679 auto pyobj = THPObjectPtr(op_obj);
680 return jit::tracer::preRecordPythonTrace(
681 std::move(pyobj), arg_types, input_vars, std::move(scalar_args));
682}
683
684static void _trace_post_record(
685 torch::jit::Node* node,
686 PyObject* op_obj,
687 const variable_list& input_vars,
688 PyObject* output_objects,
689 bool is_inplace,
690 bool unpack_output) {
691 if (!jit::tracer::isTracing()) {
692 return;
693 }
694
695 node->i_(jit::attr::inplace, is_inplace);
696 if (PyObject* module_name = PyDict_GetItemString(
697 ((PyTypeObject*)op_obj)->tp_dict, "__module__")) {
698 if (auto ptr = PyUnicode_AsUTF8(module_name)) {
699 node->s_(jit::attr::module, std::string(ptr));
700 }
701 }
702
703 // Isolate C variable ptrs in a vector
704 int num_outputs = PyTuple_GET_SIZE(output_objects);
705 auto graph = node->owningGraph();
706 node->addOutput();
707 auto old_node = node;
708 if (!unpack_output) {
709 std::vector<at::TypePtr> tuple_values(num_outputs, at::TensorType::get());
710 auto tuple_type = at::TupleType::create(std::move(tuple_values));
711 // Original type is tuple of tensors "without" element type and shape.
712 // The missed parts will be added below.
713 node->output()->setType(std::move(tuple_type));
714 auto unpacked = graph->createTupleUnpack(node->output())->insertAfter(node);
715 node = unpacked;
716 }
717
718 std::vector<torch::jit::Value*> trace_outputs;
719 for (const auto i : c10::irange(num_outputs)) {
720 PyObject* obj = PyTuple_GET_ITEM(output_objects, i);
721 if (THPVariable_Check(obj)) {
722 auto value = node->outputs()[i];
723 const auto& tensor = THPVariable_Unpack(obj);
724 if (tensor.defined()) {
725 value->inferTypeFrom(tensor);
726 trace_outputs.push_back(jit::tracer::getValueTrace(tensor));
727 jit::tracer::setValueTrace(tensor, value);
728 }
729 }
730 }
731 py::bool_ is_in_onnx_export =
732 py::module::import("torch.onnx.__init__").attr("is_in_onnx_export");
733 if (py::cast<bool>(is_in_onnx_export)) {
734 _append_subgraph(old_node, graph, std::move(trace_outputs), unpack_output);
735 }
736
737 // If TupleUnpack operator is created, we copy its output type back
738 // to the original tuple type.
739 if (!unpack_output) {
740 std::vector<at::TypePtr> new_tuple_values;
741 for (const auto i : c10::irange(num_outputs)) {
742 auto ptr = node->outputs()[i]->type();
743 new_tuple_values.push_back(ptr);
744 }
745 auto tuple_type = at::TupleType::create(std::move(new_tuple_values));
746 // The i-th tuple element receives a new tensor type with element type and
747 // shape.
748 old_node->output()->setType(std::move(tuple_type));
749 }
750}
751
752PyObject* process_outputs(
753 PyObject* op_obj,
754 const std::shared_ptr<PyNode>& cdata,
755 THPFunction* grad_fn,
756 const UnpackedInput& unpacked,
757 PyObject* inputs,
758 THPObjectPtr&& raw_output,
759 bool is_executable,
760 torch::jit::Node* node) {
761 bool unpack_output = ensure_tuple(raw_output);
762
763 auto num_outputs = PyTuple_GET_SIZE(raw_output.get());
764
765 THPObjectPtr outputs(PyTuple_New(num_outputs));
766 if (!outputs)
767 throw python_error();
768
769 cdata->clear_input_metadata();
770
771 // Record type, device, and size information about inputs
772 if (is_executable) {
773 grad_fn->input_info.clear();
774 grad_fn->input_info.reserve(unpacked.input_vars.size());
775 for (auto& var : unpacked.input_vars) {
776 grad_fn->input_info.emplace_back(var);
777 }
778 }
779
780 bool is_inplace = static_cast<bool>(grad_fn->dirty_tensors);
781 _wrap_outputs(
782 cdata, grad_fn, unpacked.input_vars, raw_output, outputs, is_executable);
783 _trace_post_record(
784 node, op_obj, unpacked.input_vars, outputs, is_inplace, unpack_output);
785
786 // It is important that creating the SavedVariables happen after the output
787 // wrapping as the outputs must have their grad_fn/fw_grad properly set before
788 // we save them.
789 if (is_executable) {
790 _save_variables(cdata, grad_fn);
791 } else {
792 // Remove unnecessary attributes
793 Py_XDECREF(grad_fn->to_save);
794 grad_fn->to_save = nullptr;
795 Py_XDECREF(grad_fn->non_differentiable);
796 grad_fn->non_differentiable = nullptr;
797 }
798
799 Py_XDECREF(grad_fn->saved_for_forward);
800 grad_fn->saved_for_forward = nullptr;
801
802 // Unpack the output, unless .forward() returned a tuple
803 if (unpack_output) {
804 PyObject* output = PyTuple_GET_ITEM(outputs.get(), 0);
805 Py_INCREF(output);
806 return output;
807 }
808
809 return outputs.release();
810}
811
812PyObject* THPFunction_name(PyObject* self, PyObject* noargs) {
813 HANDLE_TH_ERRORS
814 auto cdata = ((THPFunction*)self)->cdata.lock();
815 TORCH_CHECK(
816 cdata,
817 "Attribute 'name' is invalid for this instance of _C._FunctionBase. "
818 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
819 "access pattern that is no longer supported. For examples on how to use new-style "
820 "autograd functions, see "
821 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
822 return THPUtils_packString(cdata->name());
823 END_HANDLE_TH_ERRORS
824}
825
826PyObject* THPFunction_maybe_clear_saved_tensors(
827 PyObject* self,
828 PyObject* noargs) {
829 HANDLE_TH_ERRORS;
830 auto cdata = ((THPFunction*)self)->cdata.lock();
831 if (!get_current_graph_task_keep_graph()) {
832 cdata->release_variables();
833 }
834 Py_RETURN_NONE;
835 END_HANDLE_TH_ERRORS
836}
837
838namespace {
839
840THPObjectPtr make_ctx_input_tuple(
841 THPFunction* ctx,
842 const UnpackedInput& unpacked_input,
843 int64_t num_args) {
844 THPObjectPtr ctx_input_tuple(PyTuple_New(num_args + 1));
845 if (!ctx_input_tuple)
846 return {};
847 Py_INCREF(ctx);
848 PyTuple_SET_ITEM(ctx_input_tuple.get(), 0, (PyObject*)ctx);
849 for (const auto i : c10::irange(num_args)) {
850 PyObject* arg = PyTuple_GET_ITEM(unpacked_input.input_tuple.get(), i);
851 Py_INCREF(arg);
852 PyTuple_SET_ITEM(ctx_input_tuple.get(), i + 1, arg);
853 }
854 return ctx_input_tuple;
855}
856
857THPObjectPtr make_ctx_input_output_tuple(
858 THPFunction* ctx,
859 UnpackedInput& unpacked_input,
860 PyObject* output) {
861 THPObjectPtr result(PyTuple_New(3));
862 if (!result)
863 return {};
864 Py_INCREF(ctx);
865 Py_INCREF(unpacked_input.input_tuple.get());
866 Py_INCREF(output);
867 PyTuple_SET_ITEM(result.get(), 0, (PyObject*)ctx);
868 PyTuple_SET_ITEM(result.get(), 1, unpacked_input.input_tuple.get());
869 PyTuple_SET_ITEM(result.get(), 2, output);
870 return result;
871}
872
873} // namespace
874
875static PyObject* THPFunction_setup_context = nullptr;
876
877static PyObject* get_base_setup_context() {
878 if (THPFunction_setup_context != nullptr) {
879 return THPFunction_setup_context;
880 }
881
882 auto module = THPObjectPtr(PyImport_ImportModule("torch.autograd.function"));
883 if (!module)
884 return nullptr;
885
886 auto function =
887 THPObjectPtr(PyObject_GetAttrString(module, "_SingleLevelFunction"));
888 if (!function)
889 return nullptr;
890
891 // setup_context gets "leaked" - we return a new reference and hold onto it
892 // forever.
893 auto setup_context = PyObject_GetAttrString(function, "setup_context");
894 if (!setup_context)
895 return nullptr;
896 THPFunction_setup_context = setup_context;
897 return THPFunction_setup_context;
898}
899
900PyObject* THPFunction_apply(PyObject* cls, PyObject* inputs) {
901 HANDLE_TH_ERRORS
902
903 // save a local copy of seq_id before it gets incremented
904 int seq_id = at::sequence_number::peek();
905 auto info_pair = unpack_input<false>(inputs);
906 UnpackedInput& unpacked_input = info_pair.first;
907 InputFlags& input_info = info_pair.second;
908
909 // Call record function after all the inputs have been decoded, but
910 // before context has been allocated.
911 RECORD_FUNCTION(
912 ((PyTypeObject*)cls)->tp_name,
913 std::vector<c10::IValue>(
914 unpacked_input.input_vars.begin(), unpacked_input.input_vars.end()),
915 seq_id);
916
917 const auto& functorch_tls = at::functorch::functorchTLSAccessor();
918 if (functorch_tls) {
919 // autograd.Function support for functorch is handled in Python.
920 // If we have gotten here, then either we are dealing with a
921 // torch.autograd.function._SingleLevelFunction, or something in
922 // the implementation went wrong.
923 // The following code is useful for debugging when something goes wrong
924 // because it'll raise a loud error (instead of being silently incorrect).
925 functorch_tls->checkSupportsSingleLevelAutogradFunction();
926 }
927
928 THPObjectPtr backward_cls(PyObject_GetAttrString(cls, "_backward_cls"));
929 if (!backward_cls)
930 return nullptr;
931 THPObjectPtr ctx_obj(PyObject_CallFunctionObjArgs(backward_cls, nullptr));
932 if (!ctx_obj)
933 return nullptr;
934 THPFunction* ctx = (THPFunction*)ctx_obj.get();
935
936 auto cdata =
937 std::shared_ptr<PyNode>(new PyNode(std::move(ctx_obj)), deleteNode);
938 ctx->cdata = cdata;
939
940 // Record input nodes if tracing
941 auto* node = _trace_pre_record(cls, inputs, unpacked_input.input_vars);
942
943 // Initialize backward function (and ctx)
944 bool is_executable = input_info.is_executable;
945 cdata->set_next_edges(std::move(input_info.next_edges));
946 ctx->needs_input_grad = input_info.needs_input_grad.release();
947 ctx->is_variable_input = std::move(input_info.is_variable_input);
948
949 // autograd.Function may optionally override a setup_context staticmethod.
950 // In this case, autograd.Function.forward does NOT accept a ctx object.
951 // Determine if this is the case.
952 auto cls_setup_context =
953 THPObjectPtr(PyObject_GetAttrString(cls, "setup_context"));
954 if (!cls_setup_context) {
955 return nullptr;
956 }
957 auto orig_setup_context = get_base_setup_context();
958 if (!orig_setup_context) {
959 return nullptr;
960 }
961 auto overridden_setup_context = cls_setup_context.get() != orig_setup_context;
962
963 auto num_args = PyTuple_GET_SIZE(inputs);
964
965 // Call forward
966 THPObjectPtr output;
967 {
968 AutoGradMode grad_mode(false);
969 at::AutoFwGradMode fw_grad_mode(false);
970 THPObjectPtr forward_fn(PyObject_GetAttrString(cls, "forward"));
971 if (!forward_fn)
972 return nullptr;
973 if (overridden_setup_context) {
974 // call forward followed by setup_context
975 output = PyObject_CallObject(forward_fn, unpacked_input.input_tuple);
976 if (!output) {
977 return nullptr;
978 }
979 // signature is setup_context(ctx, inputs, output)
980 auto ctx_input_output_tuple =
981 make_ctx_input_output_tuple(ctx, unpacked_input, output);
982 if (!ctx_input_output_tuple) {
983 return nullptr;
984 }
985 THPObjectPtr setup_context_fn(
986 PyObject_GetAttrString(cls, "setup_context"));
987 auto result =
988 PyObject_CallObject(setup_context_fn, ctx_input_output_tuple);
989 if (!result) {
990 return nullptr;
991 }
992 } else {
993 // call forward
994 auto ctx_input_tuple =
995 make_ctx_input_tuple(ctx, unpacked_input, num_args);
996 if (!ctx_input_tuple) {
997 return nullptr;
998 }
999 output = PyObject_CallObject(forward_fn, ctx_input_tuple);
1000 }
1001 if (!output)
1002 return nullptr;
1003 }
1004
1005 return process_outputs(
1006 cls,
1007 cdata,
1008 ctx,
1009 unpacked_input,
1010 inputs,
1011 std::move(output),
1012 is_executable,
1013 node);
1014 END_HANDLE_TH_ERRORS
1015}
1016
1017////////////////////////////////////////////////////////////////////////////////
1018// Other methods / attributes
1019////////////////////////////////////////////////////////////////////////////////
1020
1021PyObject* THPFunction__register_hook_dict(PyObject* _self, PyObject* _var) {
1022 HANDLE_TH_ERRORS
1023 THPUtils_assert(
1024 THPVariable_Check(_var), "_register_hook_dict expected a Tensor");
1025 THPVariable* var = reinterpret_cast<THPVariable*>(_var);
1026 const auto& tensor = THPVariable_Unpack(var);
1027 std::unique_ptr<FunctionPreHook> hook(
1028 new PyFunctionTensorPreHook(var->backward_hooks, tensor.output_nr()));
1029 auto self = (THPFunction*)_self;
1030 auto cdata = self->cdata.lock();
1031 TORCH_CHECK(
1032 cdata,
1033 "Attribute '_register_hook_dict' is invalid for this instance of _C._FunctionBase. "
1034 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1035 "access pattern that is no longer supported. For examples on how to use new-style "
1036 "autograd functions, see "
1037 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1038 cdata->add_tensor_pre_hook(std::move(hook));
1039 Py_RETURN_NONE;
1040 END_HANDLE_TH_ERRORS
1041}
1042
1043PyObject* THPFunction_register_hook(PyObject* _self, PyObject* hook) {
1044 HANDLE_TH_ERRORS
1045 auto self = (THPFunction*)_self;
1046 auto cdata = self->cdata.lock();
1047 TORCH_CHECK(
1048 cdata,
1049 "Attribute 'register_hook' is invalid for this instance of _C._FunctionBase. "
1050 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1051 "access pattern that is no longer supported. For examples on how to use new-style "
1052 "autograd functions, see "
1053 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1054 return torch::autograd::registerFunctionHook(*cdata, hook);
1055 END_HANDLE_TH_ERRORS
1056}
1057
1058PyObject* THPFunction_register_prehook(PyObject* _self, PyObject* hook) {
1059 HANDLE_TH_ERRORS
1060 auto self = (THPFunction*)_self;
1061 auto cdata = self->cdata.lock();
1062 TORCH_CHECK(
1063 cdata,
1064 "Attribute 'register_prehook' is invalid for this instance of _C._FunctionBase. "
1065 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1066 "access pattern that is no longer supported. For examples on how to use new-style "
1067 "autograd functions, see "
1068 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1069 return torch::autograd::registerFunctionPreHook(*cdata, hook);
1070 END_HANDLE_TH_ERRORS
1071}
1072
1073int THPFunction_set_materialize_grads(
1074 THPFunction* self,
1075 PyObject* value,
1076 void* unused) {
1077 HANDLE_TH_ERRORS
1078 if (!PyBool_Check(value)) {
1079 THPUtils_invalidArguments(
1080 value, nullptr, "set_materialize_grads", 1, "(bool)");
1081 return -1;
1082 }
1083 self->materialize_grads = (value == Py_True);
1084 return 0;
1085 END_HANDLE_TH_ERRORS_RET(-1)
1086}
1087
1088static PyObject* unpack_saved_variables(
1089 THPFunction* self,
1090 const std::function<PyObject*(const Variable&)>& unpack_fn) {
1091 THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
1092 auto& saved_variables = self->saved_variables;
1093 if (saved_variables.empty())
1094 return PyTuple_New(0);
1095
1096 int num_saved = saved_variables.size();
1097 THPObjectPtr saved(PyTuple_New(num_saved));
1098 if (!saved)
1099 return nullptr;
1100 auto saved_for = self->cdata.lock();
1101 // This is really a true assert, because we've already tested for the
1102 // self->has_freed_buffers case at the beginning of this function:
1103 // buffers are freed when PyNode dies; if the buffers are not freed,
1104 // PyNode must be live. (Note that the buffers could be freed
1105 // even though the PyNode is live, but that doesn't matter here
1106 // because we will never hit this line of code if the buffers are freed--
1107 // and in any case saved_for will be non-NULL.)
1108 TORCH_INTERNAL_ASSERT(saved_for);
1109 for (const auto i : c10::irange(num_saved)) {
1110 auto unpacked_var = saved_variables[i].unpack(saved_for);
1111 THPObjectPtr value;
1112 if (!unpacked_var.defined()) {
1113 Py_INCREF(Py_None);
1114 value = Py_None;
1115 } else {
1116 value = unpack_fn(unpacked_var);
1117 }
1118 PyTuple_SET_ITEM(saved.get(), i, value.release());
1119 }
1120 return saved.release();
1121}
1122
1123PyObject* THPFunction_saved_tensors(THPFunction* self, void* _unused) {
1124 HANDLE_TH_ERRORS
1125 if (self->saved_for_forward) {
1126 Py_INCREF(self->saved_for_forward);
1127 return self->saved_for_forward;
1128 } else {
1129 return unpack_saved_variables(
1130 self, [](const Variable& var) { return THPVariable_Wrap(var); });
1131 }
1132 END_HANDLE_TH_ERRORS
1133}
1134
1135PyObject* THPFunction_saved_variables(THPFunction* self, void* _unused) {
1136 HANDLE_TH_ERRORS
1137 auto r = PyErr_WarnEx(
1138 PyExc_DeprecationWarning,
1139 "'saved_variables' is deprecated; use 'saved_tensors'",
1140 0);
1141 if (r != 0)
1142 throw python_error();
1143 return unpack_saved_variables(
1144 self, [](const Variable& var) { return THPVariable_Wrap(var); });
1145 END_HANDLE_TH_ERRORS
1146}
1147
1148PyObject* THPFunction_raw_saved_tensors(THPFunction* self, void* _unused) {
1149 HANDLE_TH_ERRORS
1150 // User tries to access saved variables after they have been freed
1151 THPUtils_assert(!self->has_freed_buffers, ERR_BACKWARD_TWICE);
1152 const auto& saved_variables = self->saved_variables;
1153 if (saved_variables.empty())
1154 return PyTuple_New(0);
1155 size_t num_saved = saved_variables.size();
1156 THPObjectPtr saved(PyTuple_New(num_saved));
1157 if (!saved) {
1158 return nullptr;
1159 }
1160 for (const auto i : c10::irange(num_saved)) {
1161 py::object obj =
1162 py::cast(saved_variables[i], py::return_value_policy::reference);
1163 PyTuple_SET_ITEM(saved.get(), i, obj.release().ptr());
1164 }
1165 return saved.release();
1166 END_HANDLE_TH_ERRORS
1167}
1168
1169PyObject* THPFunction_next_functions(THPFunction* self, void* _unused) {
1170 HANDLE_TH_ERRORS
1171 auto cdata = self->cdata.lock();
1172 TORCH_CHECK(
1173 cdata,
1174 "Attribute 'next_functions' is invalid for this instance of _C._FunctionBase. "
1175 "Accessing this attribute directly on an instance of autograd.Function is a legacy "
1176 "access pattern that is no longer supported. For examples on how to use new-style "
1177 "autograd functions, see "
1178 "https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function ");
1179 const auto num_outputs = cdata->num_outputs();
1180 THPObjectPtr result(PyTuple_New(num_outputs));
1181 if (!result)
1182 return nullptr;
1183 for (const auto i : c10::irange(num_outputs)) {
1184 THPObjectPtr fn_tuple(PyTuple_New(2));
1185 if (!fn_tuple)
1186 return nullptr;
1187 const auto& edge = cdata->next_edge(i);
1188 PyObject* fn = functionToPyObject(edge.function);
1189 if (!fn)
1190 return nullptr;
1191 PyTuple_SET_ITEM(fn_tuple.get(), 0, fn);
1192 PyTuple_SET_ITEM(fn_tuple.get(), 1, THPUtils_packInt64(edge.input_nr));
1193 PyTuple_SET_ITEM(result.get(), i, fn_tuple.release());
1194 }
1195 return result.release();
1196 END_HANDLE_TH_ERRORS
1197}
1198
1199PyObject* THPFunction_metadata(THPFunction* self, void* _unused) {
1200 HANDLE_TH_ERRORS
1201 auto cdata = self->cdata.lock();
1202 // The correct way to solve this problem is to stop exposing grad_fn
1203 // of PyFunctions as THPFunction; instead, we should use THPCppFunction
1204 // like everyone else. But this is a BC-breaking change as it would
1205 // mean that you no longer get the property that grad_fn is a subclass
1206 // of the autograd function class that you defined in the custom case,
1207 // so I didn't fix it here.
1208 TORCH_CHECK(
1209 cdata,
1210 "You attempted to access the anomaly metadata of a custom autograd function "
1211 "but the underlying PyNode has already been deallocated. The most likely "
1212 "reason this occurred is because you assigned x.grad_fn to a local variable "
1213 "and then let the original variable get deallocated. Don't do that! If "
1214 "you really have no way of restructuring your code so this is the case, "
1215 "please file an issue reporting that you are affected by this.");
1216 auto metadata = static_cast<PyAnomalyMetadata*>(cdata->metadata())->dict();
1217
1218 Py_INCREF(metadata);
1219 return metadata;
1220 END_HANDLE_TH_ERRORS
1221}
1222
1223typedef PyObject* (*getter)(PyObject*, void*);
1224typedef int (*setter)(PyObject*, PyObject*, void*);
1225
1226namespace {
1227
1228template <PyObject* THPFunction::*ptr>
1229PyObject* getObject(PyObject* obj, void* _unused) {
1230 auto self = (THPFunction*)obj;
1231 PyObject* value = self->*ptr;
1232 if (!value) {
1233 Py_RETURN_NONE;
1234 }
1235 Py_INCREF(value);
1236 return value;
1237}
1238
1239template <PyObject* THPFunction::*ptr>
1240int setObject(PyObject* obj, PyObject* value, void* _unused) {
1241 auto self = (THPFunction*)obj;
1242 if (value == Py_None) {
1243 value = nullptr;
1244 }
1245 Py_XDECREF((self->*ptr));
1246 Py_XINCREF(value);
1247 self->*ptr = value;
1248 return 0;
1249}
1250
1251template <typename M, M THPFunction::*ptr, PyObject* (*Convert)(long)>
1252PyObject* getMember(PyObject* obj, void* _unused) {
1253 auto self = (THPFunction*)obj;
1254 return Convert(self->*ptr);
1255}
1256
1257template <typename M, M autograd::Node::*ptr, PyObject* (*Convert)(long)>
1258PyObject* getImplMember(PyObject* obj, void* _unused) {
1259 auto self = (THPFunction*)obj;
1260 return Convert(self->cdata.*ptr);
1261}
1262
1263PyObject* getRequiresGrad(PyObject* obj, void* _unused) {
1264 Py_RETURN_TRUE;
1265}
1266
1267} // namespace
1268
1269// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1270static struct PyGetSetDef THPFunction_properties[] = {
1271 {"saved_tensors",
1272 (getter)THPFunction_saved_tensors,
1273 nullptr,
1274 nullptr,
1275 nullptr},
1276 {"saved_variables",
1277 (getter)THPFunction_saved_variables,
1278 nullptr,
1279 nullptr,
1280 nullptr},
1281 {"_raw_saved_tensors",
1282 (getter)THPFunction_raw_saved_tensors,
1283 nullptr,
1284 nullptr,
1285 nullptr},
1286 {"next_functions",
1287 (getter)THPFunction_next_functions,
1288 nullptr,
1289 nullptr,
1290 nullptr},
1291 {"to_save",
1292 &getObject<&THPFunction::to_save>,
1293 &setObject<&THPFunction::to_save>,
1294 nullptr,
1295 nullptr},
1296 {"non_differentiable",
1297 &getObject<&THPFunction::non_differentiable>,
1298 &setObject<&THPFunction::non_differentiable>,
1299 nullptr,
1300 nullptr},
1301 {"dirty_tensors",
1302 &getObject<&THPFunction::dirty_tensors>,
1303 &setObject<&THPFunction::dirty_tensors>,
1304 nullptr,
1305 nullptr},
1306 {"saved_for_forward",
1307 &getObject<&THPFunction::saved_for_forward>,
1308 &setObject<&THPFunction::saved_for_forward>,
1309 nullptr,
1310 nullptr},
1311 {"needs_input_grad",
1312 &getObject<&THPFunction::needs_input_grad>,
1313 nullptr,
1314 nullptr,
1315 nullptr},
1316 {"requires_grad", getRequiresGrad, nullptr, nullptr, nullptr},
1317 {"metadata", (getter)THPFunction_metadata, nullptr, nullptr, nullptr},
1318 {"materialize_grads",
1319 nullptr,
1320 (setter)THPFunction_set_materialize_grads,
1321 nullptr,
1322 nullptr},
1323 {nullptr}};
1324
1325// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
1326static struct PyMethodDef THPFunction_methods[] = {
1327 {(char*)"name", THPFunction_name, METH_NOARGS, nullptr},
1328 {(char*)"maybe_clear_saved_tensors",
1329 THPFunction_maybe_clear_saved_tensors,
1330 METH_NOARGS,
1331 nullptr},
1332 {(char*)"apply", THPFunction_apply, METH_CLASS | METH_VARARGS, nullptr},
1333 {(char*)"_register_hook_dict",
1334 THPFunction__register_hook_dict,
1335 METH_O,
1336 nullptr},
1337 {(char*)"register_hook", THPFunction_register_hook, METH_O, nullptr},
1338 {(char*)"register_prehook", THPFunction_register_prehook, METH_O, nullptr},
1339 {nullptr}};
1340
1341PyTypeObject THPFunctionType = {
1342 PyVarObject_HEAD_INIT(nullptr, 0) "torch._C._FunctionBase", /* tp_name */
1343 sizeof(THPFunction), /* tp_basicsize */
1344 0, /* tp_itemsize */
1345 (destructor)THPFunction_dealloc, /* tp_dealloc */
1346 0, /* tp_vectorcall_offset */
1347 nullptr, /* tp_getattr */
1348 nullptr, /* tp_setattr */
1349 nullptr, /* tp_reserved */
1350 nullptr, /* tp_repr */
1351 nullptr, /* tp_as_number */
1352 nullptr, /* tp_as_sequence */
1353 nullptr, /* tp_as_mapping */
1354 nullptr, /* tp_hash */
1355 nullptr, /* tp_call */
1356 nullptr, /* tp_str */
1357 nullptr, /* tp_getattro */
1358 nullptr, /* tp_setattro */
1359 nullptr, /* tp_as_buffer */
1360 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE |
1361 Py_TPFLAGS_HAVE_GC, /* tp_flags */
1362 nullptr, /* tp_doc */
1363 (traverseproc)THPFunction_traverse, /* tp_traverse */
1364 (inquiry)THPFunction_clear, /* tp_clear */
1365 nullptr, /* tp_richcompare */
1366 0, /* tp_weaklistoffset */
1367 nullptr, /* tp_iter */
1368 nullptr, /* tp_iternext */
1369 THPFunction_methods, /* tp_methods */
1370 nullptr, /* tp_members */
1371 THPFunction_properties, /* tp_getset */
1372 nullptr, /* tp_base */
1373 nullptr, /* tp_dict */
1374 nullptr, /* tp_descr_get */
1375 nullptr, /* tp_descr_set */
1376 0, /* tp_dictoffset */
1377 nullptr, /* tp_init */
1378 nullptr, /* tp_alloc */
1379 THPFunction_new /* tp_new */
1380};
1381
1382bool THPFunction_initModule(PyObject* module) {
1383 if (PyType_Ready(&THPFunctionType) < 0)
1384 return false;
1385 Py_INCREF(&THPFunctionType);
1386 PyModule_AddObject(module, "_FunctionBase", (PyObject*)&THPFunctionType);
1387 return true;
1388}
1389