1#pragma once
2
3#include <c10/core/DispatchKeySet.h>
4#include <c10/macros/Macros.h>
5#include <c10/util/Flags.h>
6
7// TLS management for DispatchKeySet (the "local" DispatchKeySet(s))
8//
9// This manages two thread-local DispatchKeySets:
10//
11// - The included type set, which adds a tensor type for consideration
12// in dispatch. (For example, you might add Profiling to
13// the included type set to turn on profiling on all tensor operations.)
14//
15// - The excluded type set, which disqualifies a tensor type from dispatch.
16// (For example, after redispatching on variable, we disqualify
17// Autograd so we don't attempt to handle variable again.)
18// (Exclusion wins over inclusion.)
19//
20// NB: Originally, I implemented the excluded type set as storing the inverted
21// set, but TLS is defined to be zero-initialized, so this doesn't actually work
22// (if it's inverted, you want the set to be -1 initialized).
23
24namespace c10 {
25namespace impl {
26
27// POD version of LocalDispatchKeySet. Declared here just so that
28// we can put it in the guards.
29// This struct encapsulates special handling for TLS initialization
30// in set_included()/included() API so that they reflect the truth.
31// If you want to create PODLocalDispatchKeySet with non-zero state,
32// use set_included() instead of default constructor.
33struct C10_API PODLocalDispatchKeySet {
34 uint64_t included_;
35 uint64_t excluded_;
36
37 // See Note [TLS Initialization]
38 DispatchKeySet included() const {
39 return DispatchKeySet(DispatchKeySet::RAW, included_) ^
40 c10::default_included_set;
41 }
42 DispatchKeySet excluded() const {
43 return DispatchKeySet(DispatchKeySet::RAW, excluded_) ^
44 c10::default_excluded_set;
45 }
46
47 void set_included(DispatchKeySet x) {
48 included_ = (x ^ c10::default_included_set).raw_repr();
49 }
50 void set_excluded(DispatchKeySet x) {
51 excluded_ = (x ^ c10::default_excluded_set).raw_repr();
52 }
53};
54static_assert(
55 std::is_trivial<PODLocalDispatchKeySet>::value,
56 "PODLocalDispatchKeySet must be a POD type.");
57
58struct C10_API LocalDispatchKeySet {
59 /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x)
60 : included_(x.included()), excluded_(x.excluded()) {}
61 DispatchKeySet included_;
62 DispatchKeySet excluded_;
63};
64
65// thread_local variables cannot be C10_API on Windows.
66// Inlining this seems to break AutoDispatchBelowAutograd on Android.
67#if defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
68C10_API LocalDispatchKeySet tls_local_dispatch_key_set();
69#else // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
70extern C10_API thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set;
71
72inline C10_API LocalDispatchKeySet tls_local_dispatch_key_set() {
73 // Don't let people fiddle with the thread_local directly just
74 // because they include this header.
75 return raw_local_dispatch_key_set;
76}
77#endif // defined(_MSC_VER) || defined(C10_ANDROID) || defined(C10_IPHONE)
78
79// Internal, use ThreadLocalStateGuard
80C10_API void _force_tls_local_dispatch_key_set(LocalDispatchKeySet key_set);
81
82// RAII API for manipulating the thread-local dispatch state.
83
84class C10_API IncludeDispatchKeyGuard {
85 public:
86 IncludeDispatchKeyGuard(DispatchKeySet);
87 IncludeDispatchKeyGuard(DispatchKey k)
88 : IncludeDispatchKeyGuard(DispatchKeySet(k)) {}
89 IncludeDispatchKeyGuard(const IncludeDispatchKeyGuard&) = delete;
90 IncludeDispatchKeyGuard operator=(const IncludeDispatchKeyGuard&) = delete;
91 IncludeDispatchKeyGuard(IncludeDispatchKeyGuard&&) = delete;
92 IncludeDispatchKeyGuard operator=(IncludeDispatchKeyGuard&&) = delete;
93 ~IncludeDispatchKeyGuard();
94
95 private:
96 // A little micro-optimization to save us from tls_get_addr call
97 // on destruction
98 PODLocalDispatchKeySet* tls_;
99 DispatchKeySet include_;
100};
101
102class C10_API ExcludeDispatchKeyGuard {
103 public:
104 ExcludeDispatchKeyGuard(DispatchKeySet);
105 ExcludeDispatchKeyGuard(DispatchKey k)
106 : ExcludeDispatchKeyGuard(DispatchKeySet(k)) {}
107 ExcludeDispatchKeyGuard(const ExcludeDispatchKeyGuard&) = delete;
108 ExcludeDispatchKeyGuard operator=(const ExcludeDispatchKeyGuard&) = delete;
109 ExcludeDispatchKeyGuard(ExcludeDispatchKeyGuard&&) = delete;
110 ExcludeDispatchKeyGuard operator=(ExcludeDispatchKeyGuard&&) = delete;
111 ~ExcludeDispatchKeyGuard();
112
113 private:
114 // A little micro-optimization to save us from tls_get_addr call
115 // on destruction
116 PODLocalDispatchKeySet* tls_;
117 DispatchKeySet exclude_;
118};
119
120struct C10_API ForceDispatchKeyGuard {
121 public:
122 ForceDispatchKeyGuard(c10::impl::LocalDispatchKeySet key_set)
123 : saved_keyset_(c10::impl::tls_local_dispatch_key_set()) {
124 c10::impl::_force_tls_local_dispatch_key_set(key_set);
125 }
126 ~ForceDispatchKeyGuard() {
127 c10::impl::_force_tls_local_dispatch_key_set(saved_keyset_);
128 }
129
130 private:
131 c10::impl::LocalDispatchKeySet saved_keyset_;
132};
133
134// Non-RAII API for manipulating the thread-local dispatch state.
135// Please prefer the RAII API. The non-RAII API may be useful when
136// the included/excluded state of a given DispatchKey must span
137// many calls from the Python to the C++, so you cannot conveniently
138// use an RAII guard.
139//
140// Example use case: a Python context manager that includes a certain
141// DispatchKey, to ensure ops running under the context manager dispatch
142// through that DispatchKey's registered overrides.
143//
144// The non-RAII API is less efficient than the RAII guards because both the
145// getter and setter will do a tls_getaddr lookup (the RAII struct only needs
146// one!)
147
148C10_API bool tls_is_dispatch_key_excluded(DispatchKey x);
149C10_API void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state);
150C10_API bool tls_is_dispatch_key_included(DispatchKey x);
151C10_API void tls_set_dispatch_key_included(DispatchKey x, bool desired_state);
152C10_API bool tls_is_dispatch_keyset_excluded(DispatchKeySet ks);
153C10_API bool tls_is_dispatch_keyset_included(DispatchKeySet ks);
154
155} // namespace impl
156} // namespace c10
157