1 | #pragma once |
2 | |
3 | #include <torch/csrc/python_headers.h> |
4 | |
5 | #include <ATen/core/Tensor.h> |
6 | #include <ATen/core/jit_type_base.h> |
7 | #include <c10/util/irange.h> |
8 | #include <c10/util/variant.h> |
9 | #include <pybind11/pybind11.h> |
10 | #include <pybind11/stl.h> |
11 | |
12 | #include <torch/csrc/Device.h> |
13 | #include <torch/csrc/DynamicTypes.h> |
14 | #include <torch/csrc/Generator.h> |
15 | #include <torch/csrc/MemoryFormat.h> |
16 | #include <torch/csrc/utils/tensor_memoryformats.h> |
17 | |
18 | #include <stdexcept> |
19 | #include <utility> |
20 | |
21 | namespace py = pybind11; |
22 | |
23 | // This makes intrusive_ptr to be available as a custom pybind11 holder type, |
24 | // see |
25 | // https://pybind11.readthedocs.io/en/stable/advanced/smart_ptrs.html#custom-smart-pointers |
26 | PYBIND11_DECLARE_HOLDER_TYPE(T, c10::intrusive_ptr<T>, true); |
27 | |
28 | PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonOrSharedTypePtr<T>); |
29 | PYBIND11_DECLARE_HOLDER_TYPE(T, c10::SingletonTypePtr<T>, true); |
30 | |
31 | namespace pybind11 { |
32 | namespace detail { |
33 | |
34 | // torch.Tensor <-> at::Tensor conversions (without unwrapping) |
35 | template <> |
36 | struct TORCH_PYTHON_API type_caster<at::Tensor> { |
37 | public: |
38 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
39 | PYBIND11_TYPE_CASTER(at::Tensor, _("torch.Tensor" )); |
40 | |
41 | bool load(handle src, bool); |
42 | |
43 | static handle cast( |
44 | const at::Tensor& src, |
45 | return_value_policy /* policy */, |
46 | handle /* parent */); |
47 | }; |
48 | |
49 | // torch._StorageBase <-> at::Storage |
50 | template <> |
51 | struct type_caster<at::Storage> { |
52 | public: |
53 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
54 | PYBIND11_TYPE_CASTER(at::Storage, _("torch.StorageBase" )); |
55 | |
56 | bool load(handle src, bool) { |
57 | PyObject* obj = src.ptr(); |
58 | if (torch::isStorage(obj)) { |
59 | value = torch::createStorage(obj); |
60 | return true; |
61 | } |
62 | return false; |
63 | } |
64 | |
65 | static handle cast( |
66 | const at::Storage& src, |
67 | return_value_policy /* policy */, |
68 | handle /* parent */) { |
69 | return handle(torch::createPyObject(src)); |
70 | } |
71 | }; |
72 | |
73 | template <> |
74 | struct type_caster<at::Generator> { |
75 | public: |
76 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
77 | PYBIND11_TYPE_CASTER(at::Generator, _("torch.Generator" )); |
78 | |
79 | bool load(handle src, bool) { |
80 | PyObject* obj = src.ptr(); |
81 | if (THPGenerator_Check(obj)) { |
82 | value = reinterpret_cast<THPGenerator*>(obj)->cdata; |
83 | return true; |
84 | } |
85 | return false; |
86 | } |
87 | |
88 | static handle cast( |
89 | const at::Generator& src, |
90 | return_value_policy /* policy */, |
91 | handle /* parent */) { |
92 | return handle(THPGenerator_Wrap(src)); |
93 | } |
94 | }; |
95 | |
96 | template <> |
97 | struct TORCH_PYTHON_API type_caster<at::IntArrayRef> { |
98 | public: |
99 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
100 | PYBIND11_TYPE_CASTER(at::IntArrayRef, _("Tuple[int, ...]" )); |
101 | |
102 | bool load(handle src, bool); |
103 | static handle cast( |
104 | at::IntArrayRef src, |
105 | return_value_policy /* policy */, |
106 | handle /* parent */); |
107 | |
108 | private: |
109 | std::vector<int64_t> v_value; |
110 | }; |
111 | |
112 | template <> |
113 | struct TORCH_PYTHON_API type_caster<at::SymIntArrayRef> { |
114 | public: |
115 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
116 | PYBIND11_TYPE_CASTER(at::SymIntArrayRef, _("List[int]" )); |
117 | |
118 | bool load(handle src, bool); |
119 | static handle cast( |
120 | at::SymIntArrayRef src, |
121 | return_value_policy /* policy */, |
122 | handle /* parent */); |
123 | |
124 | private: |
125 | std::vector<c10::SymInt> v_value; |
126 | }; |
127 | |
128 | template <> |
129 | struct TORCH_PYTHON_API type_caster<at::ArrayRef<c10::SymNode>> { |
130 | public: |
131 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
132 | PYBIND11_TYPE_CASTER(at::ArrayRef<c10::SymNode>, _("List[SymNode]" )); |
133 | |
134 | bool load(handle src, bool); |
135 | static handle cast( |
136 | at::ArrayRef<c10::SymNode> src, |
137 | return_value_policy /* policy */, |
138 | handle /* parent */); |
139 | |
140 | private: |
141 | std::vector<c10::SymNode> v_value; |
142 | }; |
143 | |
144 | template <> |
145 | struct type_caster<at::MemoryFormat> { |
146 | public: |
147 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
148 | PYBIND11_TYPE_CASTER(at::MemoryFormat, _("torch.memory_format" )); |
149 | |
150 | bool load(handle src, bool) { |
151 | PyObject* obj = src.ptr(); |
152 | if (THPMemoryFormat_Check(obj)) { |
153 | value = reinterpret_cast<THPMemoryFormat*>(obj)->memory_format; |
154 | return true; |
155 | } |
156 | return false; |
157 | } |
158 | static handle cast( |
159 | at::MemoryFormat src, |
160 | return_value_policy /* policy */, |
161 | handle /* parent */) { |
162 | return handle(torch::utils::getTHPMemoryFormat(src)); |
163 | } |
164 | }; |
165 | |
166 | template <> |
167 | struct type_caster<at::Device> { |
168 | public: |
169 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
170 | PYBIND11_TYPE_CASTER(at::Device, _("torch.device" )); |
171 | |
172 | // PYBIND11_TYPE_CASTER defines a member field called value. Since at::Device |
173 | // cannot be default-initialized, we provide this constructor to explicitly |
174 | // initialize that field. The value doesn't matter as it will be overwritten |
175 | // after a successful call to load. |
176 | type_caster() : value(c10::kCPU) {} |
177 | |
178 | bool load(handle src, bool) { |
179 | PyObject* obj = src.ptr(); |
180 | if (THPDevice_Check(obj)) { |
181 | value = reinterpret_cast<THPDevice*>(obj)->device; |
182 | return true; |
183 | } |
184 | return false; |
185 | } |
186 | |
187 | static handle cast( |
188 | const at::Device& src, |
189 | return_value_policy /* policy */, |
190 | handle /* parent */) { |
191 | return handle(THPDevice_New(src)); |
192 | } |
193 | }; |
194 | |
195 | template <> |
196 | struct type_caster<c10::DispatchKey> |
197 | : public type_caster_base<c10::DispatchKey> { |
198 | using base = type_caster_base<c10::DispatchKey>; |
199 | c10::DispatchKey tmp; |
200 | |
201 | public: |
202 | bool load(handle src, bool convert) { |
203 | if (base::load(src, convert)) { |
204 | return true; |
205 | } else if (py::isinstance( |
206 | src, py::module_::import("builtins" ).attr("str" ))) { |
207 | tmp = c10::parseDispatchKey(py::cast<std::string>(src)); |
208 | value = &tmp; |
209 | return true; |
210 | } |
211 | return false; |
212 | } |
213 | |
214 | static handle cast( |
215 | c10::DispatchKey src, |
216 | return_value_policy policy, |
217 | handle parent) { |
218 | return base::cast(src, policy, parent); |
219 | } |
220 | }; |
221 | |
222 | template <> |
223 | struct TORCH_PYTHON_API type_caster<c10::Scalar> { |
224 | public: |
225 | PYBIND11_TYPE_CASTER( |
226 | c10::Scalar, |
227 | _("Union[Number, torch.SymInt, torch.SymFloat]" )); |
228 | bool load(py::handle src, bool); |
229 | |
230 | static py::handle cast( |
231 | const c10::Scalar& si, |
232 | return_value_policy /* policy */, |
233 | handle /* parent */); |
234 | }; |
235 | |
236 | template <> |
237 | struct TORCH_PYTHON_API type_caster<c10::SymInt> { |
238 | public: |
239 | PYBIND11_TYPE_CASTER(c10::SymInt, _("Union[int, torch.SymInt]" )); |
240 | bool load(py::handle src, bool); |
241 | |
242 | static py::handle cast( |
243 | c10::SymInt si, |
244 | return_value_policy /* policy */, |
245 | handle /* parent */); |
246 | }; |
247 | |
248 | template <> |
249 | struct TORCH_PYTHON_API type_caster<c10::SymFloat> { |
250 | public: |
251 | PYBIND11_TYPE_CASTER(c10::SymFloat, _("float" )); |
252 | bool load(py::handle src, bool); |
253 | |
254 | static py::handle cast( |
255 | c10::SymFloat si, |
256 | return_value_policy /* policy */, |
257 | handle /* parent */); |
258 | }; |
259 | |
260 | template <typename T> |
261 | struct type_caster<c10::complex<T>> { |
262 | public: |
263 | // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) |
264 | PYBIND11_TYPE_CASTER(c10::complex<T>, _("complex" )); |
265 | |
266 | bool load(handle src, bool) { |
267 | PyObject* obj = src.ptr(); |
268 | |
269 | // Refered from `THPUtils_unpackComplexDouble` |
270 | Py_complex py_complex = PyComplex_AsCComplex(obj); |
271 | if (py_complex.real == -1.0 && PyErr_Occurred()) { |
272 | return false; |
273 | } |
274 | |
275 | // Python's Complex is always double precision. |
276 | value = c10::complex<double>(py_complex.real, py_complex.imag); |
277 | return true; |
278 | } |
279 | |
280 | static handle cast( |
281 | const c10::complex<T>& complex, |
282 | return_value_policy /* policy */, |
283 | handle /* parent */) { |
284 | // Python only knows double precision complex. |
285 | return handle(PyComplex_FromDoubles(complex.real(), complex.imag())); |
286 | } |
287 | }; |
288 | |
289 | // Pybind11 bindings for our optional and variant types. |
290 | // http://pybind11.readthedocs.io/en/stable/advanced/cast/stl.html#c-17-library-containers |
291 | template <typename T> |
292 | struct type_caster<c10::optional<T>> : optional_caster<c10::optional<T>> {}; |
293 | |
294 | template <typename... Ts> |
295 | struct C10_MPARK_VISIBILITY_HIDDEN type_caster<c10::variant<Ts...>> |
296 | : variant_caster<c10::variant<Ts...>> {}; |
297 | } // namespace detail |
298 | } // namespace pybind11 |
299 | |
300 | namespace torch { |
301 | namespace impl { |
302 | |
303 | // Use this function if you have a C++ object that is used from both C++ |
304 | // and Python contexts, and you need its GIL to be released when you |
305 | // destruct it in the Python context. |
306 | // |
307 | // This function is a valid shared_ptr destructor and can be used to |
308 | // conveniently allocate a shared_ptr to an object whose destructor will be run |
309 | // without the GIL. Pass it as the second argument to shared_ptr, e.g., |
310 | // |
311 | // shared_ptr<T>(new T(), destroy_without_gil<T>) |
312 | // |
313 | // Attaching the GIL release logic to the holder pointer rather than the |
314 | // actual destructor of T is helpful when T is Python-agnostic and |
315 | // shouldn't refer to the PYthon API. |
316 | // |
317 | // Note there are limitations to the correctness of code that makes use of this. |
318 | // In particular, if a shared_ptr is constructed from C++ code without this |
319 | // destructor and then passed to pybind11, pybind11 will happily take ownership |
320 | // of the shared_ptr (and be willing to destruct it from a context where it is |
321 | // holding the GIL). unique_ptr with a type branded deleter is less prone to |
322 | // this problem, because a stock deleter unique_ptr is not convertible with it. |
323 | // I plan to mitigate this problem by adding DEBUG-only asserts to the true C++ |
324 | // destructors that the GIL is not held (using a virtual call to get to the |
325 | // Python interpreter); alternately, we could use a virtual call to simply |
326 | // ensure we release the GIL in the C++ destructor, however, this is a layering |
327 | // violation (why does code that is ostensibly Python agnostic calling into the |
328 | // GIL). |
329 | // |
330 | // Adapted from |
331 | // https://github.com/pybind/pybind11/issues/1446#issuecomment-406341510 |
332 | template <typename T> |
333 | inline void destroy_without_gil(T* ptr) { |
334 | // Because the ownership of a shared_ptr is diffuse, it's not possible to |
335 | // necessarily predict whether or not the last reference to an object will |
336 | // be destructed from Python or C++. This means that in the destructor here, |
337 | // we don't necessarily know if we actually have the GIL or not; in fact, |
338 | // we don't even know if the Python interpreter still exists! Thus, we have |
339 | // to test for it before releasing the GIL. |
340 | // |
341 | // PyGILState_Check is hopefully self explanatory. But Py_IsInitialized or |
342 | // _PyIsFinalizing? Both get set at the same time during the Python |
343 | // destruction process: |
344 | // https://github.com/python/cpython/blob/d92513390a1a0da781bb08c284136f4d7abea36d/Python/pylifecycle.c#L1716-L1717 |
345 | // so the operant question is whether or not you want to release the GIL after |
346 | // finalization has completed (and there is just no Python interpreter). |
347 | // Clearly there is no need to release GIL in that state, so we want |
348 | // Py_IsInitialized. |
349 | if (Py_IsInitialized() && PyGILState_Check()) { |
350 | pybind11::gil_scoped_release nogil; |
351 | delete ptr; |
352 | } else { |
353 | delete ptr; |
354 | } |
355 | } |
356 | |
357 | } // namespace impl |
358 | } // namespace torch |
359 | |