1 | #pragma once |
2 | |
3 | #include <c10/core/impl/PyInterpreter.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/python_stub.h> |
6 | |
7 | namespace c10 { |
8 | |
9 | // This is an safe owning holder for a PyObject, akin to pybind11's |
10 | // py::object, with two major differences: |
11 | // |
12 | // - It is in c10/core; i.e., you can use this type in contexts where |
13 | // you do not have a libpython dependency |
14 | // |
15 | // - It is multi-interpreter safe (ala torchdeploy); when you fetch |
16 | // the underlying PyObject* you are required to specify what the current |
17 | // interpreter context is and we will check that you match it. |
18 | // |
19 | // It is INVALID to store a reference to a Tensor object in this way; |
20 | // you should just use TensorImpl directly in that case! |
21 | struct C10_API SafePyObject { |
22 | // Steals a reference to data |
23 | SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) |
24 | : data_(data), pyinterpreter_(pyinterpreter) {} |
25 | |
26 | // In principle this could be copyable if we add an incref to PyInterpreter |
27 | // but for now it's easier to just disallow it. |
28 | SafePyObject(SafePyObject const&) = delete; |
29 | SafePyObject& operator=(SafePyObject const&) = delete; |
30 | |
31 | ~SafePyObject() { |
32 | (*pyinterpreter_)->decref(data_, /*is_tensor*/ false); |
33 | } |
34 | |
35 | c10::impl::PyInterpreter& pyinterpreter() const { |
36 | return *pyinterpreter_; |
37 | } |
38 | PyObject* ptr(const c10::impl::PyInterpreter*) const; |
39 | |
40 | private: |
41 | PyObject* data_; |
42 | c10::impl::PyInterpreter* pyinterpreter_; |
43 | }; |
44 | |
45 | // Like SafePyObject, but non-owning. Good for references to global PyObjects |
46 | // that will be leaked on interpreter exit. You get a copy constructor/assign |
47 | // this way. |
48 | struct C10_API SafePyHandle { |
49 | SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {} |
50 | SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter) |
51 | : data_(data), pyinterpreter_(pyinterpreter) {} |
52 | |
53 | c10::impl::PyInterpreter& pyinterpreter() const { |
54 | return *pyinterpreter_; |
55 | } |
56 | PyObject* ptr(const c10::impl::PyInterpreter*) const; |
57 | void reset() { |
58 | data_ = nullptr; |
59 | pyinterpreter_ = nullptr; |
60 | } |
61 | operator bool() { |
62 | return data_; |
63 | } |
64 | |
65 | private: |
66 | PyObject* data_; |
67 | c10::impl::PyInterpreter* pyinterpreter_; |
68 | }; |
69 | |
70 | } // namespace c10 |
71 | |