1#pragma once
2
3#include <torch/csrc/python_headers.h>
4
5#include <torch/csrc/Exceptions.h>
6#include <torch/csrc/autograd/custom_function.h>
7#include <torch/csrc/autograd/function.h>
8#include <torch/csrc/autograd/saved_variable.h>
9#include <torch/csrc/autograd/variable.h>
10#include <torch/csrc/utils/object_ptr.h>
11
12#include <c10/core/DeviceGuard.h>
13#include <c10/util/Optional.h>
14
15#include <memory>
16#include <utility>
17#include <vector>
18
19namespace torch {
20namespace jit {
21struct Graph;
22}
23} // namespace torch
24namespace torch {
25namespace autograd {
26
27// A Function which is implemented by a Python object (i.e., a THPFunction).
28// Calls to 'apply' are forwarded to the Python method implementation.
29struct PyNode : public Node {
30 PyNode(THPObjectPtr obj) : obj(obj.release()) {}
31
32 variable_list apply(variable_list&& inputs) override;
33
34 void release_variables() override;
35 std::string name() const override;
36 bool is_traceable() override;
37
38 // THPFunction this Function is wrapping. Owning!
39 PyObject* obj;
40
41 ~PyNode() override {
42 // Can't use THPObjectPtr as a field in this class; destructor won't take
43 // out GIL! When I forgot to do this by hand
44 // TestAutograd.test_inplace_view_python called me out about it.
45 // If python is already dead, leak the wrapped python objects
46 if (Py_IsInitialized()) {
47 pybind11::gil_scoped_acquire gil;
48 Py_DECREF(obj);
49 }
50 }
51};
52
53/**
54 * Cast an object into a tuple, if it is not a tuple already. Returns true
55 * if the original object was not a tuple.
56 */
57inline bool ensure_tuple(THPObjectPtr& obj) {
58 if (PyTuple_Check(obj.get()))
59 return false;
60
61 PyObject* tuple = PyTuple_New(1);
62 if (!tuple)
63 throw python_error();
64 PyTuple_SET_ITEM(tuple, 0, obj.release());
65 obj = tuple;
66 return true;
67}
68
69} // namespace autograd
70} // namespace torch
71
72// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
73struct THPFunction {
74 PyObject_HEAD
75
76 PyObject* needs_input_grad;
77
78 // Python tuple of tensors whose variables we should save. Set
79 // by Python with 'save_for_backward'. If nullptr, no tensors were
80 // saved.
81 PyObject* to_save;
82 // Python tuple of tensors which are not differentiable. Set by
83 // Python with 'mark_non_differentiable'. If nullptr, no tensors were
84 // non-differentiable.
85 PyObject* non_differentiable;
86 // Python tuple of tensors which had inplace updates in the forward()
87 // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were
88 // modified inplace.
89 PyObject* dirty_tensors;
90
91 // boolean indicating whether to materialize undefined output grad tensors
92 // into tensors full of zeros. Set by Python with 'set_materialize_grads'.
93 // Default is true.
94 bool materialize_grads;
95
96 std::vector<torch::autograd::VariableInfo> output_info;
97 std::vector<torch::autograd::VariableInfo> input_info;
98 std::vector<torch::autograd::SavedVariable> saved_variables;
99 // For each input, true if the input is a THPVariable
100 std::vector<bool> is_variable_input;
101 char has_freed_buffers;
102
103 PyObject* saved_for_forward;
104 // The actual PyNode (in the autograd graph) that this data was
105 // saved for. This field may be NULL (because a user can construct
106 // a THPFunction directly from Python), but when this field is non-NULL,
107 // it is guaranteed that cdata.lock()->obj == this
108 //
109 // In most ordinary use, this field should always be non-NULL; e.g.,
110 // when we allocate a THPFunction because we are running Node.apply,
111 // after constructing a THPFunction, we immediately allocate a PyNode
112 // for it. We can't enforce this directly in the constructor of
113 // THPFunction though, because there's no way to keep it live long enough
114 // to save an owning reference to PyNode into the grad_fn of a Variable.
115 std::weak_ptr<torch::autograd::PyNode> cdata;
116};
117
118bool THPFunction_initModule(PyObject* module);
119extern PyTypeObject THPFunctionType;
120extern PyObject* THPFunctionClass;
121
122inline bool THPFunction_Check(PyObject* obj) {
123 return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
124}
125