1 | #pragma once |
2 | #include <c10/core/DispatchKey.h> |
3 | #include <c10/util/Exception.h> |
4 | #include <c10/util/Metaprogramming.h> |
5 | #include <c10/util/llvmMathExtras.h> |
6 | #include <ostream> |
7 | |
8 | namespace c10 { |
9 | |
10 | struct FunctionalityOffsetAndMask { |
11 | // empty constructor shouldn't be used; only needed to initialize |
12 | // the array before populating it. |
13 | FunctionalityOffsetAndMask() = default; |
14 | FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask) |
15 | : offset(offset), mask(mask) {} |
16 | // This needs to big enough to cover the size of the operator table. |
17 | uint16_t offset{}; |
18 | // See Note [No More Than 16 Backends] |
19 | // This mask needs to be big enough to mask all of the backend bits. |
20 | // We probably don't ever want to have more than 16 backend bits, so uint16_t |
21 | // should be enough. |
22 | uint16_t mask{}; |
23 | }; |
24 | static_assert( |
25 | c10::num_runtime_entries < 65536, |
26 | "The dispatcher currently only supports up to 2^16 runtime entries" ); |
27 | |
28 | C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys> |
29 | initializeFunctionalityOffsetsAndMasks(); |
30 | |
31 | C10_ALWAYS_INLINE static const std:: |
32 | array<FunctionalityOffsetAndMask, num_functionality_keys>& |
33 | offsetsAndMasks() { |
34 | static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks(); |
35 | return offsets_and_masks_; |
36 | } |
37 | |
38 | // A representation of a set of DispatchKeys. A DispatchKeySet contains both |
39 | // "functionality" bits and "backend bits", and every tensor holds its own |
40 | // DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the |
41 | // keyset on every input tensor, or’ing them together, and dispatching to a |
42 | // specific piece of functionality. The functionality bits are *ordered*. When |
43 | // multiple functionality bits are set, we use the highest priority |
44 | // functionality. Similarly, multiple backend bits can theoretically be set if |
45 | // you call an operator with multiple tensors from difference devices (e.g. CPU |
46 | // and CUDA), although support for mixed device dispatch is limited (the only |
47 | // kernels that gracefully handle mixed device inputs for now are cuda kernels |
48 | // that take in a scalar cpu tensor). |
49 | |
50 | // A representation of a set of DispatchKeys. A tensor may have multiple |
51 | // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the |
52 | // DispatchKeySet specifies what type ids apply. The internal representation is |
53 | // as a 64-bit bit set (this means only 64 tensor type ids are supported). |
54 | // |
55 | // As mentioned above, DispatchKeys are ordered; thus, we can ask questions like |
56 | // "what is the highest priority DispatchKey in the set"? (The set itself is |
57 | // not ordered; two sets with the same ids will always have the ids ordered in |
58 | // the same way.) |
59 | // |
60 | // Note [DispatchKeySet Internal Representation] |
61 | // Internally, dispatch keys are packed into 64-bit DispatchKeySet objects |
62 | // that get passed around at runtime. |
63 | // However, there isn't necessarily a 1-to-1 mapping between bits in the keyset |
64 | // and individual dispatch keys. |
65 | // |
66 | // First: why do we have this distinction, and why not map every dispatch key |
67 | // directly to a bit? This is mostly because we have several types of |
68 | // functionalities that different backends would like to customize. For example, |
69 | // we have: |
70 | // - "Dense": CPU, CUDA, XLA, ... (~12 keys) |
71 | // - "Sparse": SparseCPU, SparseCUDA, ... |
72 | // - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ... |
73 | // - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ... |
74 | // The problem is that total number of keys grows quadratically with [# |
75 | // backends] x [# functionalities], making it very difficult to map each key |
76 | // directly to a bit in a bitset without dramatically increasing the size of the |
77 | // bitset over time. |
78 | // |
79 | // The two enums (BackendComponent and DispatchKey) can be divided roughly into |
80 | // 5 categories. |
81 | // |
82 | // (1) "Building block" keys |
83 | // (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit, |
84 | // CUDABIt) (b) functionalities: (per-backend) functionality-bit DispatchKeys |
85 | // (e.g. AutogradFunctionality, Sparse, Dense) |
86 | // (2) "Runtime" keys |
87 | // (a) "non-customizable backends" (e.g. FPGA) |
88 | // (b) "non-customizable functionalities" (e.g. Functionalize) |
89 | // (c) "per-backend instances of customizable functionalities" (e.g. CPU, |
90 | // SparseCPU, AutogradCPU) |
91 | // (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys]) |
92 | // |
93 | // (1) Building block keys always correspond to individual bits in a |
94 | // DispatchKeySet. They can also be combined in a DispatchKeySet to form actual |
95 | // runtime keys. e.g. |
96 | // auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit, |
97 | // DispatchKey::Dense}); |
98 | // // The keyset has the runtime dense-cpu key. |
99 | // dense_cpu_ks.has(DispatchKey::CPU); |
100 | // // And it contains the building block keys too. |
101 | // dense_cpu_ks.has(DispatchKey::CPUBit); |
102 | // dense_cpu_ks.has(DispatchKey::Dense); |
103 | // |
104 | // Not every backend and not every functionality counts as a "building block |
105 | // key". This is mostly to give us more levers to pull in the design space. |
106 | // Backend keys and functionality keys that count as "building blocks" will |
107 | // contribute to a full cross product of functionality that can be overriden. |
108 | // |
109 | // For example, right now we have at least 12 "backend" building blocks (CPU, |
110 | // CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense, |
111 | // Sparse, Quantized, AutogradFunctionality, ...). These keys together allow |
112 | // every dispatcher operator to be customized in up to 12*4 different ways. Each |
113 | // of those requires a slot in the operator table of every dispatcher operator. |
114 | // Not every piece of functionality necessarily needs to be customizeable |
115 | // per-backend, and not every backend necessarily needs to be able to customize |
116 | // every type of functionality. |
117 | // |
118 | // |
119 | // (2) Every runtime key corresponds directly to a slot in an operator's runtime |
120 | // dispatch table, and you can directly register kernels to a runtime dispatch |
121 | // key. |
122 | // |
123 | // For per-backend functionalities like "Dense" or "AutogradFunctionality", |
124 | // you can think of the corresponding runtime dispatch keys as "instances" of |
125 | // that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all |
126 | // runtime instances of the "Dense" building block key. |
127 | |
128 | // (2a) and (2b) are represented identically in the DispatchKeySet logic: |
129 | // - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT |
130 | // customizeable per backend. |
131 | // In order to do so, we'd need to promote it to a per-backend functionality |
132 | // "building block" key. |
133 | // - non-customizeable backends (e.g. FPGA) can NOT customize existing |
134 | // functionality like Sparse, Autograd, etc. |
135 | // In order to do so, we'd need to promote it to a backend "building block" |
136 | // key. |
137 | // |
138 | // In both cases, these keys directly correspond to runtime slots in the |
139 | // operator table. |
140 | // |
141 | // |
142 | // (3) "Alias" keys |
143 | // See Note [Alias Dispatch Keys] |
144 | // |
145 | // Final note: for anyone making future changes to the Dispatcher + |
146 | // DispatchKeySet internals, there's a closed PR with a basic |
147 | // python-implementation of the Dispatcher that might be useful in quickly |
148 | // testing out and validating changes. See it at |
149 | // https://github.com/pytorch/pytorch/pull/68743 |
150 | |
151 | // An undefined tensor is one with an empty tensor type set. |
152 | class DispatchKeySet final { |
153 | public: |
154 | enum Full { FULL }; |
155 | enum FullAfter { FULL_AFTER }; |
156 | enum Raw { RAW }; |
157 | |
158 | // NB: default constructor representation as zero is MANDATORY as |
159 | // use of DispatchKeySet in TLS requires this. |
160 | constexpr DispatchKeySet() = default; |
161 | |
162 | constexpr DispatchKeySet(Full) |
163 | : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {} |
164 | |
165 | constexpr DispatchKeySet(FullAfter, DispatchKey t) |
166 | // LSB after t are OK, but not t itself. |
167 | // "functionalities" have a notion of ordering (e.g. Autograd > Sparse > |
168 | // Quantized > Dense). But backends don't really have an ordering. |
169 | // Therefore, we're enforcing that FullAfter can only be used on |
170 | // "functionality" keys. |
171 | : repr_( |
172 | (1ULL |
173 | << (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) - |
174 | 1)) - |
175 | 1) { |
176 | *this = add(DispatchKey::PythonDispatcher); |
177 | } |
178 | |
179 | // Public version of DispatchKeySet(uint64_t) API; external users |
180 | // must be explicit when they do this! |
181 | constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {} |
182 | |
183 | constexpr explicit DispatchKeySet(BackendComponent k) { |
184 | if (k == BackendComponent::InvalidBit) { |
185 | repr_ = 0; |
186 | } else { |
187 | repr_ = 1ULL << (static_cast<uint8_t>(k) - 1); |
188 | } |
189 | } |
190 | |
191 | constexpr explicit DispatchKeySet(DispatchKey k) { |
192 | if (k == DispatchKey::Undefined) { |
193 | // Case 1: handle Undefined specifically |
194 | repr_ = 0; |
195 | } else if (k <= DispatchKey::EndOfFunctionalityKeys) { |
196 | // Case 2: handle "functionality-only" keys |
197 | // These keys have a functionality bit set, but no backend bits |
198 | // These can technically be either: |
199 | // - valid runtime keys (e.g. DispatchKey::AutogradOther, |
200 | // DispatchKey::FuncTorchBatched, etc) |
201 | // - "building block" keys that aren't actual runtime keys (e.g. |
202 | // DispatchKey::Dense or Sparse) |
203 | uint64_t functionality_val = 1ULL |
204 | << (num_backends + static_cast<uint8_t>(k) - 1); |
205 | repr_ = functionality_val; |
206 | } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) { |
207 | // Case 3: "runtime" keys that have a functionality bit AND a backend bit. |
208 | // First compute which bit to flip for the functionality. |
209 | auto functionality_k = toFunctionalityKey(k); |
210 | // The - 1 is because Undefined is technically a "functionality" that |
211 | // doesn't show up in the bitset. So e.g. Dense is technically the second |
212 | // functionality, but the lowest functionality bit. |
213 | uint64_t functionality_val = 1ULL |
214 | << (num_backends + static_cast<uint8_t>(functionality_k) - 1); |
215 | |
216 | // then compute which bit to flip for the backend |
217 | // Case 4a: handle the runtime instances of "per-backend functionality" |
218 | // keys For example, given DispatchKey::CPU, we should set: |
219 | // - the Dense functionality bit |
220 | // - the CPUBit backend bit |
221 | // first compute which bit to flip for the backend |
222 | auto backend_k = toBackendComponent(k); |
223 | uint64_t backend_val = backend_k == BackendComponent::InvalidBit |
224 | ? 0 |
225 | : 1ULL << (static_cast<uint8_t>(backend_k) - 1); |
226 | repr_ = functionality_val + backend_val; |
227 | } else { |
228 | // At this point, we should have covered every case except for alias keys. |
229 | // Technically it would be possible to add alias dispatch keys to a |
230 | // DispatchKeySet, but the semantics are a little confusing and this |
231 | // currently isn't needed anywhere. |
232 | repr_ = 0; |
233 | } |
234 | } |
235 | |
236 | constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) { |
237 | uint64_t repr = 0; |
238 | for (auto k : ks) { |
239 | repr |= DispatchKeySet(k).repr_; |
240 | } |
241 | return repr; |
242 | } |
243 | |
244 | constexpr uint64_t backend_bits_to_repr( |
245 | std::initializer_list<BackendComponent> ks) { |
246 | uint64_t repr = 0; |
247 | for (auto k : ks) { |
248 | repr |= DispatchKeySet(k).repr_; |
249 | } |
250 | return repr; |
251 | } |
252 | |
253 | explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks) |
254 | : repr_(keys_to_repr(ks)) {} |
255 | |
256 | explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks) |
257 | // Note: for some reason, putting this logic directly in the constructor |
258 | // appears to fail to compile on CUDA 10.1. |
259 | // See an example internal failure at |
260 | // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr |
261 | : repr_(backend_bits_to_repr(ks)) {} |
262 | |
263 | // Test if a DispatchKey is in the set |
264 | inline bool has(DispatchKey t) const { |
265 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined); |
266 | return has_all(DispatchKeySet(t)); |
267 | } |
268 | constexpr bool has_backend(BackendComponent t) const { |
269 | return has_all(DispatchKeySet(t)); |
270 | } |
271 | |
272 | // Test if a DispatchKey is in the set |
273 | // Given a DispatchKeySet of functionality keys and (potentially) backend |
274 | // keys, tests if all of them are in the current set. |
275 | constexpr bool has_all(DispatchKeySet ks) const { |
276 | return static_cast<bool>((repr_ & ks.repr_) == ks.repr_); |
277 | } |
278 | |
279 | // Given a DispatchKeySet of functionality keys and (potentially) backend |
280 | // keys, tests if any of them are in the current set. This could technically |
281 | // be pretty easily implemented using has(). It is strictly a perf |
282 | // optimization though. There are many places in the code base where we want |
283 | // to test for multiple functionality keys together. HOWEVER, runtime |
284 | // per-backend functionality keys aren't allowed to be used with this |
285 | // function, because you can end up with weird results. e.g. |
286 | // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU)) |
287 | // would return true. |
288 | inline bool has_any(DispatchKeySet ks) const { |
289 | TORCH_INTERNAL_ASSERT_DEBUG_ONLY( |
290 | // Either there are no backend bits in the input keyset |
291 | ((ks.repr_ & full_backend_mask) == 0) || |
292 | // or there are no per-backend-functionality bits |
293 | // See [Note: Per-Backend Functionality Dispatch Keys] |
294 | ((ks & |
295 | DispatchKeySet({ |
296 | DispatchKey::Dense, |
297 | DispatchKey::Quantized, |
298 | DispatchKey::Sparse, |
299 | DispatchKey::AutogradFunctionality, |
300 | }) |
301 | .repr_) == 0)); |
302 | return static_cast<bool>((repr_ & ks.repr_) != 0); |
303 | } |
304 | // Test if DispatchKeySet is a superset of ks. |
305 | bool isSupersetOf(DispatchKeySet ks) const { |
306 | return (repr_ & ks.repr_) == ks.repr_; |
307 | } |
308 | // Perform set union |
309 | constexpr DispatchKeySet operator|(DispatchKeySet other) const { |
310 | return DispatchKeySet(repr_ | other.repr_); |
311 | } |
312 | // Perform set intersection |
313 | constexpr DispatchKeySet operator&(DispatchKeySet other) const { |
314 | return DispatchKeySet(repr_ & other.repr_); |
315 | } |
316 | // Compute the set difference self - other, |
317 | // but ONLY for the functionality keys. |
318 | // Any backend bits set on self will remain unchanged. |
319 | // See Note [Removing keys from DispatchKeySet Only Affects Functionality |
320 | // Keys] |
321 | constexpr DispatchKeySet operator-(DispatchKeySet other) const { |
322 | return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_)); |
323 | } |
324 | |
325 | // Compute self ^ other |
326 | constexpr DispatchKeySet operator^(DispatchKeySet other) const { |
327 | return DispatchKeySet(repr_ ^ other.repr_); |
328 | } |
329 | bool operator==(DispatchKeySet other) const { |
330 | return repr_ == other.repr_; |
331 | } |
332 | bool operator!=(DispatchKeySet other) const { |
333 | return repr_ != other.repr_; |
334 | } |
335 | // Add a DispatchKey to the DispatchKey set. Does NOT mutate, |
336 | // returns the extended DispatchKeySet! |
337 | C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const { |
338 | return *this | DispatchKeySet(t); |
339 | } |
340 | C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const { |
341 | return *this | ks; |
342 | } |
343 | |
344 | // Remove a DispatchKey from the DispatchKey set. |
345 | // This is generally not an operation you should be doing |
346 | // (it's used to implement the printing overload, operator<<) |
347 | // |
348 | // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys] |
349 | // Only functionality bits are allowed to be removed from a keyset. |
350 | // For now, we're only allowing removal of "functionality bits" from the |
351 | // keyset, which is specifically needed by the fallthrough key calculation |
352 | // logic. Why is removing backend bits problematic? Consider this example: |
353 | // |
354 | // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA, |
355 | // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA) |
356 | // DispatchKeySet([DispatchKey.CPU, |
357 | // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA) |
358 | // |
359 | // What do we want to happen? |
360 | // Technically, we'd like it to be true that after removal, |
361 | // the first keyset still has the CUDA dispatch key while the second doesn't. |
362 | // Unfortunately there's no way to represent that, because the two keysets are |
363 | // represented the same way internally: functionality bits: Autograd, Dense |
364 | // backend bits: CPU, CUDA |
365 | // |
366 | // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd" |
367 | // bit from the bitset. |
368 | C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const { |
369 | return DispatchKeySet( |
370 | repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask)); |
371 | } |
372 | // You're allowed to remove a backend bit from a DispatchKeySet, |
373 | // but you have to be explicit about it (remove_backend() instead of |
374 | // remove()). |
375 | constexpr DispatchKeySet remove_backend(BackendComponent b) const { |
376 | return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_)); |
377 | } |
378 | // Is the set empty? (AKA undefined tensor) |
379 | bool empty() const { |
380 | return repr_ == 0; |
381 | } |
382 | uint64_t raw_repr() { |
383 | return repr_; |
384 | } |
385 | |
386 | DispatchKey highestFunctionalityKey() const { |
387 | auto functionality_idx = indexOfHighestBit(); |
388 | // This means that none of the functionality bits were set. |
389 | if (functionality_idx < num_backends) |
390 | return DispatchKey::Undefined; |
391 | // The first num_backend bits in the keyset don't correspond to real |
392 | // dispatch keys. |
393 | return static_cast<DispatchKey>(functionality_idx - num_backends); |
394 | } |
395 | |
396 | // This is similar like toBackendComponent(DispatchKey), but less restrictive. |
397 | // toBackendComponent() errors out if the key that it was passed has no |
398 | // backend bits, which is useful for error checking. We need a version of that |
399 | // here that can also handle "fake" backends like FPGA, because they need to |
400 | // map to the AutogradOther key. For those backends, we return |
401 | // BackendComponent::InvalidBit. |
402 | BackendComponent highestBackendKey() const { |
403 | // mask to mask out functionality bits |
404 | auto backend_idx = |
405 | DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit(); |
406 | // all zeros across the backend bits means that no backend bits are set. |
407 | if (backend_idx == 0) |
408 | return BackendComponent::InvalidBit; |
409 | return static_cast<BackendComponent>(backend_idx); |
410 | } |
411 | |
412 | // returns the DispatchKey of highest priority in the set. |
413 | DispatchKey highestPriorityTypeId() const { |
414 | auto functionality_k = highestFunctionalityKey(); |
415 | if (isPerBackendFunctionalityKey(functionality_k)) { |
416 | return toRuntimePerBackendFunctionalityKey( |
417 | functionality_k, highestBackendKey()); |
418 | } |
419 | return functionality_k; |
420 | } |
421 | |
422 | // Returns the index of the most-significant bit in the keyset. |
423 | // This is used to as part of the calculation into the operator table to get: |
424 | // - the highest "functionality" bit in the keyset. |
425 | // - the highest "backend" bit in the keyset. |
426 | uint8_t indexOfHighestBit() const { |
427 | return 64 - llvm::countLeadingZeros(repr_); |
428 | } |
429 | |
430 | #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) |
431 | // [Note: Trimmed Mobile Dispatch Keys] |
432 | /** |
433 | * The method below maps the dispatch key in the enum DispatchKey to an |
434 | * integer index in the dispatchTable_ array in OperatorEntry. The array |
435 | * is trimmed for mobile to reduce peak memory usage since it's |
436 | * unnecessary to reserve additional space for dispatch keys that will |
437 | * never be used on mobile. |
438 | */ |
439 | int getDispatchTableIndexForDispatchKeySet() const { |
440 | auto dk = highestPriorityTypeId(); |
441 | switch (dk) { |
442 | case DispatchKey::Undefined: |
443 | return 0; |
444 | case DispatchKey::CPU: |
445 | return 1; |
446 | case DispatchKey::QuantizedCPU: |
447 | return 2; |
448 | case DispatchKey::SparseCPU: |
449 | return 3; |
450 | case DispatchKey::BackendSelect: |
451 | return 4; |
452 | case DispatchKey::ADInplaceOrView: |
453 | return 5; |
454 | case DispatchKey::AutogradOther: |
455 | return 6; |
456 | case DispatchKey::AutogradCPU: |
457 | return 7; |
458 | default: |
459 | return -1; |
460 | } |
461 | } |
462 | #else |
463 | // returns the index in the operator table of highest priority key in the the |
464 | // keyset Note that we could in theory implement this using |
465 | // highestPriorityTypeId(), but this code is very hotpath and we can do it |
466 | // faster without it. |
467 | int getDispatchTableIndexForDispatchKeySet() const { |
468 | auto functionality_idx = |
469 | DispatchKeySet(repr_ >> num_backends).indexOfHighestBit(); |
470 | auto offset_and_mask = offsetsAndMasks()[functionality_idx]; |
471 | // Mask the functionality bits out first, then right-shift by 1. |
472 | // right-shifting by 1 because everything is zero-indexed. |
473 | // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should |
474 | // give us an offset of 1, etc. |
475 | auto backend_idx = |
476 | DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit(); |
477 | return offset_and_mask.offset + backend_idx; |
478 | } |
479 | #endif |
480 | |
481 | // returns the "index" of the highest priority backend in the keyset. |
482 | // This is pretty similar to getBackendKey(), but: |
483 | // - It's hotpath code (part of the runtime bitset calculation) |
484 | // - I's returns an integer index, not an enum value |
485 | // - Everything is shifted to the right by 1. |
486 | // BackendComponent::InvalidBit is technically the lowest enum value, |
487 | // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2, |
488 | // etc. |
489 | uint64_t getBackendIndex() const { |
490 | return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit(); |
491 | } |
492 | |
493 | private: |
494 | constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {} |
495 | uint64_t repr_ = 0; |
496 | |
497 | public: |
498 | // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys |
499 | // in the set. The iterator is only invalidated by the destruction of the |
500 | // underlying DispatchKeySet as the iterator stores a pointer to the raw |
501 | // representation of the DispatchKeySet. Note: When we encounter a per-backend |
502 | // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend |
503 | // in the keyset, for that functionality. For example, if the next |
504 | // functionality key to iterate over is Autograd, and the backend bits in the |
505 | // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit], |
506 | // then the next two keys we return will be DispatchKey::AutogradCPU, |
507 | // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than |
508 | // CUDA in DispatchKey.h). |
509 | class iterator { |
510 | public: |
511 | using self_type = iterator; |
512 | using iterator_category = std::input_iterator_tag; |
513 | using value_type = DispatchKey; |
514 | using difference_type = ptrdiff_t; |
515 | using reference = value_type&; |
516 | using pointer = value_type*; |
517 | // final mask value should mask out the entire keyset |
518 | static const uint8_t end_iter_mask_val = |
519 | num_backends + num_functionality_keys; |
520 | // final key value should be the last DispatchKey |
521 | static const uint8_t end_iter_key_val = num_functionality_keys; |
522 | |
523 | // current_dispatchkey_idx_ will iterate through all functionality bits. |
524 | // current_backendcomponent_idx_ will iterate through all backend bits. |
525 | explicit iterator( |
526 | const uint64_t* data_ptr, |
527 | uint8_t next_functionality = num_backends, |
528 | uint8_t next_backend = 0) |
529 | : data_ptr_(data_ptr), |
530 | next_functionality_(next_functionality), |
531 | next_backend_(next_backend), |
532 | // These are in an invalid state at construction time, and set by the |
533 | // first increment call |
534 | current_dispatchkey_idx_(end_iter_key_val), |
535 | current_backendcomponent_idx_(end_iter_key_val) { |
536 | // Go to the first key in the set |
537 | TORCH_INTERNAL_ASSERT( |
538 | next_functionality_ >= num_backends, |
539 | "num_backends=" , |
540 | static_cast<uint32_t>(num_backends), |
541 | "next_functionality_=" , |
542 | static_cast<uint32_t>(next_functionality_)); |
543 | ++(*this); |
544 | } |
545 | |
546 | C10_API self_type& operator++(); |
547 | |
548 | self_type operator++(int) { |
549 | self_type previous_iterator = *this; |
550 | ++(*this); |
551 | return previous_iterator; |
552 | } |
553 | |
554 | bool operator==(const self_type& rhs) const { |
555 | return next_functionality_ == rhs.next_functionality_ && |
556 | current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ && |
557 | next_backend_ == rhs.next_backend_ && |
558 | current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_; |
559 | } |
560 | bool operator!=(const self_type& rhs) const { |
561 | return next_functionality_ != rhs.next_functionality_ || |
562 | current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ || |
563 | next_backend_ != rhs.next_backend_ || |
564 | current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_; |
565 | } |
566 | DispatchKey operator*() const { |
567 | auto functionality_key = |
568 | static_cast<DispatchKey>(current_dispatchkey_idx_); |
569 | if (isPerBackendFunctionalityKey(functionality_key)) { |
570 | auto next_key = toRuntimePerBackendFunctionalityKey( |
571 | functionality_key, |
572 | static_cast<BackendComponent>(current_backendcomponent_idx_)); |
573 | // We expect all of the Dense, Sparse, Quantized, and Autograd keys to |
574 | // be ordered the same way with respect to their backends |
575 | TORCH_INTERNAL_ASSERT( |
576 | toBackendComponent(next_key) == |
577 | static_cast<BackendComponent>(current_backendcomponent_idx_), |
578 | "Tried to map functionality key " , |
579 | toString(functionality_key), |
580 | " and backend bit " , |
581 | toString( |
582 | static_cast<BackendComponent>(current_backendcomponent_idx_)), |
583 | " to a runtime key, but ended up with " , |
584 | toString(next_key), |
585 | ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent." , |
586 | " Please double check that enum for inconsistencies." ); |
587 | return next_key; |
588 | } else { |
589 | return functionality_key; |
590 | } |
591 | } |
592 | |
593 | private: |
594 | const uint64_t* data_ptr_; |
595 | uint8_t next_functionality_; |
596 | uint8_t next_backend_; |
597 | uint8_t current_dispatchkey_idx_; |
598 | uint8_t current_backendcomponent_idx_; |
599 | }; |
600 | |
601 | public: |
602 | // Returns iterator to the first key in the set. If no keys are in the |
603 | // set, then will return the end iterator. |
604 | iterator begin() const { |
605 | return iterator(&repr_); |
606 | } |
607 | |
608 | // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat |
609 | // this as the end iterator. |
610 | iterator end() const { |
611 | return iterator(&repr_, iterator::end_iter_mask_val); |
612 | } |
613 | }; |
614 | |
615 | C10_API std::string toString(DispatchKeySet); |
616 | C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); |
617 | |
618 | C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) { |
619 | return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet(); |
620 | } |
621 | |
622 | // Alias key DispatchKey::Autograd maps to |
623 | // (autograd_dispatch_keyset x full_backend_mask) |
624 | // NB: keys in this set also get associated with CompositeImplicitAutograd |
625 | // |
626 | // Note [autograd_dispatch_keyset Does Not Include Backend Bits] |
627 | // We don't want to include any backend bits (BackendComponent::CPUBit, etc) |
628 | // directly in autograd_dispatch_keyset. |
629 | // Why? keysets like autograd_dispatch_keyset are commonly used to remove |
630 | // autograd keys from a DispatchKeySet throughout the code base. However, you |
631 | // are only allowed to remove functionality bits from a keyset, not backend |
632 | // bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality |
633 | // Keys] for details. To be consistent and avoid confusion, we're explicitly |
634 | // setting up autograd_dispatch_keyset to not have any backend bits. |
635 | constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({ |
636 | DispatchKey::AutogradFunctionality, |
637 | DispatchKey::AutogradOther, |
638 | DispatchKey::AutogradNestedTensor, |
639 | }); |
640 | |
641 | constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({ |
642 | DispatchKey::AutocastCPU, |
643 | DispatchKey::AutocastCUDA, |
644 | DispatchKey::AutocastXPU, |
645 | DispatchKey::AutocastHPU, |
646 | }); |
647 | |
648 | // See Note [TLS Initialization] |
649 | constexpr DispatchKeySet default_included_set = DispatchKeySet({ |
650 | DispatchKey::BackendSelect, |
651 | DispatchKey::ADInplaceOrView, |
652 | }); |
653 | |
654 | constexpr DispatchKeySet default_excluded_set = DispatchKeySet({ |
655 | DispatchKey::AutocastCPU, |
656 | DispatchKey::AutocastCUDA, |
657 | DispatchKey::AutocastXPU, |
658 | DispatchKey::AutocastHPU, |
659 | }); |
660 | |
661 | constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView = |
662 | autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView); |
663 | |
664 | constexpr DispatchKeySet python_ks = DispatchKeySet({ |
665 | DispatchKey::Python, |
666 | DispatchKey::PythonTLSSnapshot, |
667 | }); |
668 | |
669 | constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse); |
670 | |
671 | constexpr DispatchKeySet sparse_csr_ks = |
672 | DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA}); |
673 | |
674 | constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU); |
675 | |
676 | // backend dispatch keys that map to DispatchKey::AutogradOther |
677 | // NB: keys in this set also get associated with CompositeImplicitAutograd |
678 | constexpr DispatchKeySet autogradother_backends = |
679 | DispatchKeySet( |
680 | // HIP and VE aren't in this list: they now have their own backend bits |
681 | // which means that they can now have their own Autograd keys. |
682 | // Technically, HIP will now redispatch to its own custom AutogradHIP |
683 | // slot in the runtime table. |
684 | {DispatchKey::FPGA, |
685 | DispatchKey::ORT, |
686 | DispatchKey::Vulkan, |
687 | DispatchKey::Metal, |
688 | DispatchKey::SparseCsrCPU, |
689 | DispatchKey::SparseCsrCUDA, |
690 | DispatchKey::CustomRNGKeyId, |
691 | DispatchKey::MkldnnCPU, |
692 | // Sparse and Quantized backends also live here. |
693 | DispatchKey::Sparse, |
694 | DispatchKey::Quantized}) |
695 | // Including the backend bits because this keyset is used during op |
696 | // registration, which requires looping over all runtime autogradother |
697 | // backend keys. |
698 | | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); |
699 | |
700 | // The set of dispatch keys that come after autograd |
701 | // n.b. this relies on the fact that AutogradOther is currently the lowest |
702 | // Autograd key |
703 | constexpr DispatchKeySet after_autograd_keyset = |
704 | DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther); |
705 | |
706 | // The set of dispatch keys that come after ADInplaceOrView |
707 | constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet( |
708 | DispatchKeySet::FULL_AFTER, |
709 | c10::DispatchKey::ADInplaceOrView); |
710 | |
711 | // The set of dispatch keys that come after Functionalize |
712 | constexpr DispatchKeySet after_func_keyset = |
713 | DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize) |
714 | .remove( |
715 | // NOTE: we also need to remove ADInplaceOrView from the keyset when |
716 | // redispatching after the func kernels. This is because we're not |
717 | // calling the same op; we originally called an inplace op, and now |
718 | // we aren't. The original key calculation figured out which keys |
719 | // were Fallthrough based on the inplace op. That means that it did |
720 | // not include the ADInPlaceOrView kernel as a fallthrough key. |
721 | // However, we WANT the ADInPlaceOrView kernel to be ignored now |
722 | // that we're calling an out-of-place op. Re-invoking |
723 | // Dispatcher::call would re-run the Fallthrough key calculation and |
724 | // get us that, But at::redispatch is more performant. We can get |
725 | // away with it by explicitly removing the key here. |
726 | c10::DispatchKey::ADInplaceOrView); |
727 | |
728 | constexpr DispatchKeySet backend_bitset_mask = |
729 | DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1); |
730 | |
731 | constexpr auto inplace_or_view_ks = |
732 | DispatchKeySet(DispatchKey::ADInplaceOrView); |
733 | constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU); |
734 | constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU); |
735 | constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU); |
736 | constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA); |
737 | constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA); |
738 | constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy); |
739 | constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta); |
740 | constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS); |
741 | constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU); |
742 | constexpr auto autograd_privateuse1_ks = |
743 | DispatchKeySet(DispatchKey::AutogradPrivateUse1); |
744 | constexpr auto autograd_privateuse2_ks = |
745 | DispatchKeySet(DispatchKey::AutogradPrivateUse2); |
746 | constexpr auto autograd_privateuse3_ks = |
747 | DispatchKeySet(DispatchKey::AutogradPrivateUse3); |
748 | constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther); |
749 | constexpr auto autograd_nested = |
750 | DispatchKeySet(DispatchKey::AutogradNestedTensor); |
751 | // keyset correpsonding to functorch keys that have their own dedicated |
752 | // TensorImpl subclass. |
753 | constexpr auto functorch_transforms_ks = DispatchKeySet( |
754 | {DispatchKey::FuncTorchBatched, |
755 | DispatchKey::FuncTorchVmapMode, |
756 | DispatchKey::Batched, |
757 | DispatchKey::VmapMode, |
758 | DispatchKey::FuncTorchGradWrapper}); |
759 | |
760 | constexpr auto functorch_batched_ks = |
761 | DispatchKeySet({DispatchKey::FuncTorchBatched}); |
762 | |
763 | // This keyset has: |
764 | // (1) the functionality bits corresponding to backends (dense, sparse, |
765 | // quantized) (2) all of the backend bits set |
766 | constexpr DispatchKeySet backend_functionality_keys = |
767 | DispatchKeySet({ |
768 | DispatchKey::Dense, |
769 | DispatchKey::Quantized, |
770 | DispatchKey::Sparse, |
771 | }) | |
772 | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); |
773 | |
774 | struct OpTableOffsetAndMask { |
775 | uint16_t offset; |
776 | uint16_t backend_mask; |
777 | }; |
778 | |
779 | static_assert( |
780 | num_backends <= 16, |
781 | "Right now we expect the number of backends not to exceed 16. In the (unlikely) event" |
782 | " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too." ); |
783 | |
784 | // true if t is a backend dispatch key |
785 | C10_API bool isBackendDispatchKey(DispatchKey t); |
786 | |
787 | // Resolve alias dispatch key to DispatchKeySet if applicable |
788 | C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t); |
789 | |
790 | // Resolve alias dispatch key to DispatchKeySet if applicable, |
791 | // and chek if k is a part of that set |
792 | C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k); |
793 | |
794 | // Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key |
795 | // t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd. |
796 | C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t); |
797 | |
798 | // Returns a DispatchKeySet of autograd related keys mapped to backend. |
799 | // for a given backend key, use the associated autograd key. |
800 | // for non-backend keys, use AutogradOther as a default. |
801 | // Note: it's convenient and fast to return a default here rather than (say) |
802 | // returning an optional<DispatchKey>, or throwing. But it makes callers |
803 | // responsible for either a) enforcing the invariant that only backend keys |
804 | // be passed as arguments, or b) interpreting our return value carefully. |
805 | inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) { |
806 | switch (t) { |
807 | case BackendComponent::CPUBit: |
808 | return inplace_or_view_ks | autograd_cpu_ks; |
809 | case BackendComponent::IPUBit: |
810 | return inplace_or_view_ks | autograd_ipu_ks; |
811 | case BackendComponent::XPUBit: |
812 | return inplace_or_view_ks | autograd_xpu_ks; |
813 | case BackendComponent::CUDABit: |
814 | return inplace_or_view_ks | autograd_cuda_ks; |
815 | case BackendComponent::XLABit: |
816 | return inplace_or_view_ks | autograd_xla_ks; |
817 | case BackendComponent::LazyBit: |
818 | return inplace_or_view_ks | autograd_lazy_ks; |
819 | case BackendComponent::MetaBit: |
820 | return inplace_or_view_ks | autograd_meta_ks; |
821 | case BackendComponent::MPSBit: |
822 | return inplace_or_view_ks | autograd_mps_ks; |
823 | case BackendComponent::HPUBit: |
824 | return inplace_or_view_ks | autograd_hpu_ks; |
825 | case BackendComponent::PrivateUse1Bit: |
826 | return inplace_or_view_ks | autograd_privateuse1_ks; |
827 | case BackendComponent::PrivateUse2Bit: |
828 | return inplace_or_view_ks | autograd_privateuse2_ks; |
829 | case BackendComponent::PrivateUse3Bit: |
830 | return inplace_or_view_ks | autograd_privateuse3_ks; |
831 | default: |
832 | return inplace_or_view_ks | autograd_other_ks; |
833 | } |
834 | } |
835 | |
836 | // Returns a DispatchKeySet of autocast related keys mapped to backend. |
837 | inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) { |
838 | constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU); |
839 | constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU); |
840 | constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU); |
841 | constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA); |
842 | switch (t) { |
843 | case BackendComponent::CPUBit: |
844 | return autocast_cpu_ks; |
845 | case BackendComponent::XPUBit: |
846 | return autocast_xpu_ks; |
847 | case BackendComponent::HPUBit: |
848 | return autocast_hpu_ks; |
849 | case BackendComponent::CUDABit: |
850 | case BackendComponent::XLABit: |
851 | return autocast_cuda_ks; |
852 | default: |
853 | return DispatchKeySet(); |
854 | } |
855 | } |
856 | |
857 | // returns the "backend" DispatchKey of highest priority in the set. |
858 | // This is basically like highestBackendKey(), except that we have some |
859 | // "functionality" bits that correspond to backends (Sparse, Quantized) |
860 | inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) { |
861 | return (ks & backend_functionality_keys).highestPriorityTypeId(); |
862 | } |
863 | |
864 | // This API exists because we have a use case for checking |
865 | // getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined) |
866 | // in OperatorEntry.cpp but we disallow it in has() API. |
867 | C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias); |
868 | |
869 | // Historically, every tensor only had a single DispatchKey, and it was always |
870 | // something like CPU, and there wasn't any of this business where TLS |
871 | // could cause the DispatchKey of a tensor to change. But we still have some |
872 | // legacy code that is still using DispatchKey for things like instanceof |
873 | // checks; if at all possible, refactor the code to stop using DispatchKey in |
874 | // those cases. |
875 | static inline DispatchKey (DispatchKeySet s) { |
876 | // NB: If you add any extra keys that can be stored in TensorImpl on |
877 | // top of existing "backend" keys like CPU/CUDA, you need to add it |
878 | // here. At the moment, autograd keys and ADInplaceOrView key need this |
879 | // treatment; |
880 | return (s - autograd_dispatch_keyset_with_ADInplaceOrView - |
881 | autocast_dispatch_keyset - |
882 | DispatchKeySet( |
883 | {DispatchKey::Functionalize, |
884 | DispatchKey::PythonTLSSnapshot, |
885 | DispatchKey::Python})) |
886 | .highestPriorityTypeId(); |
887 | } |
888 | |
889 | template <class T> |
890 | using is_not_DispatchKeySet = guts::negation<std::is_same<DispatchKeySet, T>>; |
891 | |
892 | // Given a function type, constructs a function_traits type that drops the first |
893 | // parameter type if the first parameter is of type DispatchKeySet. NB: |
894 | // DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid |
895 | // pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through |
896 | // the Dispatcher] for details). If at any point in the future we need to expose |
897 | // this type to JIT, revisit the usage of this type alias. |
898 | template <class FuncType> |
899 | using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t< |
900 | typename guts::infer_function_traits_t<FuncType>::return_type, |
901 | typename std::conditional_t< |
902 | std::is_same< |
903 | DispatchKeySet, |
904 | typename guts::typelist::head_with_default_t< |
905 | void, |
906 | typename guts::infer_function_traits_t< |
907 | FuncType>::parameter_types>>::value, |
908 | guts::typelist::drop_if_nonempty_t< |
909 | typename guts::infer_function_traits_t<FuncType>::parameter_types, |
910 | 1>, |
911 | typename guts::infer_function_traits_t<FuncType>::parameter_types>>; |
912 | } // namespace c10 |
913 | |