1 | #pragma once |
2 | |
3 | #include <cstdint> |
4 | #include <ATen/core/function_schema.h> |
5 | #include <ATen/core/jit_type.h> |
6 | #include <c10/util/Bitset.h> |
7 | #include <c10/core/DispatchKeySet.h> |
8 | #include <c10/util/irange.h> |
9 | #include <ATen/core/Variadic.h> |
10 | #include <ATen/core/stack.h> |
11 | |
12 | namespace c10 { |
13 | |
14 | namespace impl { |
15 | |
16 | // Take a DispatchKeySet for a Tensor and determine what the actual dispatch |
17 | // DispatchKey should be, taking into account TLS, and skipping backends which |
18 | // fall through. |
19 | // |
20 | // Unlike Tensor::key_set(), the value of this on a tensor can change depending |
21 | // on TLS. |
22 | // |
23 | // NB: If there is no valid dispatch key, this will return Undefined |
24 | static inline DispatchKeySet computeDispatchKeySet( |
25 | DispatchKeySet ks, |
26 | // The key mask lets us eliminate (by zero entries) keys which should not |
27 | // be considered for dispatch. There are two cases when we use this: |
28 | // |
29 | // - If an operator's dispatch table contains a fallthrough entry, we |
30 | // should bypass it entirely when finding the key |
31 | // - If a user invokes with redispatch, the mask lets us |
32 | // zero out the key the user asked us to stop. |
33 | // |
34 | // These excluded backends are NOT tracked in the TLS, but must be applied |
35 | // AFTER TLS (since the backend may have been introduced for consideration |
36 | // by the included TLS), which is why you have to pass them in to this |
37 | // function (as opposed to just applying it to the input 'ks'). |
38 | DispatchKeySet key_mask |
39 | ) { |
40 | c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set(); |
41 | // TODO: It's a bit irritating that we have to do logical ORs here, it would |
42 | // be nice to only do one. Can always_included be folded into the TLS? Well, |
43 | // it's a bit troublesome, because fastpath TLS access requires the type of |
44 | // the TLS in question to be zero-initialized, so you don't actually win |
45 | // anyting in that case. |
46 | return (((ks | local.included_) - local.excluded_) & key_mask); |
47 | } |
48 | |
49 | } |
50 | |
51 | namespace detail { |
52 | // A small gadget to extract the DispatchKeySet from types which are known |
53 | // to have it. Used to extract dispatch keys from unboxed calls. |
54 | struct MultiDispatchKeySet : at::IterArgs<MultiDispatchKeySet> { |
55 | DispatchKeySet ts; |
56 | void operator()(const at::Tensor& x) { |
57 | ts = ts | x.key_set(); |
58 | } |
59 | void operator()(const c10::optional<at::Tensor>& x) { |
60 | if (x.has_value()) { |
61 | ts = ts | x->key_set(); |
62 | } |
63 | } |
64 | void operator()(at::ArrayRef<at::Tensor> xs) { |
65 | for (const auto& x : xs) { |
66 | ts = ts | x.key_set(); |
67 | } |
68 | } |
69 | // Tensor?[] translates to this case. |
70 | void operator()(const c10::List<c10::optional<at::Tensor>>& xs) { |
71 | for (c10::optional<at::Tensor> x : xs) { |
72 | if (x.has_value()) { |
73 | ts = ts | x.value().key_set(); |
74 | } |
75 | } |
76 | } |
77 | // Structured Tensor[] translates to this case |
78 | void operator()(at::ITensorListRef xs) { |
79 | for (const auto& x : xs) { |
80 | ts = ts | x.key_set(); |
81 | } |
82 | } |
83 | [[noreturn]] void operator()(at::ArrayRef<c10::optional<at::Tensor>>) { |
84 | // Just checking that the handling of Tensor?[] didn't change. |
85 | TORCH_INTERNAL_ASSERT(false); |
86 | } |
87 | void operator()(const at::Generator& gen) { |
88 | if (gen.defined()) { |
89 | ts = ts | gen.key_set(); |
90 | } |
91 | } |
92 | void operator()(const c10::optional<at::Generator>& gen) { |
93 | if (gen.has_value() && gen->defined()) { |
94 | ts = ts | gen->key_set(); |
95 | } |
96 | } |
97 | template <typename T> |
98 | void operator()(const T&) { |
99 | // do nothing |
100 | } |
101 | }; |
102 | |
103 | // NB: take by const reference (Don't do universal forwarding here! You |
104 | // don't want to move into this function!) |
105 | template <typename... Args> |
106 | DispatchKeySet multi_dispatch_key_set(const Args&... args) { |
107 | return MultiDispatchKeySet().apply(args...).ts; |
108 | } |
109 | } |
110 | |
111 | /** |
112 | * An instance of DispatchKeyExtractor knows how to get a dispatch key given |
113 | * a list of arguments for an operator call. |
114 | * |
115 | * The instance is specific for a certain operator as: |
116 | * - In boxed dispatch, different operators have different ways to extract |
117 | * the dispatch key (e.g. different numbers of arguments), and we precompute |
118 | * the stack locations we should look at; and |
119 | * - In all dispatch, some backends should be excluded from dispatch because |
120 | * they have been registered as fallthrough. The set of excluded backends |
121 | * varies from operator, as some operators may have overridden the |
122 | * fallthrough with custom behavior. |
123 | * |
124 | * Note - this should maintain identical impl to the py dispatcher key extraction logic |
125 | * at pytorch/torch/dispatcher.py |
126 | */ |
127 | struct TORCH_API final { |
128 | public: |
129 | static DispatchKeyExtractor (const FunctionSchema& schema) { |
130 | return DispatchKeyExtractor(makeBitsetForDispatchArgs(schema)); |
131 | } |
132 | |
133 | static DispatchKeyExtractor () { |
134 | return DispatchKeyExtractor(c10::utils::bitset()); |
135 | } |
136 | |
137 | void (const FunctionSchema& schema) { |
138 | TORCH_INTERNAL_ASSERT(dispatch_arg_indices_reverse_.is_entirely_unset()); |
139 | dispatch_arg_indices_reverse_ = makeBitsetForDispatchArgs(schema); |
140 | } |
141 | void () { |
142 | dispatch_arg_indices_reverse_ = c10::utils::bitset(); |
143 | } |
144 | |
145 | DispatchKeySet (const torch::jit::Stack* stack) const { |
146 | DispatchKeySet ks; |
147 | dispatch_arg_indices_reverse_.for_each_set_bit([&] (size_t reverse_arg_index) { |
148 | const auto& ivalue = torch::jit::peek(*stack, 0, reverse_arg_index + 1); |
149 | if (C10_LIKELY(ivalue.isTensor())) { |
150 | // NB: Take care not to introduce a refcount bump (there's |
151 | // no safe toTensorRef method, alas) |
152 | ks = ks | ivalue.unsafeToTensorImpl()->key_set(); |
153 | } else if (C10_UNLIKELY(ivalue.isTensorList())) { |
154 | for (const at::Tensor& tensor : ivalue.toTensorList()) { |
155 | ks = ks | tensor.key_set(); |
156 | } |
157 | } |
158 | // Tensor?[] translates to a c10::List<IValue> so we need to peek inside |
159 | else if (C10_UNLIKELY(ivalue.isList())) { |
160 | for (const auto& elt : ivalue.toListRef()) { |
161 | if (elt.isTensor()) { |
162 | ks = ks | elt.toTensor().key_set(); |
163 | } |
164 | } |
165 | } |
166 | }); |
167 | // Keys that are fallthrough should be skipped |
168 | if (requiresBitsetPerBackend_) { |
169 | auto backend_idx = ks.getBackendIndex(); |
170 | return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]); |
171 | } else { |
172 | return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); |
173 | } |
174 | } |
175 | |
176 | template<class... Args> |
177 | DispatchKeySet (const Args&... args) const { |
178 | auto ks = detail::multi_dispatch_key_set(args...); |
179 | // Keys that are fallthrough should be skipped |
180 | if (requiresBitsetPerBackend_) { |
181 | auto backend_idx = ks.getBackendIndex(); |
182 | return impl::computeDispatchKeySet(ks, nonFallthroughKeysPerBackend_[backend_idx]); |
183 | } else { |
184 | return impl::computeDispatchKeySet(ks, nonFallthroughKeys_); |
185 | } |
186 | } |
187 | |
188 | void (DispatchKey k, bool has_fallthrough); |
189 | |
190 | std::string () const; |
191 | void (const FunctionSchema& schema) const; |
192 | |
193 | private: |
194 | static c10::utils::bitset (const FunctionSchema& schema) { |
195 | TORCH_CHECK(schema.arguments().size() <= c10::utils::bitset::NUM_BITS(), |
196 | "The function schema has " , schema.arguments().size(), |
197 | " arguments but this PyTorch build only supports " , c10::utils::bitset::NUM_BITS()); |
198 | c10::utils::bitset dispatch_arg_indices_reverse; |
199 | for (const auto index : c10::irange(schema.arguments().size())) { |
200 | if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) || |
201 | schema.arguments()[index].type()->isSubtypeOf( |
202 | *ListType::ofTensors()) || |
203 | schema.arguments()[index].type()->isSubtypeOf( |
204 | *ListType::ofOptionalTensors()) || |
205 | schema.arguments()[index].type()->isSubtypeOf( |
206 | *OptionalType::ofTensor())) { |
207 | dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index); |
208 | } |
209 | } |
210 | return dispatch_arg_indices_reverse; |
211 | } |
212 | |
213 | explicit (c10::utils::bitset dispatch_arg_indices_reverse) |
214 | : dispatch_arg_indices_reverse_(dispatch_arg_indices_reverse) |
215 | , nonFallthroughKeys_(DispatchKeySet::FULL) |
216 | , requiresBitsetPerBackend_(false) { |
217 | for (const auto i : c10::irange(nonFallthroughKeysPerBackend_.size())) { |
218 | nonFallthroughKeysPerBackend_[i] = DispatchKeySet::FULL; |
219 | } |
220 | } |
221 | |
222 | // this is a bitset that has ones for each argument index which has to be |
223 | // considered for dispatch. This avoids having to iterate over the stack |
224 | // to find all the tensors. The bits are stored in reverse order, i.e. |
225 | // dispatch_arg_indices_reverse_[i] == true, then the i-th argument from |
226 | // the top of the stack (i.e. the i-th last argument of the function) |
227 | // is relevant for dispatch. |
228 | // dispatch_arg_indices_reverse_ is allowed to have zero bits set; that just means you must do the |
229 | // fallthrough |
230 | c10::utils::bitset ; |
231 | |
232 | // Set of functionality keys for which the operator does NOT have fallthrough kernel. |
233 | DispatchKeySet ; |
234 | // Set of functionality keys for which the operator does NOT have fallthrough kernel, defined PER BACKEND. |
235 | // This is only needed if we know that the operator has a different set of fallthroughs defined for some backends. |
236 | std::array<DispatchKeySet, num_backends> ; |
237 | // Flag to tell us if we can use the single set of nonFallthroughKeys_ (fast path), |
238 | // or if we need to fall back to the slower path and check nonFallthroughKeysPerBackend_ |
239 | bool ; |
240 | }; |
241 | |
242 | } |
243 | |