1#pragma once
2
3#include <c10/core/impl/PyInterpreter.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/python_stub.h>
6
7namespace c10 {
8
9// This is an safe owning holder for a PyObject, akin to pybind11's
10// py::object, with two major differences:
11//
12// - It is in c10/core; i.e., you can use this type in contexts where
13// you do not have a libpython dependency
14//
15// - It is multi-interpreter safe (ala torchdeploy); when you fetch
16// the underlying PyObject* you are required to specify what the current
17// interpreter context is and we will check that you match it.
18//
19// It is INVALID to store a reference to a Tensor object in this way;
20// you should just use TensorImpl directly in that case!
21struct C10_API SafePyObject {
22 // Steals a reference to data
23 SafePyObject(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
24 : data_(data), pyinterpreter_(pyinterpreter) {}
25
26 // In principle this could be copyable if we add an incref to PyInterpreter
27 // but for now it's easier to just disallow it.
28 SafePyObject(SafePyObject const&) = delete;
29 SafePyObject& operator=(SafePyObject const&) = delete;
30
31 ~SafePyObject() {
32 (*pyinterpreter_)->decref(data_, /*is_tensor*/ false);
33 }
34
35 c10::impl::PyInterpreter& pyinterpreter() const {
36 return *pyinterpreter_;
37 }
38 PyObject* ptr(const c10::impl::PyInterpreter*) const;
39
40 private:
41 PyObject* data_;
42 c10::impl::PyInterpreter* pyinterpreter_;
43};
44
45// Like SafePyObject, but non-owning. Good for references to global PyObjects
46// that will be leaked on interpreter exit. You get a copy constructor/assign
47// this way.
48struct C10_API SafePyHandle {
49 SafePyHandle() : data_(nullptr), pyinterpreter_(nullptr) {}
50 SafePyHandle(PyObject* data, c10::impl::PyInterpreter* pyinterpreter)
51 : data_(data), pyinterpreter_(pyinterpreter) {}
52
53 c10::impl::PyInterpreter& pyinterpreter() const {
54 return *pyinterpreter_;
55 }
56 PyObject* ptr(const c10::impl::PyInterpreter*) const;
57 void reset() {
58 data_ = nullptr;
59 pyinterpreter_ = nullptr;
60 }
61 operator bool() {
62 return data_;
63 }
64
65 private:
66 PyObject* data_;
67 c10::impl::PyInterpreter* pyinterpreter_;
68};
69
70} // namespace c10
71