1#pragma once
2
3#include <c10/core/Device.h>
4#include <c10/core/Layout.h>
5#include <c10/core/MemoryFormat.h>
6#include <c10/core/SymIntArrayRef.h>
7#include <c10/macros/Macros.h>
8#include <c10/util/ArrayRef.h>
9#include <c10/util/intrusive_ptr.h>
10#include <c10/util/python_stub.h>
11#include <string>
12#include <vector>
13
14// Forward declarations
15
16namespace c10 {
17struct IValue;
18class OperatorHandle;
19struct TensorImpl;
20struct SafePyObject;
21} // namespace c10
22
23namespace torch {
24namespace jit {
25using Stack = std::vector<c10::IValue>;
26}
27} // namespace torch
28
29// Actual implementation
30
31namespace c10 {
32namespace impl {
33
34struct C10_API PyInterpreter;
35
36// Note [Python interpreter tag]
37// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
38// Traditionally, PyTorch is layered such that our Python library
39// (libtorch_python) references our pure C++ library (libtorch) as the
40// natural order of things. However, sometimes this natural order is
41// subverted: C++ objects refer to Python objects (for example, we
42// store a PyObject* pointer on TensorImpl so that converting from a
43// C++ Tensor to a Python Tensor is just a memory dereference).
44//
45// These unusual orderings must be treated with care. To start, you need to
46// virtualize the destructor so that the PyObject can be decref'ed on
47// destruction (because the C++ object itself doesn't know anything about
48// Python--remember, layering!). This process itself is fraught, since
49// acquiring the GIL could lead to deadlocks if someone is blocking on you
50// while holding the GIL. Furthermore, if the C++ objects outlive the
51// interpreter (which can happen if you stash them in a static global
52// variable defined in libtorch), you may attempt to decref the object when
53// the Python interpreter has already been shutdown.
54//
55// BUT WAIT, IT GETS WORSE. With torchdeploy, there may be multiple Python
56// interpreters in a single process. If a C++ object is accessible from
57// multiple interpreters, we must take care not to accidentally pass a
58// PyObject from one interpreter with another interpreter.
59//
60// To prevent these mixups, we introduce a PyInterpreter "tag" (object with
61// a vtable), which specifies a specific Python interpreter.
62//
63// - Any given object can be associated with AT MOST one Python interpreter.
64// We represent the interpreter tag as a memory address to an instance of
65// a virtual class that is allocated once per interpreter (this is so that
66// we can request the interpreter to perform operations for us, if
67// necessary).
68//
69// - It can be recorded with a PyObject (PyInterpreterObject) so that
70// we know what interpreter the object is associated with, and we can
71// raise an error if you try to use the PyObject from the wrong
72// interpreter context.
73//
74// - It contains a vtable that can be used to perform various Python
75// operations from ordinary C++ code that ordinarily wouldn't be accessible
76// from libtorch.
77//
78// A simple use case is when a C++ object must be associated with a PyObject.
79// However, for TensorImpl, we lazily allocate a PyObject the first time the
80// object passes into Python. The invariants for this situation are more
81// subtle:
82//
83// - A given TensorImpl's interpreter tag can only go from uninitialized to
84// tagged; once tagged, this is a quiescent state (once tagged to an
85// interpreter, ALWAYS tagged to that interpreter)
86//
87// - A thread may mutate the PyObject field of a TensorImpl if and only if it
88// holds the GIL for the interpreter tagged on the TensorImpl. (If the
89// TensorImpl is not tagged, it must first atomically claim its tag before it
90// can validly write)
91//
92// WARNING: This class has to be written very carefully, because it may be
93// possible for a Tensor to have a reference an interpreter corresponding to
94// a shared library that has ALREADY BEEN UNLOADED. This makes blindly calling
95// virtual methods very dangerous, because the vtable may be garbage at that
96// point (on a good day, you might get "pure virtual method called").
97//
98// The idea to solve this problem is we always leak PyInterpreters (so they
99// always stay live even after dlclose), and make sure we can disarm their
100// virtual methods by indirecting through a separate PyInterpreterVTable
101// object. This can be replaced with a no-op vtable from libc10.so, which
102// is guaranteed to stick around until the bitter end.
103//
104// NB: The downside with representing PyInterpreter tags as full objects is that
105// it takes an extra word on TensorImpl. If tags were instead just integer
106// indices, on 64-bit architectures we could pack the tag and PyObject together
107// into a single atomic word. On 32-bit architectures we could simply say that
108// only one Python interpreter is supported (erroring if a nontrivial
109// interpreter tag is attempted to be set).
110//
111// The difficulty with this scheme is we need to maintain an out-of-line table
112// to get at the PyInterpreters so that we can do virtual method calls on them,
113// and registration/deregistration to this table must be done in a thread safe
114// manner. This can be easily done if the number of possible PyInterpreters is
115// small enough (e.g., 8-bit integer) by simply preallocating an array of
116// sufficient size to hold all possible interpreters. Surely 128 threads is
117// more than enough for anyone!
118//
119// I didn't decide to do this technique at the moment, because the extra word
120// added by the PyInterpreter tag takes us to 24 words, which means that we
121// still fit inside three eight word cache lines. If you need to penny pinch
122// another word consider doing this!
123
124struct C10_API PyInterpreterVTable {
125 virtual ~PyInterpreterVTable() = default;
126
127 // Report the name of this interpreter
128 virtual std::string name() const = 0;
129
130 // Run Py_DECREF on a PyObject. We DO NOT assume the GIL is held on call
131 // See NOTE [PyInterpreter::decref takes an `is_tensor` arg]
132 virtual void decref(PyObject* pyobj, bool is_tensor) const = 0;
133
134 // Perform a detach by deferring to the __torch_dispatch__ implementation of
135 // detach, which will also arrange for the PyObject to get copied in this
136 // situation
137 virtual c10::intrusive_ptr<TensorImpl> detach(
138 const TensorImpl* self) const = 0;
139
140 // Invoke the Python boxed fallback dispatch to go back into Python
141 virtual void dispatch(const c10::OperatorHandle& op, torch::jit::Stack* stack)
142 const = 0;
143
144 // This is only invoked in the multipy/torchdeploy situation from
145 // pythonOpRegistrationTrampoline; this lets us get to the Python
146 // interpreter to actually find the appropriate Python op registration
147 // entry to call.
148 virtual void python_op_registration_trampoline(
149 const c10::OperatorHandle& op,
150 c10::DispatchKey,
151 torch::jit::Stack* stack) const = 0;
152
153 // Invoke the Python dispatcher to handle this call
154 virtual void python_dispatcher(
155 const c10::OperatorHandle& op,
156 c10::DispatchKeySet,
157 torch::jit::Stack* stack) const = 0;
158
159 virtual bool is_contiguous(const TensorImpl* self, at::MemoryFormat)
160 const = 0;
161 virtual bool is_strides_like(const TensorImpl* self, at::MemoryFormat)
162 const = 0;
163 virtual bool is_non_overlapping_and_dense(const TensorImpl* self) const = 0;
164 virtual c10::Device device(const TensorImpl* self) const = 0;
165 virtual int64_t dim(const TensorImpl* self) const = 0;
166 virtual c10::IntArrayRef strides(const TensorImpl* self) const = 0;
167 virtual c10::IntArrayRef sizes(const TensorImpl* self) const = 0;
168 virtual c10::SymIntArrayRef sym_sizes(const TensorImpl* self) const = 0;
169 virtual c10::Layout layout(const TensorImpl* self) const = 0;
170 virtual c10::SymInt sym_numel(const TensorImpl* self) const = 0;
171 virtual c10::SymIntArrayRef sym_strides(const TensorImpl* self) const = 0;
172 virtual c10::SymInt sym_storage_offset(const TensorImpl* self) const = 0;
173
174 virtual void trace_gpu_event_creation(uintptr_t event) const = 0;
175 virtual void trace_gpu_event_deletion(uintptr_t event) const = 0;
176 virtual void trace_gpu_event_record(uintptr_t event, uintptr_t stream)
177 const = 0;
178 virtual void trace_gpu_event_wait(uintptr_t event, uintptr_t stream)
179 const = 0;
180 virtual void trace_gpu_memory_allocation(uintptr_t ptr) const = 0;
181 virtual void trace_gpu_memory_deallocation(uintptr_t ptr) const = 0;
182 virtual void trace_gpu_stream_creation(uintptr_t stream) const = 0;
183 virtual void trace_gpu_device_synchronization() const = 0;
184 virtual void trace_gpu_stream_synchronization(uintptr_t stream) const = 0;
185 virtual void trace_gpu_event_synchronization(uintptr_t event) const = 0;
186
187 virtual void reset_backward_hooks(const TensorImpl* self) const = 0;
188};
189
190struct C10_API PyInterpreter {
191 const PyInterpreterVTable* vtable_;
192
193 PyInterpreter(const PyInterpreterVTable* vtable) : vtable_(vtable){};
194
195 const PyInterpreterVTable& operator*() const noexcept {
196 return *vtable_;
197 }
198 const PyInterpreterVTable* operator->() const noexcept {
199 return vtable_;
200 }
201
202 // Disarm this PyInterpreter, making all of its methods noops.
203 // The vtable pointer is not an atomic at the moment, which means
204 // a disarm() invocation that is concurrent with active destructors
205 // is not thread safe and will trigger TSAN. My hope is that this
206 // situations doesn't ever actually happen; tensor destruction should
207 // quiesce when a dlclose happens, and any long lived tensors whose
208 // destructors would be disarmed here only begin the destruction process
209 // on process shutdown (long after the dlclose has occurred).
210 void disarm() noexcept;
211};
212
213// PyInterpreterStatus describes what the state of its interpreter tag
214// is, relative to the thread currently holding the GIL.
215enum class PyInterpreterStatus {
216 // We just allocated the Tensor, it hasn't escaped to other threads,
217 // we know that it definitely hasn't been tagged to be associated
218 // with an interpreter.
219 DEFINITELY_UNINITIALIZED,
220 // We queried the interpreter field and it looked uninitialized. But
221 // another thread may have raced with us to tag it with some other
222 // interpreter id. So we will have to do a CEX to make sure we can
223 // actually nab it.
224 MAYBE_UNINITIALIZED,
225 // We queried the interpreter field and it was tagged to belong to us.
226 // This means we have sole write access (as we hold the GIL for this
227 // interpreter)
228 TAGGED_BY_US,
229 // Someone else tagged this. We can't use this TensorImpl from Python.
230 TAGGED_BY_OTHER,
231};
232
233} // namespace impl
234} // namespace c10
235