1 | #pragma once |
2 | |
3 | #include <c10/core/Device.h> |
4 | #include <c10/core/Layout.h> |
5 | #include <c10/core/MemoryFormat.h> |
6 | #include <c10/core/SymIntArrayRef.h> |
7 | #include <c10/macros/Macros.h> |
8 | #include <c10/util/ArrayRef.h> |
9 | #include <c10/util/intrusive_ptr.h> |
10 | #include <c10/util/python_stub.h> |
11 | #include <string> |
12 | #include <vector> |
13 | |
14 | // Forward declarations |
15 | |
16 | namespace c10 { |
17 | struct IValue; |
18 | class OperatorHandle; |
19 | struct TensorImpl; |
20 | struct SafePyObject; |
21 | } // namespace c10 |
22 | |
23 | namespace torch { |
24 | namespace jit { |
25 | using Stack = std::vector<c10::IValue>; |
26 | } |
27 | } // namespace torch |
28 | |
29 | // Actual implementation |
30 | |
31 | namespace c10 { |
32 | namespace impl { |
33 | |
34 | struct C10_API PyInterpreter; |
35 | |
36 | // Note [Python interpreter tag] |
37 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
38 | // Traditionally, PyTorch is layered such that our Python library |
39 | // (libtorch_python) references our pure C++ library (libtorch) as the |
40 | // natural order of things. However, sometimes this natural order is |
41 | // subverted: C++ objects refer to Python objects (for example, we |
42 | // store a PyObject* pointer on TensorImpl so that converting from a |
43 | // C++ Tensor to a Python Tensor is just a memory dereference). |
44 | // |
45 | // These unusual orderings must be treated with care. To start, you need to |
46 | // virtualize the destructor so that the PyObject can be decref'ed on |
47 | // destruction (because the C++ object itself doesn't know anything about |
48 | // Python--remember, layering!). This process itself is fraught, since |
49 | // acquiring the GIL could lead to deadlocks if someone is blocking on you |
50 | // while holding the GIL. Furthermore, if the C++ objects outlive the |
51 | // interpreter (which can happen if you stash them in a static global |
52 | // variable defined in libtorch), you may attempt to decref the object when |
53 | // the Python interpreter has already been shutdown. |
54 | // |
55 | // BUT WAIT, IT GETS WORSE. With torchdeploy, there may be multiple Python |
56 | // interpreters in a single process. If a C++ object is accessible from |
57 | // multiple interpreters, we must take care not to accidentally pass a |
58 | // PyObject from one interpreter with another interpreter. |
59 | // |
60 | // To prevent these mixups, we introduce a PyInterpreter "tag" (object with |
61 | // a vtable), which specifies a specific Python interpreter. |
62 | // |
63 | // - Any given object can be associated with AT MOST one Python interpreter. |
64 | // We represent the interpreter tag as a memory address to an instance of |
65 | // a virtual class that is allocated once per interpreter (this is so that |
66 | // we can request the interpreter to perform operations for us, if |
67 | // necessary). |
68 | // |
69 | // - It can be recorded with a PyObject (PyInterpreterObject) so that |
70 | // we know what interpreter the object is associated with, and we can |
71 | // raise an error if you try to use the PyObject from the wrong |
72 | // interpreter context. |
73 | // |
74 | // - It contains a vtable that can be used to perform various Python |
75 | // operations from ordinary C++ code that ordinarily wouldn't be accessible |
76 | // from libtorch. |
77 | // |
78 | // A simple use case is when a C++ object must be associated with a PyObject. |
79 | // However, for TensorImpl, we lazily allocate a PyObject the first time the |
80 | // object passes into Python. The invariants for this situation are more |
81 | // subtle: |
82 | // |
83 | // - A given TensorImpl's interpreter tag can only go from uninitialized to |
84 | // tagged; once tagged, this is a quiescent state (once tagged to an |
85 | // interpreter, ALWAYS tagged to that interpreter) |
86 | // |
87 | // - A thread may mutate the PyObject field of a TensorImpl if and only if it |
88 | // holds the GIL for the interpreter tagged on the TensorImpl. (If the |
89 | // TensorImpl is not tagged, it must first atomically claim its tag before it |
90 | // can validly write) |
91 | // |
92 | // WARNING: This class has to be written very carefully, because it may be |
93 | // possible for a Tensor to have a reference an interpreter corresponding to |
94 | // a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling |
95 | // virtual methods very dangerous, because the vtable may be garbage at that |
96 | // point (on a good day, you might get "pure virtual method called"). |
97 | // |
98 | // The idea to solve this problem is we always leak PyInterpreters (so they |
99 | // always stay live even after dlclose), and make sure we can disarm their |
100 | // virtual methods by indirecting through a separate PyInterpreterVTable |
101 | // object. This can be replaced with a no-op vtable from libc10.so, which |
102 | // is guaranteed to stick around until the bitter end. |
103 | // |
104 | // NB: The downside with representing PyInterpreter tags as full objects is that |
105 | // it takes an extra word on TensorImpl. If tags were instead just integer |
106 | // indices, on 64-bit architectures we could pack the tag and PyObject together |
107 | // into a single atomic word. On 32-bit architectures we could simply say that |
108 | // only one Python interpreter is supported (erroring if a nontrivial |
109 | // interpreter tag is attempted to be set). |
110 | // |
111 | // The difficulty with this scheme is we need to maintain an out-of-line table |
112 | // to get at the PyInterpreters so that we can do virtual method calls on them, |
113 | // and registration/deregistration to this table must be done in a thread safe |
114 | // manner. This can be easily done if the number of possible PyInterpreters is |
115 | // small enough (e.g., 8-bit integer) by simply preallocating an array of |
116 | // sufficient size to hold all possible interpreters. Surely 128 threads is |
117 | // more than enough for anyone! |
118 | // |
119 | // I didn't decide to do this technique at the moment, because the extra word |
120 | // added by the PyInterpreter tag takes us to 24 words, which means that we |
121 | // still fit inside three eight word cache lines. If you need to penny pinch |
122 | // another word consider doing this! |
123 | |
124 | struct C10_API PyInterpreterVTable { |
125 | virtual ~PyInterpreterVTable() = default; |
126 | |
127 | // Report the name of this interpreter |
128 | virtual std::string name() const = 0; |
129 | |
130 | // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call |
131 | // See NOTE [PyInterpreter::decref takes an `is_tensor` arg] |
132 | virtual void decref(PyObject* pyobj, bool is_tensor) const = 0; |
133 | |
134 | // Perform a detach by deferring to the __torch_dispatch__ implementation of |
135 | // detach, which will also arrange for the PyObject to get copied in this |
136 | // situation |
137 | virtual c10::intrusive_ptr<TensorImpl> detach( |
138 | const TensorImpl* self) const = 0; |
139 | |
140 | // Invoke the Python boxed fallback dispatch to go back into Python |
141 | virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack) |
142 | const = 0; |
143 | |
144 | // This is only invoked in the multipy/torchdeploy situation from |
145 | // pythonOpRegistrationTrampoline; this lets us get to the Python |
146 | // interpreter to actually find the appropriate Python op registration |
147 | // entry to call. |
148 | virtual void python_op_registration_trampoline( |
149 | const c10::OperatorHandle& op, |
150 | c10::DispatchKey, |
151 | torch::jit::Stack* stack) const = 0; |
152 | |
153 | // Invoke the Python dispatcher to handle this call |
154 | virtual void python_dispatcher( |
155 | const c10::OperatorHandle& op, |
156 | c10::DispatchKeySet, |
157 | torch::jit::Stack* stack) const = 0; |
158 | |
159 | virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat) |
160 | const = 0; |
161 | virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat) |
162 | const = 0; |
163 | virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0; |
164 | virtual c10::Device device(const TensorImpl* self) const = 0; |
165 | virtual int64_t dim(const TensorImpl* self) const = 0; |
166 | virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0; |
167 | virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0; |
168 | virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0; |
169 | virtual c10::Layout layout(const TensorImpl* self) const = 0; |
170 | virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0; |
171 | virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0; |
172 | virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0; |
173 | |
174 | virtual void trace_gpu_event_creation(uintptr_t event) const = 0; |
175 | virtual void trace_gpu_event_deletion(uintptr_t event) const = 0; |
176 | virtual void trace_gpu_event_record(uintptr_t event, uintptr_t stream) |
177 | const = 0; |
178 | virtual void trace_gpu_event_wait(uintptr_t event, uintptr_t stream) |
179 | const = 0; |
180 | virtual void trace_gpu_memory_allocation(uintptr_t ptr) const = 0; |
181 | virtual void trace_gpu_memory_deallocation(uintptr_t ptr) const = 0; |
182 | virtual void trace_gpu_stream_creation(uintptr_t stream) const = 0; |
183 | virtual void trace_gpu_device_synchronization() const = 0; |
184 | virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0; |
185 | virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0; |
186 | |
187 | virtual void reset_backward_hooks(const TensorImpl* self) const = 0; |
188 | }; |
189 | |
190 | struct C10_API PyInterpreter { |
191 | const PyInterpreterVTable* vtable_; |
192 | |
193 | PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable){}; |
194 | |
195 | const PyInterpreterVTable& operator*() const noexcept { |
196 | return *vtable_; |
197 | } |
198 | const PyInterpreterVTable* operator->() const noexcept { |
199 | return vtable_; |
200 | } |
201 | |
202 | // Disarm this PyInterpreter, making all of its methods noops. |
203 | // The vtable pointer is not an atomic at the moment, which means |
204 | // a disarm() invocation that is concurrent with active destructors |
205 | // is not thread safe and will trigger TSAN. My hope is that this |
206 | // situations doesn't ever actually happen; tensor destruction should |
207 | // quiesce when a dlclose happens, and any long lived tensors whose |
208 | // destructors would be disarmed here only begin the destruction process |
209 | // on process shutdown (long after the dlclose has occurred). |
210 | void disarm() noexcept; |
211 | }; |
212 | |
213 | // PyInterpreterStatus describes what the state of its interpreter tag |
214 | // is, relative to the thread currently holding the GIL. |
215 | enum class PyInterpreterStatus { |
216 | // We just allocated the Tensor, it hasn't escaped to other threads, |
217 | // we know that it definitely hasn't been tagged to be associated |
218 | // with an interpreter. |
219 | DEFINITELY_UNINITIALIZED, |
220 | // We queried the interpreter field and it looked uninitialized. But |
221 | // another thread may have raced with us to tag it with some other |
222 | // interpreter id. So we will have to do a CEX to make sure we can |
223 | // actually nab it. |
224 | MAYBE_UNINITIALIZED, |
225 | // We queried the interpreter field and it was tagged to belong to us. |
226 | // This means we have sole write access (as we hold the GIL for this |
227 | // interpreter) |
228 | TAGGED_BY_US, |
229 | // Someone else tagged this. We can't use this TensorImpl from Python. |
230 | TAGGED_BY_OTHER, |
231 | }; |
232 | |
233 | } // namespace impl |
234 | } // namespace c10 |
235 | |