1#pragma once
2
3#include <c10/core/impl/PyInterpreter.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/python_stub.h>
6
7#include <atomic>
8
9namespace c10 {
10
11// A PyHandleCache represents a cached pointer from a C++ object to
12// a Python object that represents that object analogously in Python.
13// Upon a cache hit, the relevant object can be retrieved after a test
14// and then a memory load. Two conditions must hold to be able to use this
15// class:
16//
17// - This must truly be a cache; e.g., the caller must be able to produce
18// the object some other way if the cache hit misses.
19//
20// - This must truly be a handle; e.g., the Python object referenced by
21// this class must have static lifetime. This means we don't have to
22// maintain strong ownership or deallocate the object when the C++ object
23// dies. Static lifetime is a good idea in conjunction with the cache,
24// since if you are producing a fresh object on miss you won't be
25// maintaining object identity. If you need bidirectional ownership,
26// you will want to factor out the pattern in TensorImpl with
27// resurrection.
28//
29// This cache is expected to not improve perf under torchdeploy, as one
30// interpreter will fill up the cache, and all the interpreters will be
31// unable to use the slot. A potential improvement is to have multiple
32// slots (one per interpreter), which will work in deployment scenarios
33// where there a stable, fixed number of interpreters. You can also store
34// the relevant state in the Python library, rather than in the non-Python
35// library (although in many cases, this is not convenient, as there may
36// not be a way to conveniently index based on the object.)
37class PyHandleCache {
38 public:
39 PyHandleCache() : pyinterpreter_(nullptr), data_(nullptr) {}
40
41 // Attempt to fetch the pointer from the cache, if the PyInterpreter
42 // matches. If it doesn't exist, or the cache entry is not valid,
43 // use slow_accessor to get the real pointer value and return that
44 // (possibly writing it to the cache, if the cache entry is
45 // available.)
46 template <typename F>
47 PyObject* ptr_or(impl::PyInterpreter* self_interpreter, F slow_accessor)
48 const {
49 // Note [Memory ordering on Python interpreter tag]
50 impl::PyInterpreter* interpreter =
51 pyinterpreter_.load(std::memory_order_acquire);
52 if (C10_LIKELY(interpreter == self_interpreter)) {
53 return data_;
54 } else if (interpreter == nullptr) {
55 auto* r = slow_accessor();
56 impl::PyInterpreter* expected = nullptr;
57 // attempt to claim this cache entry with the specified interpreter tag
58 if (pyinterpreter_.compare_exchange_strong(
59 expected, self_interpreter, std::memory_order_acq_rel)) {
60 data_ = r;
61 }
62 // This shouldn't be possible, as you should be GIL protected
63 TORCH_INTERNAL_ASSERT(expected != self_interpreter);
64 return r;
65 } else {
66 return slow_accessor();
67 }
68 }
69
70 private:
71 mutable std::atomic<impl::PyInterpreter*> pyinterpreter_;
72 mutable PyObject* data_;
73};
74
75} // namespace c10
76