1#pragma once
2#include <c10/core/DispatchKey.h>
3#include <c10/util/Exception.h>
4#include <c10/util/Metaprogramming.h>
5#include <c10/util/llvmMathExtras.h>
6#include <ostream>
7
8namespace c10 {
9
10struct FunctionalityOffsetAndMask {
11 // empty constructor shouldn't be used; only needed to initialize
12 // the array before populating it.
13 FunctionalityOffsetAndMask() = default;
14 FunctionalityOffsetAndMask(uint16_t offset, uint16_t mask)
15 : offset(offset), mask(mask) {}
16 // This needs to big enough to cover the size of the operator table.
17 uint16_t offset{};
18 // See Note [No More Than 16 Backends]
19 // This mask needs to be big enough to mask all of the backend bits.
20 // We probably don't ever want to have more than 16 backend bits, so uint16_t
21 // should be enough.
22 uint16_t mask{};
23};
24static_assert(
25 c10::num_runtime_entries < 65536,
26 "The dispatcher currently only supports up to 2^16 runtime entries");
27
28C10_API std::array<FunctionalityOffsetAndMask, num_functionality_keys>
29initializeFunctionalityOffsetsAndMasks();
30
31C10_ALWAYS_INLINE static const std::
32 array<FunctionalityOffsetAndMask, num_functionality_keys>&
33 offsetsAndMasks() {
34 static auto offsets_and_masks_ = initializeFunctionalityOffsetsAndMasks();
35 return offsets_and_masks_;
36}
37
38// A representation of a set of DispatchKeys. A DispatchKeySet contains both
39// "functionality" bits and "backend bits", and every tensor holds its own
40// DispatchKeySet. The Dispatcher implements multiple dispatch by grabbing the
41// keyset on every input tensor, or’ing them together, and dispatching to a
42// specific piece of functionality. The functionality bits are *ordered*. When
43// multiple functionality bits are set, we use the highest priority
44// functionality. Similarly, multiple backend bits can theoretically be set if
45// you call an operator with multiple tensors from difference devices (e.g. CPU
46// and CUDA), although support for mixed device dispatch is limited (the only
47// kernels that gracefully handle mixed device inputs for now are cuda kernels
48// that take in a scalar cpu tensor).
49
50// A representation of a set of DispatchKeys. A tensor may have multiple
51// tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the
52// DispatchKeySet specifies what type ids apply. The internal representation is
53// as a 64-bit bit set (this means only 64 tensor type ids are supported).
54//
55// As mentioned above, DispatchKeys are ordered; thus, we can ask questions like
56// "what is the highest priority DispatchKey in the set"? (The set itself is
57// not ordered; two sets with the same ids will always have the ids ordered in
58// the same way.)
59//
60// Note [DispatchKeySet Internal Representation]
61// Internally, dispatch keys are packed into 64-bit DispatchKeySet objects
62// that get passed around at runtime.
63// However, there isn't necessarily a 1-to-1 mapping between bits in the keyset
64// and individual dispatch keys.
65//
66// First: why do we have this distinction, and why not map every dispatch key
67// directly to a bit? This is mostly because we have several types of
68// functionalities that different backends would like to customize. For example,
69// we have:
70// - "Dense": CPU, CUDA, XLA, ... (~12 keys)
71// - "Sparse": SparseCPU, SparseCUDA, ...
72// - "Quantized": QuantizedCPU, QuantizedCUDA, QuantizedXLA, ...
73// - "Autograd": AutogradCPU, AutogradCUDA, Autograd XLA, ...
74// The problem is that total number of keys grows quadratically with [#
75// backends] x [# functionalities], making it very difficult to map each key
76// directly to a bit in a bitset without dramatically increasing the size of the
77// bitset over time.
78//
79// The two enums (BackendComponent and DispatchKey) can be divided roughly into
80// 5 categories.
81//
82// (1) "Building block" keys
83// (a) backends: jEverything in the BackendComponent enum (e.g. CPUBit,
84// CUDABIt) (b) functionalities: (per-backend) functionality-bit DispatchKeys
85// (e.g. AutogradFunctionality, Sparse, Dense)
86// (2) "Runtime" keys
87// (a) "non-customizable backends" (e.g. FPGA)
88// (b) "non-customizable functionalities" (e.g. Functionalize)
89// (c) "per-backend instances of customizable functionalities" (e.g. CPU,
90// SparseCPU, AutogradCPU)
91// (3) "Alias" DispatchKeys (see Note [Alias Dispatch Keys])
92//
93// (1) Building block keys always correspond to individual bits in a
94// DispatchKeySet. They can also be combined in a DispatchKeySet to form actual
95// runtime keys. e.g.
96// auto dense_cpu_ks = DispatchKeySet({DispatchKey::CPUBit,
97// DispatchKey::Dense});
98// // The keyset has the runtime dense-cpu key.
99// dense_cpu_ks.has(DispatchKey::CPU);
100// // And it contains the building block keys too.
101// dense_cpu_ks.has(DispatchKey::CPUBit);
102// dense_cpu_ks.has(DispatchKey::Dense);
103//
104// Not every backend and not every functionality counts as a "building block
105// key". This is mostly to give us more levers to pull in the design space.
106// Backend keys and functionality keys that count as "building blocks" will
107// contribute to a full cross product of functionality that can be overriden.
108//
109// For example, right now we have at least 12 "backend" building blocks (CPU,
110// CUDA, XLA, ...) and at least 4 "functionality" building blocks (Dense,
111// Sparse, Quantized, AutogradFunctionality, ...). These keys together allow
112// every dispatcher operator to be customized in up to 12*4 different ways. Each
113// of those requires a slot in the operator table of every dispatcher operator.
114// Not every piece of functionality necessarily needs to be customizeable
115// per-backend, and not every backend necessarily needs to be able to customize
116// every type of functionality.
117//
118//
119// (2) Every runtime key corresponds directly to a slot in an operator's runtime
120// dispatch table, and you can directly register kernels to a runtime dispatch
121// key.
122//
123// For per-backend functionalities like "Dense" or "AutogradFunctionality",
124// you can think of the corresponding runtime dispatch keys as "instances" of
125// that functionality, per backend. E.g. "CPU", "CUDA", "XLA", etc. are all
126// runtime instances of the "Dense" building block key.
127
128// (2a) and (2b) are represented identically in the DispatchKeySet logic:
129// - backend-agnostic functionalities (e.g. FuncTorchBatched) are NOT
130// customizeable per backend.
131// In order to do so, we'd need to promote it to a per-backend functionality
132// "building block" key.
133// - non-customizeable backends (e.g. FPGA) can NOT customize existing
134// functionality like Sparse, Autograd, etc.
135// In order to do so, we'd need to promote it to a backend "building block"
136// key.
137//
138// In both cases, these keys directly correspond to runtime slots in the
139// operator table.
140//
141//
142// (3) "Alias" keys
143// See Note [Alias Dispatch Keys]
144//
145// Final note: for anyone making future changes to the Dispatcher +
146// DispatchKeySet internals, there's a closed PR with a basic
147// python-implementation of the Dispatcher that might be useful in quickly
148// testing out and validating changes. See it at
149// https://github.com/pytorch/pytorch/pull/68743
150
151// An undefined tensor is one with an empty tensor type set.
152class DispatchKeySet final {
153 public:
154 enum Full { FULL };
155 enum FullAfter { FULL_AFTER };
156 enum Raw { RAW };
157
158 // NB: default constructor representation as zero is MANDATORY as
159 // use of DispatchKeySet in TLS requires this.
160 constexpr DispatchKeySet() = default;
161
162 constexpr DispatchKeySet(Full)
163 : repr_((1ULL << (num_backends + num_functionality_keys - 1)) - 1) {}
164
165 constexpr DispatchKeySet(FullAfter, DispatchKey t)
166 // LSB after t are OK, but not t itself.
167 // "functionalities" have a notion of ordering (e.g. Autograd > Sparse >
168 // Quantized > Dense). But backends don't really have an ordering.
169 // Therefore, we're enforcing that FullAfter can only be used on
170 // "functionality" keys.
171 : repr_(
172 (1ULL
173 << (num_backends + static_cast<uint8_t>(toFunctionalityKey(t)) -
174 1)) -
175 1) {
176 *this = add(DispatchKey::PythonDispatcher);
177 }
178
179 // Public version of DispatchKeySet(uint64_t) API; external users
180 // must be explicit when they do this!
181 constexpr DispatchKeySet(Raw, uint64_t x) : repr_(x) {}
182
183 constexpr explicit DispatchKeySet(BackendComponent k) {
184 if (k == BackendComponent::InvalidBit) {
185 repr_ = 0;
186 } else {
187 repr_ = 1ULL << (static_cast<uint8_t>(k) - 1);
188 }
189 }
190
191 constexpr explicit DispatchKeySet(DispatchKey k) {
192 if (k == DispatchKey::Undefined) {
193 // Case 1: handle Undefined specifically
194 repr_ = 0;
195 } else if (k <= DispatchKey::EndOfFunctionalityKeys) {
196 // Case 2: handle "functionality-only" keys
197 // These keys have a functionality bit set, but no backend bits
198 // These can technically be either:
199 // - valid runtime keys (e.g. DispatchKey::AutogradOther,
200 // DispatchKey::FuncTorchBatched, etc)
201 // - "building block" keys that aren't actual runtime keys (e.g.
202 // DispatchKey::Dense or Sparse)
203 uint64_t functionality_val = 1ULL
204 << (num_backends + static_cast<uint8_t>(k) - 1);
205 repr_ = functionality_val;
206 } else if (k <= DispatchKey::EndOfRuntimeBackendKeys) {
207 // Case 3: "runtime" keys that have a functionality bit AND a backend bit.
208 // First compute which bit to flip for the functionality.
209 auto functionality_k = toFunctionalityKey(k);
210 // The - 1 is because Undefined is technically a "functionality" that
211 // doesn't show up in the bitset. So e.g. Dense is technically the second
212 // functionality, but the lowest functionality bit.
213 uint64_t functionality_val = 1ULL
214 << (num_backends + static_cast<uint8_t>(functionality_k) - 1);
215
216 // then compute which bit to flip for the backend
217 // Case 4a: handle the runtime instances of "per-backend functionality"
218 // keys For example, given DispatchKey::CPU, we should set:
219 // - the Dense functionality bit
220 // - the CPUBit backend bit
221 // first compute which bit to flip for the backend
222 auto backend_k = toBackendComponent(k);
223 uint64_t backend_val = backend_k == BackendComponent::InvalidBit
224 ? 0
225 : 1ULL << (static_cast<uint8_t>(backend_k) - 1);
226 repr_ = functionality_val + backend_val;
227 } else {
228 // At this point, we should have covered every case except for alias keys.
229 // Technically it would be possible to add alias dispatch keys to a
230 // DispatchKeySet, but the semantics are a little confusing and this
231 // currently isn't needed anywhere.
232 repr_ = 0;
233 }
234 }
235
236 constexpr uint64_t keys_to_repr(std::initializer_list<DispatchKey> ks) {
237 uint64_t repr = 0;
238 for (auto k : ks) {
239 repr |= DispatchKeySet(k).repr_;
240 }
241 return repr;
242 }
243
244 constexpr uint64_t backend_bits_to_repr(
245 std::initializer_list<BackendComponent> ks) {
246 uint64_t repr = 0;
247 for (auto k : ks) {
248 repr |= DispatchKeySet(k).repr_;
249 }
250 return repr;
251 }
252
253 explicit constexpr DispatchKeySet(std::initializer_list<DispatchKey> ks)
254 : repr_(keys_to_repr(ks)) {}
255
256 explicit constexpr DispatchKeySet(std::initializer_list<BackendComponent> ks)
257 // Note: for some reason, putting this logic directly in the constructor
258 // appears to fail to compile on CUDA 10.1.
259 // See an example internal failure at
260 // https://www.internalfb.com/intern/skycastle/run/76561193669136035/artifact/actionlog.76561193742069401.stderr
261 : repr_(backend_bits_to_repr(ks)) {}
262
263 // Test if a DispatchKey is in the set
264 inline bool has(DispatchKey t) const {
265 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(t != DispatchKey::Undefined);
266 return has_all(DispatchKeySet(t));
267 }
268 constexpr bool has_backend(BackendComponent t) const {
269 return has_all(DispatchKeySet(t));
270 }
271
272 // Test if a DispatchKey is in the set
273 // Given a DispatchKeySet of functionality keys and (potentially) backend
274 // keys, tests if all of them are in the current set.
275 constexpr bool has_all(DispatchKeySet ks) const {
276 return static_cast<bool>((repr_ & ks.repr_) == ks.repr_);
277 }
278
279 // Given a DispatchKeySet of functionality keys and (potentially) backend
280 // keys, tests if any of them are in the current set. This could technically
281 // be pretty easily implemented using has(). It is strictly a perf
282 // optimization though. There are many places in the code base where we want
283 // to test for multiple functionality keys together. HOWEVER, runtime
284 // per-backend functionality keys aren't allowed to be used with this
285 // function, because you can end up with weird results. e.g.
286 // DispatchKeySet(DispatchKey::AutogradCPU).has_any(DispatchKeySet(DispatchKey::CPU))
287 // would return true.
288 inline bool has_any(DispatchKeySet ks) const {
289 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
290 // Either there are no backend bits in the input keyset
291 ((ks.repr_ & full_backend_mask) == 0) ||
292 // or there are no per-backend-functionality bits
293 // See [Note: Per-Backend Functionality Dispatch Keys]
294 ((ks &
295 DispatchKeySet({
296 DispatchKey::Dense,
297 DispatchKey::Quantized,
298 DispatchKey::Sparse,
299 DispatchKey::AutogradFunctionality,
300 })
301 .repr_) == 0));
302 return static_cast<bool>((repr_ & ks.repr_) != 0);
303 }
304 // Test if DispatchKeySet is a superset of ks.
305 bool isSupersetOf(DispatchKeySet ks) const {
306 return (repr_ & ks.repr_) == ks.repr_;
307 }
308 // Perform set union
309 constexpr DispatchKeySet operator|(DispatchKeySet other) const {
310 return DispatchKeySet(repr_ | other.repr_);
311 }
312 // Perform set intersection
313 constexpr DispatchKeySet operator&(DispatchKeySet other) const {
314 return DispatchKeySet(repr_ & other.repr_);
315 }
316 // Compute the set difference self - other,
317 // but ONLY for the functionality keys.
318 // Any backend bits set on self will remain unchanged.
319 // See Note [Removing keys from DispatchKeySet Only Affects Functionality
320 // Keys]
321 constexpr DispatchKeySet operator-(DispatchKeySet other) const {
322 return DispatchKeySet(repr_ & (full_backend_mask | ~other.repr_));
323 }
324
325 // Compute self ^ other
326 constexpr DispatchKeySet operator^(DispatchKeySet other) const {
327 return DispatchKeySet(repr_ ^ other.repr_);
328 }
329 bool operator==(DispatchKeySet other) const {
330 return repr_ == other.repr_;
331 }
332 bool operator!=(DispatchKeySet other) const {
333 return repr_ != other.repr_;
334 }
335 // Add a DispatchKey to the DispatchKey set. Does NOT mutate,
336 // returns the extended DispatchKeySet!
337 C10_NODISCARD constexpr DispatchKeySet add(DispatchKey t) const {
338 return *this | DispatchKeySet(t);
339 }
340 C10_NODISCARD constexpr DispatchKeySet add(DispatchKeySet ks) const {
341 return *this | ks;
342 }
343
344 // Remove a DispatchKey from the DispatchKey set.
345 // This is generally not an operation you should be doing
346 // (it's used to implement the printing overload, operator<<)
347 //
348 // Note [Removing keys from DispatchKeySet Only Affects Functionality Keys]
349 // Only functionality bits are allowed to be removed from a keyset.
350 // For now, we're only allowing removal of "functionality bits" from the
351 // keyset, which is specifically needed by the fallthrough key calculation
352 // logic. Why is removing backend bits problematic? Consider this example:
353 //
354 // DispatchKeySet([DispatchKey.CPU, DispatchKey.AutogradCUDA,
355 // DispatchKey.CUDA]).remove(DispatchKey.AutogradCUDA)
356 // DispatchKeySet([DispatchKey.CPU,
357 // DispatchKey.AutogradCUDA]).remove(DispatchKey.AutogradCUDA)
358 //
359 // What do we want to happen?
360 // Technically, we'd like it to be true that after removal,
361 // the first keyset still has the CUDA dispatch key while the second doesn't.
362 // Unfortunately there's no way to represent that, because the two keysets are
363 // represented the same way internally: functionality bits: Autograd, Dense
364 // backend bits: CPU, CUDA
365 //
366 // Instead, remove(DispatchKey.AutogradCPU) will only remove the "Autograd"
367 // bit from the bitset.
368 C10_NODISCARD constexpr DispatchKeySet remove(DispatchKey t) const {
369 return DispatchKeySet(
370 repr_ & ~(DispatchKeySet(t).repr_ & ~full_backend_mask));
371 }
372 // You're allowed to remove a backend bit from a DispatchKeySet,
373 // but you have to be explicit about it (remove_backend() instead of
374 // remove()).
375 constexpr DispatchKeySet remove_backend(BackendComponent b) const {
376 return DispatchKeySet(repr_ & ~(DispatchKeySet(b).repr_));
377 }
378 // Is the set empty? (AKA undefined tensor)
379 bool empty() const {
380 return repr_ == 0;
381 }
382 uint64_t raw_repr() {
383 return repr_;
384 }
385
386 DispatchKey highestFunctionalityKey() const {
387 auto functionality_idx = indexOfHighestBit();
388 // This means that none of the functionality bits were set.
389 if (functionality_idx < num_backends)
390 return DispatchKey::Undefined;
391 // The first num_backend bits in the keyset don't correspond to real
392 // dispatch keys.
393 return static_cast<DispatchKey>(functionality_idx - num_backends);
394 }
395
396 // This is similar like toBackendComponent(DispatchKey), but less restrictive.
397 // toBackendComponent() errors out if the key that it was passed has no
398 // backend bits, which is useful for error checking. We need a version of that
399 // here that can also handle "fake" backends like FPGA, because they need to
400 // map to the AutogradOther key. For those backends, we return
401 // BackendComponent::InvalidBit.
402 BackendComponent highestBackendKey() const {
403 // mask to mask out functionality bits
404 auto backend_idx =
405 DispatchKeySet(repr_ & full_backend_mask).indexOfHighestBit();
406 // all zeros across the backend bits means that no backend bits are set.
407 if (backend_idx == 0)
408 return BackendComponent::InvalidBit;
409 return static_cast<BackendComponent>(backend_idx);
410 }
411
412 // returns the DispatchKey of highest priority in the set.
413 DispatchKey highestPriorityTypeId() const {
414 auto functionality_k = highestFunctionalityKey();
415 if (isPerBackendFunctionalityKey(functionality_k)) {
416 return toRuntimePerBackendFunctionalityKey(
417 functionality_k, highestBackendKey());
418 }
419 return functionality_k;
420 }
421
422 // Returns the index of the most-significant bit in the keyset.
423 // This is used to as part of the calculation into the operator table to get:
424 // - the highest "functionality" bit in the keyset.
425 // - the highest "backend" bit in the keyset.
426 uint8_t indexOfHighestBit() const {
427 return 64 - llvm::countLeadingZeros(repr_);
428 }
429
430#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
431 // [Note: Trimmed Mobile Dispatch Keys]
432 /**
433 * The method below maps the dispatch key in the enum DispatchKey to an
434 * integer index in the dispatchTable_ array in OperatorEntry. The array
435 * is trimmed for mobile to reduce peak memory usage since it's
436 * unnecessary to reserve additional space for dispatch keys that will
437 * never be used on mobile.
438 */
439 int getDispatchTableIndexForDispatchKeySet() const {
440 auto dk = highestPriorityTypeId();
441 switch (dk) {
442 case DispatchKey::Undefined:
443 return 0;
444 case DispatchKey::CPU:
445 return 1;
446 case DispatchKey::QuantizedCPU:
447 return 2;
448 case DispatchKey::SparseCPU:
449 return 3;
450 case DispatchKey::BackendSelect:
451 return 4;
452 case DispatchKey::ADInplaceOrView:
453 return 5;
454 case DispatchKey::AutogradOther:
455 return 6;
456 case DispatchKey::AutogradCPU:
457 return 7;
458 default:
459 return -1;
460 }
461 }
462#else
463 // returns the index in the operator table of highest priority key in the the
464 // keyset Note that we could in theory implement this using
465 // highestPriorityTypeId(), but this code is very hotpath and we can do it
466 // faster without it.
467 int getDispatchTableIndexForDispatchKeySet() const {
468 auto functionality_idx =
469 DispatchKeySet(repr_ >> num_backends).indexOfHighestBit();
470 auto offset_and_mask = offsetsAndMasks()[functionality_idx];
471 // Mask the functionality bits out first, then right-shift by 1.
472 // right-shifting by 1 because everything is zero-indexed.
473 // E.g. 000001 (CPU) should give us an offset of 0, 000010 (CUDA) should
474 // give us an offset of 1, etc.
475 auto backend_idx =
476 DispatchKeySet((repr_ & offset_and_mask.mask) >> 1).indexOfHighestBit();
477 return offset_and_mask.offset + backend_idx;
478 }
479#endif
480
481 // returns the "index" of the highest priority backend in the keyset.
482 // This is pretty similar to getBackendKey(), but:
483 // - It's hotpath code (part of the runtime bitset calculation)
484 // - I's returns an integer index, not an enum value
485 // - Everything is shifted to the right by 1.
486 // BackendComponent::InvalidBit is technically the lowest enum value,
487 // but it isn't included in the runtime table. So CPUBit = 1, CUDABit = 2,
488 // etc.
489 uint64_t getBackendIndex() const {
490 return DispatchKeySet((repr_ & full_backend_mask) >> 1).indexOfHighestBit();
491 }
492
493 private:
494 constexpr DispatchKeySet(uint64_t repr) : repr_(repr) {}
495 uint64_t repr_ = 0;
496
497 public:
498 // STL iterator for DispatchKeySet. Iterates through all runtime DispatchKeys
499 // in the set. The iterator is only invalidated by the destruction of the
500 // underlying DispatchKeySet as the iterator stores a pointer to the raw
501 // representation of the DispatchKeySet. Note: When we encounter a per-backend
502 // functionality (e.g. Dense or Sparse), we will iterate through EVERY backend
503 // in the keyset, for that functionality. For example, if the next
504 // functionality key to iterate over is Autograd, and the backend bits in the
505 // keyset correspond to [BackendComponent::CPUBit, BackendComponent::CUDABit],
506 // then the next two keys we return will be DispatchKey::AutogradCPU,
507 // DispatchKey::AutogradCUDA (CPU first because it has lower precedence than
508 // CUDA in DispatchKey.h).
509 class iterator {
510 public:
511 using self_type = iterator;
512 using iterator_category = std::input_iterator_tag;
513 using value_type = DispatchKey;
514 using difference_type = ptrdiff_t;
515 using reference = value_type&;
516 using pointer = value_type*;
517 // final mask value should mask out the entire keyset
518 static const uint8_t end_iter_mask_val =
519 num_backends + num_functionality_keys;
520 // final key value should be the last DispatchKey
521 static const uint8_t end_iter_key_val = num_functionality_keys;
522
523 // current_dispatchkey_idx_ will iterate through all functionality bits.
524 // current_backendcomponent_idx_ will iterate through all backend bits.
525 explicit iterator(
526 const uint64_t* data_ptr,
527 uint8_t next_functionality = num_backends,
528 uint8_t next_backend = 0)
529 : data_ptr_(data_ptr),
530 next_functionality_(next_functionality),
531 next_backend_(next_backend),
532 // These are in an invalid state at construction time, and set by the
533 // first increment call
534 current_dispatchkey_idx_(end_iter_key_val),
535 current_backendcomponent_idx_(end_iter_key_val) {
536 // Go to the first key in the set
537 TORCH_INTERNAL_ASSERT(
538 next_functionality_ >= num_backends,
539 "num_backends=",
540 static_cast<uint32_t>(num_backends),
541 "next_functionality_=",
542 static_cast<uint32_t>(next_functionality_));
543 ++(*this);
544 }
545
546 C10_API self_type& operator++();
547
548 self_type operator++(int) {
549 self_type previous_iterator = *this;
550 ++(*this);
551 return previous_iterator;
552 }
553
554 bool operator==(const self_type& rhs) const {
555 return next_functionality_ == rhs.next_functionality_ &&
556 current_dispatchkey_idx_ == rhs.current_dispatchkey_idx_ &&
557 next_backend_ == rhs.next_backend_ &&
558 current_backendcomponent_idx_ == rhs.current_backendcomponent_idx_;
559 }
560 bool operator!=(const self_type& rhs) const {
561 return next_functionality_ != rhs.next_functionality_ ||
562 current_dispatchkey_idx_ != rhs.current_dispatchkey_idx_ ||
563 next_backend_ != rhs.next_backend_ ||
564 current_backendcomponent_idx_ != rhs.current_backendcomponent_idx_;
565 }
566 DispatchKey operator*() const {
567 auto functionality_key =
568 static_cast<DispatchKey>(current_dispatchkey_idx_);
569 if (isPerBackendFunctionalityKey(functionality_key)) {
570 auto next_key = toRuntimePerBackendFunctionalityKey(
571 functionality_key,
572 static_cast<BackendComponent>(current_backendcomponent_idx_));
573 // We expect all of the Dense, Sparse, Quantized, and Autograd keys to
574 // be ordered the same way with respect to their backends
575 TORCH_INTERNAL_ASSERT(
576 toBackendComponent(next_key) ==
577 static_cast<BackendComponent>(current_backendcomponent_idx_),
578 "Tried to map functionality key ",
579 toString(functionality_key),
580 " and backend bit ",
581 toString(
582 static_cast<BackendComponent>(current_backendcomponent_idx_)),
583 " to a runtime key, but ended up with ",
584 toString(next_key),
585 ". This can happen if the order of the backend dispatch keys in DispatchKey.h isn't consistent.",
586 " Please double check that enum for inconsistencies.");
587 return next_key;
588 } else {
589 return functionality_key;
590 }
591 }
592
593 private:
594 const uint64_t* data_ptr_;
595 uint8_t next_functionality_;
596 uint8_t next_backend_;
597 uint8_t current_dispatchkey_idx_;
598 uint8_t current_backendcomponent_idx_;
599 };
600
601 public:
602 // Returns iterator to the first key in the set. If no keys are in the
603 // set, then will return the end iterator.
604 iterator begin() const {
605 return iterator(&repr_);
606 }
607
608 // We do not need to iterate beyond EndOfFunctionalityKeys so we will treat
609 // this as the end iterator.
610 iterator end() const {
611 return iterator(&repr_, iterator::end_iter_mask_val);
612 }
613};
614
615C10_API std::string toString(DispatchKeySet);
616C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet);
617
618C10_API inline int getDispatchTableIndexForDispatchKey(DispatchKey k) {
619 return DispatchKeySet(k).getDispatchTableIndexForDispatchKeySet();
620}
621
622// Alias key DispatchKey::Autograd maps to
623// (autograd_dispatch_keyset x full_backend_mask)
624// NB: keys in this set also get associated with CompositeImplicitAutograd
625//
626// Note [autograd_dispatch_keyset Does Not Include Backend Bits]
627// We don't want to include any backend bits (BackendComponent::CPUBit, etc)
628// directly in autograd_dispatch_keyset.
629// Why? keysets like autograd_dispatch_keyset are commonly used to remove
630// autograd keys from a DispatchKeySet throughout the code base. However, you
631// are only allowed to remove functionality bits from a keyset, not backend
632// bits. See Note [Removing keys from DispatchKeySet Only Affects Functionality
633// Keys] for details. To be consistent and avoid confusion, we're explicitly
634// setting up autograd_dispatch_keyset to not have any backend bits.
635constexpr DispatchKeySet autograd_dispatch_keyset = DispatchKeySet({
636 DispatchKey::AutogradFunctionality,
637 DispatchKey::AutogradOther,
638 DispatchKey::AutogradNestedTensor,
639});
640
641constexpr DispatchKeySet autocast_dispatch_keyset = DispatchKeySet({
642 DispatchKey::AutocastCPU,
643 DispatchKey::AutocastCUDA,
644 DispatchKey::AutocastXPU,
645 DispatchKey::AutocastHPU,
646});
647
648// See Note [TLS Initialization]
649constexpr DispatchKeySet default_included_set = DispatchKeySet({
650 DispatchKey::BackendSelect,
651 DispatchKey::ADInplaceOrView,
652});
653
654constexpr DispatchKeySet default_excluded_set = DispatchKeySet({
655 DispatchKey::AutocastCPU,
656 DispatchKey::AutocastCUDA,
657 DispatchKey::AutocastXPU,
658 DispatchKey::AutocastHPU,
659});
660
661constexpr DispatchKeySet autograd_dispatch_keyset_with_ADInplaceOrView =
662 autograd_dispatch_keyset | DispatchKeySet(DispatchKey::ADInplaceOrView);
663
664constexpr DispatchKeySet python_ks = DispatchKeySet({
665 DispatchKey::Python,
666 DispatchKey::PythonTLSSnapshot,
667});
668
669constexpr DispatchKeySet sparse_ks = DispatchKeySet(DispatchKey::Sparse);
670
671constexpr DispatchKeySet sparse_csr_ks =
672 DispatchKeySet({DispatchKey::SparseCsrCPU, DispatchKey::SparseCsrCUDA});
673
674constexpr DispatchKeySet mkldnn_ks = DispatchKeySet(DispatchKey::MkldnnCPU);
675
676// backend dispatch keys that map to DispatchKey::AutogradOther
677// NB: keys in this set also get associated with CompositeImplicitAutograd
678constexpr DispatchKeySet autogradother_backends =
679 DispatchKeySet(
680 // HIP and VE aren't in this list: they now have their own backend bits
681 // which means that they can now have their own Autograd keys.
682 // Technically, HIP will now redispatch to its own custom AutogradHIP
683 // slot in the runtime table.
684 {DispatchKey::FPGA,
685 DispatchKey::ORT,
686 DispatchKey::Vulkan,
687 DispatchKey::Metal,
688 DispatchKey::SparseCsrCPU,
689 DispatchKey::SparseCsrCUDA,
690 DispatchKey::CustomRNGKeyId,
691 DispatchKey::MkldnnCPU,
692 // Sparse and Quantized backends also live here.
693 DispatchKey::Sparse,
694 DispatchKey::Quantized})
695 // Including the backend bits because this keyset is used during op
696 // registration, which requires looping over all runtime autogradother
697 // backend keys.
698 | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
699
700// The set of dispatch keys that come after autograd
701// n.b. this relies on the fact that AutogradOther is currently the lowest
702// Autograd key
703constexpr DispatchKeySet after_autograd_keyset =
704 DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::AutogradOther);
705
706// The set of dispatch keys that come after ADInplaceOrView
707constexpr DispatchKeySet after_ADInplaceOrView_keyset = DispatchKeySet(
708 DispatchKeySet::FULL_AFTER,
709 c10::DispatchKey::ADInplaceOrView);
710
711// The set of dispatch keys that come after Functionalize
712constexpr DispatchKeySet after_func_keyset =
713 DispatchKeySet(DispatchKeySet::FULL_AFTER, c10::DispatchKey::Functionalize)
714 .remove(
715 // NOTE: we also need to remove ADInplaceOrView from the keyset when
716 // redispatching after the func kernels. This is because we're not
717 // calling the same op; we originally called an inplace op, and now
718 // we aren't. The original key calculation figured out which keys
719 // were Fallthrough based on the inplace op. That means that it did
720 // not include the ADInPlaceOrView kernel as a fallthrough key.
721 // However, we WANT the ADInPlaceOrView kernel to be ignored now
722 // that we're calling an out-of-place op. Re-invoking
723 // Dispatcher::call would re-run the Fallthrough key calculation and
724 // get us that, But at::redispatch is more performant. We can get
725 // away with it by explicitly removing the key here.
726 c10::DispatchKey::ADInplaceOrView);
727
728constexpr DispatchKeySet backend_bitset_mask =
729 DispatchKeySet(DispatchKeySet::RAW, (1ULL << num_backends) - 1);
730
731constexpr auto inplace_or_view_ks =
732 DispatchKeySet(DispatchKey::ADInplaceOrView);
733constexpr auto autograd_cpu_ks = DispatchKeySet(DispatchKey::AutogradCPU);
734constexpr auto autograd_ipu_ks = DispatchKeySet(DispatchKey::AutogradIPU);
735constexpr auto autograd_xpu_ks = DispatchKeySet(DispatchKey::AutogradXPU);
736constexpr auto autograd_cuda_ks = DispatchKeySet(DispatchKey::AutogradCUDA);
737constexpr auto autograd_xla_ks = DispatchKeySet(DispatchKey::AutogradXLA);
738constexpr auto autograd_lazy_ks = DispatchKeySet(DispatchKey::AutogradLazy);
739constexpr auto autograd_meta_ks = DispatchKeySet(DispatchKey::AutogradMeta);
740constexpr auto autograd_mps_ks = DispatchKeySet(DispatchKey::AutogradMPS);
741constexpr auto autograd_hpu_ks = DispatchKeySet(DispatchKey::AutogradHPU);
742constexpr auto autograd_privateuse1_ks =
743 DispatchKeySet(DispatchKey::AutogradPrivateUse1);
744constexpr auto autograd_privateuse2_ks =
745 DispatchKeySet(DispatchKey::AutogradPrivateUse2);
746constexpr auto autograd_privateuse3_ks =
747 DispatchKeySet(DispatchKey::AutogradPrivateUse3);
748constexpr auto autograd_other_ks = DispatchKeySet(DispatchKey::AutogradOther);
749constexpr auto autograd_nested =
750 DispatchKeySet(DispatchKey::AutogradNestedTensor);
751// keyset correpsonding to functorch keys that have their own dedicated
752// TensorImpl subclass.
753constexpr auto functorch_transforms_ks = DispatchKeySet(
754 {DispatchKey::FuncTorchBatched,
755 DispatchKey::FuncTorchVmapMode,
756 DispatchKey::Batched,
757 DispatchKey::VmapMode,
758 DispatchKey::FuncTorchGradWrapper});
759
760constexpr auto functorch_batched_ks =
761 DispatchKeySet({DispatchKey::FuncTorchBatched});
762
763// This keyset has:
764// (1) the functionality bits corresponding to backends (dense, sparse,
765// quantized) (2) all of the backend bits set
766constexpr DispatchKeySet backend_functionality_keys =
767 DispatchKeySet({
768 DispatchKey::Dense,
769 DispatchKey::Quantized,
770 DispatchKey::Sparse,
771 }) |
772 DispatchKeySet(DispatchKeySet::RAW, full_backend_mask);
773
774struct OpTableOffsetAndMask {
775 uint16_t offset;
776 uint16_t backend_mask;
777};
778
779static_assert(
780 num_backends <= 16,
781 "Right now we expect the number of backends not to exceed 16. In the (unlikely) event"
782 " that this changes, the size of OpTableOffsetAndMask::backend_mask needs to be increased too.");
783
784// true if t is a backend dispatch key
785C10_API bool isBackendDispatchKey(DispatchKey t);
786
787// Resolve alias dispatch key to DispatchKeySet if applicable
788C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);
789
790// Resolve alias dispatch key to DispatchKeySet if applicable,
791// and chek if k is a part of that set
792C10_API bool runtimeDispatchKeySetHas(DispatchKey t, DispatchKey k);
793
794// Returns a DispatchKeySet of all backend keys mapped to Autograd dispatch key
795// t, DispatchKeySet is empty if t is not alias of DispatchKey::Autograd.
796C10_API DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t);
797
798// Returns a DispatchKeySet of autograd related keys mapped to backend.
799// for a given backend key, use the associated autograd key.
800// for non-backend keys, use AutogradOther as a default.
801// Note: it's convenient and fast to return a default here rather than (say)
802// returning an optional<DispatchKey>, or throwing. But it makes callers
803// responsible for either a) enforcing the invariant that only backend keys
804// be passed as arguments, or b) interpreting our return value carefully.
805inline DispatchKeySet getAutogradRelatedKeySetFromBackend(BackendComponent t) {
806 switch (t) {
807 case BackendComponent::CPUBit:
808 return inplace_or_view_ks | autograd_cpu_ks;
809 case BackendComponent::IPUBit:
810 return inplace_or_view_ks | autograd_ipu_ks;
811 case BackendComponent::XPUBit:
812 return inplace_or_view_ks | autograd_xpu_ks;
813 case BackendComponent::CUDABit:
814 return inplace_or_view_ks | autograd_cuda_ks;
815 case BackendComponent::XLABit:
816 return inplace_or_view_ks | autograd_xla_ks;
817 case BackendComponent::LazyBit:
818 return inplace_or_view_ks | autograd_lazy_ks;
819 case BackendComponent::MetaBit:
820 return inplace_or_view_ks | autograd_meta_ks;
821 case BackendComponent::MPSBit:
822 return inplace_or_view_ks | autograd_mps_ks;
823 case BackendComponent::HPUBit:
824 return inplace_or_view_ks | autograd_hpu_ks;
825 case BackendComponent::PrivateUse1Bit:
826 return inplace_or_view_ks | autograd_privateuse1_ks;
827 case BackendComponent::PrivateUse2Bit:
828 return inplace_or_view_ks | autograd_privateuse2_ks;
829 case BackendComponent::PrivateUse3Bit:
830 return inplace_or_view_ks | autograd_privateuse3_ks;
831 default:
832 return inplace_or_view_ks | autograd_other_ks;
833 }
834}
835
836// Returns a DispatchKeySet of autocast related keys mapped to backend.
837inline DispatchKeySet getAutocastRelatedKeySetFromBackend(BackendComponent t) {
838 constexpr auto autocast_cpu_ks = DispatchKeySet(DispatchKey::AutocastCPU);
839 constexpr auto autocast_xpu_ks = DispatchKeySet(DispatchKey::AutocastXPU);
840 constexpr auto autocast_hpu_ks = DispatchKeySet(DispatchKey::AutocastHPU);
841 constexpr auto autocast_cuda_ks = DispatchKeySet(DispatchKey::AutocastCUDA);
842 switch (t) {
843 case BackendComponent::CPUBit:
844 return autocast_cpu_ks;
845 case BackendComponent::XPUBit:
846 return autocast_xpu_ks;
847 case BackendComponent::HPUBit:
848 return autocast_hpu_ks;
849 case BackendComponent::CUDABit:
850 case BackendComponent::XLABit:
851 return autocast_cuda_ks;
852 default:
853 return DispatchKeySet();
854 }
855}
856
857// returns the "backend" DispatchKey of highest priority in the set.
858// This is basically like highestBackendKey(), except that we have some
859// "functionality" bits that correspond to backends (Sparse, Quantized)
860inline DispatchKey highestPriorityBackendTypeId(DispatchKeySet ks) {
861 return (ks & backend_functionality_keys).highestPriorityTypeId();
862}
863
864// This API exists because we have a use case for checking
865// getRuntimeDispatchKeySet(alias).has(DispatchKey::Undefined)
866// in OperatorEntry.cpp but we disallow it in has() API.
867C10_API bool isIncludedInAlias(DispatchKey k, DispatchKey alias);
868
869// Historically, every tensor only had a single DispatchKey, and it was always
870// something like CPU, and there wasn't any of this business where TLS
871// could cause the DispatchKey of a tensor to change. But we still have some
872// legacy code that is still using DispatchKey for things like instanceof
873// checks; if at all possible, refactor the code to stop using DispatchKey in
874// those cases.
875static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) {
876 // NB: If you add any extra keys that can be stored in TensorImpl on
877 // top of existing "backend" keys like CPU/CUDA, you need to add it
878 // here. At the moment, autograd keys and ADInplaceOrView key need this
879 // treatment;
880 return (s - autograd_dispatch_keyset_with_ADInplaceOrView -
881 autocast_dispatch_keyset -
882 DispatchKeySet(
883 {DispatchKey::Functionalize,
884 DispatchKey::PythonTLSSnapshot,
885 DispatchKey::Python}))
886 .highestPriorityTypeId();
887}
888
889template <class T>
890using is_not_DispatchKeySet = guts::negation<std::is_same<DispatchKeySet, T>>;
891
892// Given a function type, constructs a function_traits type that drops the first
893// parameter type if the first parameter is of type DispatchKeySet. NB:
894// DispatchKeySet is currently explicitly hidden from JIT (mainly to avoid
895// pushing unnecessary arguments on the stack - see Note [ Plumbing Keys Through
896// the Dispatcher] for details). If at any point in the future we need to expose
897// this type to JIT, revisit the usage of this type alias.
898template <class FuncType>
899using remove_DispatchKeySet_arg_from_func = guts::make_function_traits_t<
900 typename guts::infer_function_traits_t<FuncType>::return_type,
901 typename std::conditional_t<
902 std::is_same<
903 DispatchKeySet,
904 typename guts::typelist::head_with_default_t<
905 void,
906 typename guts::infer_function_traits_t<
907 FuncType>::parameter_types>>::value,
908 guts::typelist::drop_if_nonempty_t<
909 typename guts::infer_function_traits_t<FuncType>::parameter_types,
910 1>,
911 typename guts::infer_function_traits_t<FuncType>::parameter_types>>;
912} // namespace c10
913