1 | #pragma once |
2 | |
3 | #include <torch/csrc/python_headers.h> |
4 | |
5 | #include <torch/csrc/Exceptions.h> |
6 | #include <torch/csrc/autograd/custom_function.h> |
7 | #include <torch/csrc/autograd/function.h> |
8 | #include <torch/csrc/autograd/saved_variable.h> |
9 | #include <torch/csrc/autograd/variable.h> |
10 | #include <torch/csrc/utils/object_ptr.h> |
11 | |
12 | #include <c10/core/DeviceGuard.h> |
13 | #include <c10/util/Optional.h> |
14 | |
15 | #include <memory> |
16 | #include <utility> |
17 | #include <vector> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | struct Graph; |
22 | } |
23 | } // namespace torch |
24 | namespace torch { |
25 | namespace autograd { |
26 | |
27 | // A Function which is implemented by a Python object (i.e., a THPFunction). |
28 | // Calls to 'apply' are forwarded to the Python method implementation. |
29 | struct PyNode : public Node { |
30 | PyNode(THPObjectPtr obj) : obj(obj.release()) {} |
31 | |
32 | variable_list apply(variable_list&& inputs) override; |
33 | |
34 | void release_variables() override; |
35 | std::string name() const override; |
36 | bool is_traceable() override; |
37 | |
38 | // THPFunction this Function is wrapping. Owning! |
39 | PyObject* obj; |
40 | |
41 | ~PyNode() override { |
42 | // Can't use THPObjectPtr as a field in this class; destructor won't take |
43 | // out GIL! When I forgot to do this by hand |
44 | // TestAutograd.test_inplace_view_python called me out about it. |
45 | // If python is already dead, leak the wrapped python objects |
46 | if (Py_IsInitialized()) { |
47 | pybind11::gil_scoped_acquire gil; |
48 | Py_DECREF(obj); |
49 | } |
50 | } |
51 | }; |
52 | |
53 | /** |
54 | * Cast an object into a tuple, if it is not a tuple already. Returns true |
55 | * if the original object was not a tuple. |
56 | */ |
57 | inline bool ensure_tuple(THPObjectPtr& obj) { |
58 | if (PyTuple_Check(obj.get())) |
59 | return false; |
60 | |
61 | PyObject* tuple = PyTuple_New(1); |
62 | if (!tuple) |
63 | throw python_error(); |
64 | PyTuple_SET_ITEM(tuple, 0, obj.release()); |
65 | obj = tuple; |
66 | return true; |
67 | } |
68 | |
69 | } // namespace autograd |
70 | } // namespace torch |
71 | |
72 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
73 | struct THPFunction { |
74 | PyObject_HEAD |
75 | |
76 | PyObject* needs_input_grad; |
77 | |
78 | // Python tuple of tensors whose variables we should save. Set |
79 | // by Python with 'save_for_backward'. If nullptr, no tensors were |
80 | // saved. |
81 | PyObject* to_save; |
82 | // Python tuple of tensors which are not differentiable. Set by |
83 | // Python with 'mark_non_differentiable'. If nullptr, no tensors were |
84 | // non-differentiable. |
85 | PyObject* non_differentiable; |
86 | // Python tuple of tensors which had inplace updates in the forward() |
87 | // pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were |
88 | // modified inplace. |
89 | PyObject* dirty_tensors; |
90 | |
91 | // boolean indicating whether to materialize undefined output grad tensors |
92 | // into tensors full of zeros. Set by Python with 'set_materialize_grads'. |
93 | // Default is true. |
94 | bool materialize_grads; |
95 | |
96 | std::vector<torch::autograd::VariableInfo> output_info; |
97 | std::vector<torch::autograd::VariableInfo> input_info; |
98 | std::vector<torch::autograd::SavedVariable> saved_variables; |
99 | // For each input, true if the input is a THPVariable |
100 | std::vector<bool> is_variable_input; |
101 | char has_freed_buffers; |
102 | |
103 | PyObject* saved_for_forward; |
104 | // The actual PyNode (in the autograd graph) that this data was |
105 | // saved for. This field may be NULL (because a user can construct |
106 | // a THPFunction directly from Python), but when this field is non-NULL, |
107 | // it is guaranteed that cdata.lock()->obj == this |
108 | // |
109 | // In most ordinary use, this field should always be non-NULL; e.g., |
110 | // when we allocate a THPFunction because we are running Node.apply, |
111 | // after constructing a THPFunction, we immediately allocate a PyNode |
112 | // for it. We can't enforce this directly in the constructor of |
113 | // THPFunction though, because there's no way to keep it live long enough |
114 | // to save an owning reference to PyNode into the grad_fn of a Variable. |
115 | std::weak_ptr<torch::autograd::PyNode> cdata; |
116 | }; |
117 | |
118 | bool THPFunction_initModule(PyObject* module); |
119 | extern PyTypeObject THPFunctionType; |
120 | extern PyObject* THPFunctionClass; |
121 | |
122 | inline bool THPFunction_Check(PyObject* obj) { |
123 | return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType); |
124 | } |
125 | |