1#pragma once
2
3#include <c10/core/impl/HermeticPyObjectTLS.h>
4#include <c10/core/impl/PyInterpreter.h>
5#include <c10/util/Optional.h>
6#include <c10/util/python_stub.h>
7
8#include <atomic>
9
10namespace c10 {
11namespace impl {
12
13struct C10_API PyObjectSlot {
14 public:
15 PyObjectSlot();
16
17 void destroy_pyobj_if_needed();
18
19 // Associate the TensorImpl with the specified PyObject, and, if necessary,
20 // also tag the interpreter.
21 //
22 // NB: This lives in a header so that we can inline away the switch on status
23 //
24 // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after
25 // PyObject if necessary!
26 void init_pyobj(
27 PyInterpreter* self_interpreter,
28 PyObject* pyobj,
29 PyInterpreterStatus status) {
30 impl::PyInterpreter* expected = nullptr;
31 switch (status) {
32 case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED:
33 // caller guarantees there is no multithreaded access; if there is
34 // no data race OK to do a relaxed store
35 pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed);
36 break;
37 case impl::PyInterpreterStatus::TAGGED_BY_US:
38 // no tagging is necessary, the tag is already correct
39 break;
40 case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED:
41 // attempt to claim this TensorImpl with the specified interpreter
42 // tag
43 if (pyobj_interpreter_.compare_exchange_strong(
44 expected, self_interpreter, std::memory_order_acq_rel)) {
45 break;
46 }
47 // test if, actually, it was already tagged by us! this situation can't
48 // be caused by a race, but it could be caused by a situation
49 // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED
50 // (because they didn't pre-check the tag) when actually it was
51 // owned by the interpreter
52 if (expected == self_interpreter) {
53 break;
54 }
55 // fallthrough, we lost the race. We are guaranteed not to lose the
56 // race with ourself, as calls to init_pyobj with the same interpreter
57 // ID must be sequentialized by the GIL
58 C10_FALLTHROUGH;
59 case impl::PyInterpreterStatus::TAGGED_BY_OTHER:
60 TORCH_CHECK(
61 false,
62 "cannot allocate PyObject for Tensor on interpreter ",
63 self_interpreter,
64 " that has already been used by another torch deploy interpreter ",
65 pyobj_interpreter_.load());
66 }
67
68 // we are the ONLY thread that can have gotten to this point. It is not
69 // possible to conflict with another zero interpreter as access is protected
70 // by GIL
71 // NB: owns_pyobj tag is initially false
72 pyobj_ = pyobj;
73 }
74
75 // Query the PyObject interpreter. This may return null if there is no
76 // interpreter. This is racy!
77 PyInterpreter* pyobj_interpreter();
78
79 PyObject* _unchecked_untagged_pyobj() const;
80
81 // Test the interpreter tag. If tagged for the current interpreter, return
82 // a non-nullopt (but possibly null) PyObject. If (possibly) untagged,
83 // returns a nullopt. If it is definitely invalid, raises an error.
84 //
85 // NB: this lives in header so that we can avoid actually creating the
86 // c10::optional
87 c10::optional<PyObject*> check_pyobj(PyInterpreter* self_interpreter) const {
88 // Note [Memory ordering on Python interpreter tag]
89 impl::PyInterpreter* interpreter =
90 pyobj_interpreter_.load(std::memory_order_acquire);
91 if (interpreter == nullptr) {
92 // NB: This never returns DEFINITELY_UNINITIALIZED because there is
93 // always the possibility that another thread races to initialize
94 // after we query here. The only time when we can conclude a tensor
95 // is definitely uninitialized is when we have just allocated it and
96 // it cannot have escaped to other threads yet
97 return c10::nullopt;
98 } else if (interpreter == self_interpreter) {
99 // NB: pyobj_ could still be null!
100 if (c10::impl::HermeticPyObjectTLS::get_state()) {
101 return c10::nullopt;
102 } else {
103 return c10::make_optional(_unchecked_untagged_pyobj());
104 }
105 } else {
106 TORCH_CHECK(
107 false,
108 "cannot access PyObject for Tensor on interpreter ",
109 (*self_interpreter)->name(),
110 " that has already been used by another torch deploy interpreter ",
111 (*pyobj_interpreter_.load())->name());
112 }
113 }
114
115 // Clear the PyObject field for an interpreter, in situations where we
116 // statically know the tensor is tagged with our interpreter.
117 void unchecked_clear_pyobj(PyInterpreter* interpreter);
118
119 PyInterpreter& load_pyobj_interpreter() const;
120
121 bool owns_pyobj();
122
123 void set_owns_pyobj(bool b);
124
125 private:
126 // This field contains the interpreter tag for this object. See
127 // Note [Python interpreter tag] for general context
128 //
129 // Note [Memory ordering on Python interpreter tag]
130 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
131 // What memory_order do we need when accessing this atomic? We don't
132 // need a single total modification order (as provided by
133 // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only
134 // transition from -1 to some positive integer and never changes afterwards.
135 // Because there is only one modification, it trivially already has a total
136 // modification order (e.g., we don't need fences or locked instructions on
137 // x86)
138 //
139 // In fact, one could make a reasonable argument that relaxed reads are OK,
140 // due to the presence of external locking (GIL) to ensure that interactions
141 // with other data structures are still correctly synchronized, so that
142 // we fall in the "Single-Location Data Structures" case as described in
143 // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf
144 // However, on x86, it doesn't matter if I use acquire or relaxed on the load
145 // as I get the same assembly in both cases. So I just use the more
146 // conservative acquire (which will impede compiler optimizations but I don't
147 // care)
148 std::atomic<PyInterpreter*> pyobj_interpreter_;
149
150 // This field contains a reference to a PyObject representing this Tensor.
151 // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new
152 // PyObject for it and set this field. This field does not have to be
153 // protected by an atomic as it is only allowed to be accessed when you hold
154 // the GIL, or during destruction of the tensor.
155 //
156 // When a PyObject dies, you are obligated to clear this field
157 // (otherwise, you will try to use-after-free the pyobj); this currently
158 // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp
159 //
160 // NB: Ordinarily, this should not be a strong reference, as if the
161 // PyObject owns the Tensor, this would create a reference cycle.
162 // However, sometimes this ownership flips. To track who owns
163 // who, this has a single pointer tag indicating whether or not the
164 // C++ object owns the PyObject (the common case, zero, means PyObject
165 // owns the C++ object); see _unchecked_untagged_pyobj for raw access
166 // or check_pyobj for checked access. See references to PyObject
167 // resurrection in torch/csrc/autograd/python_variable.cpp
168 PyObject* pyobj_;
169};
170
171} // namespace impl
172} // namespace c10
173