1 | #pragma once |
2 | |
3 | #include <c10/core/impl/PyInterpreter.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/python_stub.h> |
6 | |
7 | #include <atomic> |
8 | |
9 | namespace c10 { |
10 | |
11 | // A PyHandleCache represents a cached pointer from a C++ object to |
12 | // a Python object that represents that object analogously in Python. |
13 | // Upon a cache hit, the relevant object can be retrieved after a test |
14 | // and then a memory load. Two conditions must hold to be able to use this |
15 | // class: |
16 | // |
17 | // - This must truly be a cache; e.g., the caller must be able to produce |
18 | // the object some other way if the cache hit misses. |
19 | // |
20 | // - This must truly be a handle; e.g., the Python object referenced by |
21 | // this class must have static lifetime. This means we don't have to |
22 | // maintain strong ownership or deallocate the object when the C++ object |
23 | // dies. Static lifetime is a good idea in conjunction with the cache, |
24 | // since if you are producing a fresh object on miss you won't be |
25 | // maintaining object identity. If you need bidirectional ownership, |
26 | // you will want to factor out the pattern in TensorImpl with |
27 | // resurrection. |
28 | // |
29 | // This cache is expected to not improve perf under torchdeploy, as one |
30 | // interpreter will fill up the cache, and all the interpreters will be |
31 | // unable to use the slot. A potential improvement is to have multiple |
32 | // slots (one per interpreter), which will work in deployment scenarios |
33 | // where there a stable, fixed number of interpreters. You can also store |
34 | // the relevant state in the Python library, rather than in the non-Python |
35 | // library (although in many cases, this is not convenient, as there may |
36 | // not be a way to conveniently index based on the object.) |
37 | class PyHandleCache { |
38 | public: |
39 | PyHandleCache() : pyinterpreter_(nullptr), data_(nullptr) {} |
40 | |
41 | // Attempt to fetch the pointer from the cache, if the PyInterpreter |
42 | // matches. If it doesn't exist, or the cache entry is not valid, |
43 | // use slow_accessor to get the real pointer value and return that |
44 | // (possibly writing it to the cache, if the cache entry is |
45 | // available.) |
46 | template <typename F> |
47 | PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor) |
48 | const { |
49 | // Note [Memory ordering on Python interpreter tag] |
50 | impl::PyInterpreter* interpreter = |
51 | pyinterpreter_.load(std::memory_order_acquire); |
52 | if (C10_LIKELY(interpreter == self_interpreter)) { |
53 | return data_; |
54 | } else if (interpreter == nullptr) { |
55 | auto* r = slow_accessor(); |
56 | impl::PyInterpreter* expected = nullptr; |
57 | // attempt to claim this cache entry with the specified interpreter tag |
58 | if (pyinterpreter_.compare_exchange_strong( |
59 | expected, self_interpreter, std::memory_order_acq_rel)) { |
60 | data_ = r; |
61 | } |
62 | // This shouldn't be possible, as you should be GIL protected |
63 | TORCH_INTERNAL_ASSERT(expected != self_interpreter); |
64 | return r; |
65 | } else { |
66 | return slow_accessor(); |
67 | } |
68 | } |
69 | |
70 | private: |
71 | mutable std::atomic<impl::PyInterpreter*> pyinterpreter_; |
72 | mutable PyObject* data_; |
73 | }; |
74 | |
75 | } // namespace c10 |
76 | |