1 | #pragma once |
2 | |
3 | #include <c10/core/impl/HermeticPyObjectTLS.h> |
4 | #include <c10/core/impl/PyInterpreter.h> |
5 | #include <c10/util/Optional.h> |
6 | #include <c10/util/python_stub.h> |
7 | |
8 | #include <atomic> |
9 | |
10 | namespace c10 { |
11 | namespace impl { |
12 | |
13 | struct C10_API PyObjectSlot { |
14 | public: |
15 | PyObjectSlot(); |
16 | |
17 | void destroy_pyobj_if_needed(); |
18 | |
19 | // Associate the TensorImpl with the specified PyObject, and, if necessary, |
20 | // also tag the interpreter. |
21 | // |
22 | // NB: This lives in a header so that we can inline away the switch on status |
23 | // |
24 | // NB: THIS FUNCTION CAN RAISE AN EXCEPTION. Make sure to clean up after |
25 | // PyObject if necessary! |
26 | void init_pyobj( |
27 | PyInterpreter* self_interpreter, |
28 | PyObject* pyobj, |
29 | PyInterpreterStatus status) { |
30 | impl::PyInterpreter* expected = nullptr; |
31 | switch (status) { |
32 | case impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED: |
33 | // caller guarantees there is no multithreaded access; if there is |
34 | // no data race OK to do a relaxed store |
35 | pyobj_interpreter_.store(self_interpreter, std::memory_order_relaxed); |
36 | break; |
37 | case impl::PyInterpreterStatus::TAGGED_BY_US: |
38 | // no tagging is necessary, the tag is already correct |
39 | break; |
40 | case impl::PyInterpreterStatus::MAYBE_UNINITIALIZED: |
41 | // attempt to claim this TensorImpl with the specified interpreter |
42 | // tag |
43 | if (pyobj_interpreter_.compare_exchange_strong( |
44 | expected, self_interpreter, std::memory_order_acq_rel)) { |
45 | break; |
46 | } |
47 | // test if, actually, it was already tagged by us! this situation can't |
48 | // be caused by a race, but it could be caused by a situation |
49 | // where someone conservatively tagged the tensor as MAYBE_UNINITIALIZED |
50 | // (because they didn't pre-check the tag) when actually it was |
51 | // owned by the interpreter |
52 | if (expected == self_interpreter) { |
53 | break; |
54 | } |
55 | // fallthrough, we lost the race. We are guaranteed not to lose the |
56 | // race with ourself, as calls to init_pyobj with the same interpreter |
57 | // ID must be sequentialized by the GIL |
58 | C10_FALLTHROUGH; |
59 | case impl::PyInterpreterStatus::TAGGED_BY_OTHER: |
60 | TORCH_CHECK( |
61 | false, |
62 | "cannot allocate PyObject for Tensor on interpreter " , |
63 | self_interpreter, |
64 | " that has already been used by another torch deploy interpreter " , |
65 | pyobj_interpreter_.load()); |
66 | } |
67 | |
68 | // we are the ONLY thread that can have gotten to this point. It is not |
69 | // possible to conflict with another zero interpreter as access is protected |
70 | // by GIL |
71 | // NB: owns_pyobj tag is initially false |
72 | pyobj_ = pyobj; |
73 | } |
74 | |
75 | // Query the PyObject interpreter. This may return null if there is no |
76 | // interpreter. This is racy! |
77 | PyInterpreter* pyobj_interpreter(); |
78 | |
79 | PyObject* _unchecked_untagged_pyobj() const; |
80 | |
81 | // Test the interpreter tag. If tagged for the current interpreter, return |
82 | // a non-nullopt (but possibly null) PyObject. If (possibly) untagged, |
83 | // returns a nullopt. If it is definitely invalid, raises an error. |
84 | // |
85 | // NB: this lives in header so that we can avoid actually creating the |
86 | // c10::optional |
87 | c10::optional<PyObject*> check_pyobj(PyInterpreter* self_interpreter) const { |
88 | // Note [Memory ordering on Python interpreter tag] |
89 | impl::PyInterpreter* interpreter = |
90 | pyobj_interpreter_.load(std::memory_order_acquire); |
91 | if (interpreter == nullptr) { |
92 | // NB: This never returns DEFINITELY_UNINITIALIZED because there is |
93 | // always the possibility that another thread races to initialize |
94 | // after we query here. The only time when we can conclude a tensor |
95 | // is definitely uninitialized is when we have just allocated it and |
96 | // it cannot have escaped to other threads yet |
97 | return c10::nullopt; |
98 | } else if (interpreter == self_interpreter) { |
99 | // NB: pyobj_ could still be null! |
100 | if (c10::impl::HermeticPyObjectTLS::get_state()) { |
101 | return c10::nullopt; |
102 | } else { |
103 | return c10::make_optional(_unchecked_untagged_pyobj()); |
104 | } |
105 | } else { |
106 | TORCH_CHECK( |
107 | false, |
108 | "cannot access PyObject for Tensor on interpreter " , |
109 | (*self_interpreter)->name(), |
110 | " that has already been used by another torch deploy interpreter " , |
111 | (*pyobj_interpreter_.load())->name()); |
112 | } |
113 | } |
114 | |
115 | // Clear the PyObject field for an interpreter, in situations where we |
116 | // statically know the tensor is tagged with our interpreter. |
117 | void unchecked_clear_pyobj(PyInterpreter* interpreter); |
118 | |
119 | PyInterpreter& load_pyobj_interpreter() const; |
120 | |
121 | bool owns_pyobj(); |
122 | |
123 | void set_owns_pyobj(bool b); |
124 | |
125 | private: |
126 | // This field contains the interpreter tag for this object. See |
127 | // Note [Python interpreter tag] for general context |
128 | // |
129 | // Note [Memory ordering on Python interpreter tag] |
130 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
131 | // What memory_order do we need when accessing this atomic? We don't |
132 | // need a single total modification order (as provided by |
133 | // memory_order_seq_cst) as pyobj_interpreter_ is monotonic: it can only |
134 | // transition from -1 to some positive integer and never changes afterwards. |
135 | // Because there is only one modification, it trivially already has a total |
136 | // modification order (e.g., we don't need fences or locked instructions on |
137 | // x86) |
138 | // |
139 | // In fact, one could make a reasonable argument that relaxed reads are OK, |
140 | // due to the presence of external locking (GIL) to ensure that interactions |
141 | // with other data structures are still correctly synchronized, so that |
142 | // we fall in the "Single-Location Data Structures" case as described in |
143 | // http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2020/p2055r0.pdf |
144 | // However, on x86, it doesn't matter if I use acquire or relaxed on the load |
145 | // as I get the same assembly in both cases. So I just use the more |
146 | // conservative acquire (which will impede compiler optimizations but I don't |
147 | // care) |
148 | std::atomic<PyInterpreter*> pyobj_interpreter_; |
149 | |
150 | // This field contains a reference to a PyObject representing this Tensor. |
151 | // If pyobj is nullptr, when we transfer Tensor to Python, we allocate a new |
152 | // PyObject for it and set this field. This field does not have to be |
153 | // protected by an atomic as it is only allowed to be accessed when you hold |
154 | // the GIL, or during destruction of the tensor. |
155 | // |
156 | // When a PyObject dies, you are obligated to clear this field |
157 | // (otherwise, you will try to use-after-free the pyobj); this currently |
158 | // occurs in THPVariable_clear in torch/csrc/autograd/python_variable.cpp |
159 | // |
160 | // NB: Ordinarily, this should not be a strong reference, as if the |
161 | // PyObject owns the Tensor, this would create a reference cycle. |
162 | // However, sometimes this ownership flips. To track who owns |
163 | // who, this has a single pointer tag indicating whether or not the |
164 | // C++ object owns the PyObject (the common case, zero, means PyObject |
165 | // owns the C++ object); see _unchecked_untagged_pyobj for raw access |
166 | // or check_pyobj for checked access. See references to PyObject |
167 | // resurrection in torch/csrc/autograd/python_variable.cpp |
168 | PyObject* pyobj_; |
169 | }; |
170 | |
171 | } // namespace impl |
172 | } // namespace c10 |
173 | |