1 | #pragma once |
2 | |
3 | #include <c10/core/DispatchKeySet.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <c10/util/Flags.h> |
6 | |
7 | // TLS management for DispatchKeySet (the "local" DispatchKeySet(s)) |
8 | // |
9 | // This manages two thread-local DispatchKeySets: |
10 | // |
11 | // - The included type set, which adds a tensor type for consideration |
12 | // in dispatch. (For example, you might add Profiling to |
13 | // the included type set to turn on profiling on all tensor operations.) |
14 | // |
15 | // - The excluded type set, which disqualifies a tensor type from dispatch. |
16 | // (For example, after redispatching on variable, we disqualify |
17 | // Autograd so we don't attempt to handle variable again.) |
18 | // (Exclusion wins over inclusion.) |
19 | // |
20 | // NB: Originally, I implemented the excluded type set as storing the inverted |
21 | // set, but TLS is defined to be zero-initialized, so this doesn't actually work |
22 | // (if it's inverted, you want the set to be -1 initialized). |
23 | |
24 | namespace c10 { |
25 | namespace impl { |
26 | |
27 | // POD version of LocalDispatchKeySet. Declared here just so that |
28 | // we can put it in the guards. |
29 | // This struct encapsulates special handling for TLS initialization |
30 | // in set_included()/included() API so that they reflect the truth. |
31 | // If you want to create PODLocalDispatchKeySet with non-zero state, |
32 | // use set_included() instead of default constructor. |
33 | struct C10_API PODLocalDispatchKeySet { |
34 | uint64_t included_; |
35 | uint64_t excluded_; |
36 | |
37 | // See Note [TLS Initialization] |
38 | DispatchKeySet included() const { |
39 | return DispatchKeySet(DispatchKeySet::RAW, included_) ^ |
40 | c10::default_included_set; |
41 | } |
42 | DispatchKeySet excluded() const { |
43 | return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^ |
44 | c10::default_excluded_set; |
45 | } |
46 | |
47 | void set_included(DispatchKeySet x) { |
48 | included_ = (x ^ c10::default_included_set).raw_repr(); |
49 | } |
50 | void set_excluded(DispatchKeySet x) { |
51 | excluded_ = (x ^ c10::default_excluded_set).raw_repr(); |
52 | } |
53 | }; |
54 | static_assert( |
55 | std::is_trivial<PODLocalDispatchKeySet>::value, |
56 | "PODLocalDispatchKeySet must be a POD type." ); |
57 | |
58 | struct C10_API LocalDispatchKeySet { |
59 | /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x) |
60 | : included_(x.included()), excluded_(x.excluded()) {} |
61 | DispatchKeySet included_; |
62 | DispatchKeySet excluded_; |
63 | }; |
64 | |
65 | // thread_local variables cannot be C10_API on Windows. |
66 | // Inlining this seems to break AutoDispatchBelowAutograd on Android. |
67 | #if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) |
68 | C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); |
69 | #else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) |
70 | extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; |
71 | |
72 | inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() { |
73 | // Don't let people fiddle with the thread_local directly just |
74 | // because they include this header. |
75 | return raw_local_dispatch_key_set; |
76 | } |
77 | #endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE) |
78 | |
79 | // Internal, use ThreadLocalStateGuard |
80 | C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set); |
81 | |
82 | // RAII API for manipulating the thread-local dispatch state. |
83 | |
84 | class C10_API IncludeDispatchKeyGuard { |
85 | public: |
86 | IncludeDispatchKeyGuard(DispatchKeySet); |
87 | IncludeDispatchKeyGuard(DispatchKey k) |
88 | : IncludeDispatchKeyGuard(DispatchKeySet(k)) {} |
89 | IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete; |
90 | IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete; |
91 | IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete; |
92 | IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete; |
93 | ~IncludeDispatchKeyGuard(); |
94 | |
95 | private: |
96 | // A little micro-optimization to save us from tls_get_addr call |
97 | // on destruction |
98 | PODLocalDispatchKeySet* tls_; |
99 | DispatchKeySet include_; |
100 | }; |
101 | |
102 | class C10_API ExcludeDispatchKeyGuard { |
103 | public: |
104 | ExcludeDispatchKeyGuard(DispatchKeySet); |
105 | ExcludeDispatchKeyGuard(DispatchKey k) |
106 | : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {} |
107 | ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete; |
108 | ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete; |
109 | ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete; |
110 | ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete; |
111 | ~ExcludeDispatchKeyGuard(); |
112 | |
113 | private: |
114 | // A little micro-optimization to save us from tls_get_addr call |
115 | // on destruction |
116 | PODLocalDispatchKeySet* tls_; |
117 | DispatchKeySet exclude_; |
118 | }; |
119 | |
120 | struct C10_API ForceDispatchKeyGuard { |
121 | public: |
122 | ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set) |
123 | : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) { |
124 | c10::impl::_force_tls_local_dispatch_key_set(key_set); |
125 | } |
126 | ~ForceDispatchKeyGuard() { |
127 | c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_); |
128 | } |
129 | |
130 | private: |
131 | c10::impl::LocalDispatchKeySet saved_keyset_; |
132 | }; |
133 | |
134 | // Non-RAII API for manipulating the thread-local dispatch state. |
135 | // Please prefer the RAII API. The non-RAII API may be useful when |
136 | // the included/excluded state of a given DispatchKey must span |
137 | // many calls from the Python to the C++, so you cannot conveniently |
138 | // use an RAII guard. |
139 | // |
140 | // Example use case: a Python context manager that includes a certain |
141 | // DispatchKey, to ensure ops running under the context manager dispatch |
142 | // through that DispatchKey's registered overrides. |
143 | // |
144 | // The non-RAII API is less efficient than the RAII guards because both the |
145 | // getter and setter will do a tls_getaddr lookup (the RAII struct only needs |
146 | // one!) |
147 | |
148 | C10_API bool tls_is_dispatch_key_excluded(DispatchKey x); |
149 | C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state); |
150 | C10_API bool tls_is_dispatch_key_included(DispatchKey x); |
151 | C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state); |
152 | C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks); |
153 | C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks); |
154 | |
155 | } // namespace impl |
156 | } // namespace c10 |
157 | |