1#pragma once
2
3#include <c10/core/DeviceType.h>
4#include <c10/macros/Macros.h>
5#include <ostream>
6#include <string>
7
8namespace c10 {
9
10// Semantically, each value of BackendComponent identifies a "backend" for our
11// dispatch. Some functionalities that we may dispatch to are allowed to
12// register different handlers for each backend. The BackendComponent is then
13// used to figure out which backend implementation to dispatch to.
14
15// In implementation terms, the backend component identifies a specific "bit" in
16// a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom
17// ~12 "BackendComponent" bits, while the remaining upper bits are assigned to
18// functionalities. When we encounter a functionality bit that is known to be
19// customizeable per-backend, then we also look at the lower BackendComponent
20// bits and take the highest bit to determine which backend's implementation to
21// use.
22
23// WARNING! If you add a new backend component to the end of this list,
24// make sure you update PrivateUse3Bit. (But you shouldn't: private use
25// keys should have higher precedence than all built-in keys)
26
27// If you add a new (non-privateuse) backend here,
28// make sure to add an Autograd<Backend> fallthrough kernel
29// in aten/src/ATen/core/VariableFallbackKernel.cpp
30
31#define C10_FORALL_BACKEND_COMPONENTS(_, extra) \
32 _(CPU, extra) \
33 _(CUDA, extra) \
34 _(HIP, extra) \
35 _(XLA, extra) \
36 _(MPS, extra) \
37 _(IPU, extra) \
38 _(XPU, extra) \
39 _(HPU, extra) \
40 _(VE, extra) \
41 _(Lazy, extra) \
42 _(Meta, extra) \
43 _(MTIA, extra) \
44 _(PrivateUse1, extra) \
45 _(PrivateUse2, extra) \
46 _(PrivateUse3, extra)
47
48// WARNING! If we add a new per-backend functionality key that has higher
49// priority than Autograd, then make sure you update EndOfRuntimeBackendKeys
50
51#define C10_FORALL_FUNCTIONALITY_KEYS(_) \
52 _(Dense, ) \
53 _(Quantized, Quantized) \
54 _(Sparse, Sparse) \
55 _(NestedTensor, NestedTensor) \
56 _(AutogradFunctionality, Autograd)
57
58enum class BackendComponent : uint8_t {
59
60 // A "backend" is colloquially used to refer to handlers for dispatch
61 // which actually implement the numerics of an operation in question.
62 //
63 // Due to the nature of the enum, these backends are specified in
64 // an ordered way, but for most backends this order is not semantically
65 // meaningful (e.g., it's valid to reorder these backends without changing
66 // semantics). The only situation when backend ordering is meaningful
67 // is when the backend participates in multiple dispatch with another
68 // backend; e.g., CPU and CUDA (cuda must have higher priority).
69
70 // These keys don't correspond to individual kernels.
71 // Instead, they represent the backends that are allowed to override specific
72 // pieces of functionality:
73 // - dense kernels (e.g. DispatchKey::CPU)
74 // - sparse kernels (e.g. DispatchKey::SparseCPU)
75 // - quantized kernels (e.g. DispatchKey::QuantizedCPU)
76 // - autograd kernels (e.g. DispatchKey::AutogradCPU)
77 // We reserve space in the runtime operator table for this full cross product
78 // of
79 // [backends in this enum] x [keys below that are explicitly marked as having
80 // per-backend functionality]
81 //
82 // A meta tensor is a tensor without any data associated with it. (They
83 // have also colloquially been referred to as tensors on the "null" device).
84 // A meta tensor can be used to dry run operators without actually doing any
85 // computation, e.g., add on two meta tensors would give you another meta
86 // tensor with the output shape and dtype, but wouldn't actually add anything.
87
88 InvalidBit = 0,
89#define DEFINE_BACKEND_COMPONENT(n, _) n##Bit,
90 C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused)
91#undef DEFINE_BACKEND_COMPONENT
92
93 // Define an alias to represent end of backend dispatch keys.
94 // If you add new backend keys after PrivateUse3, please also update it here.
95 EndOfBackendKeys = PrivateUse3Bit,
96};
97
98// Semantically, a dispatch key identifies a possible "level" in our
99// dispatch, for which a handler may be registered. Each handler corresponds
100// to a type of functionality.
101//
102// In implementation terms, the dispatch key identifies a specific "bit" in a
103// DispatchKeySet. Higher bit indexes get handled by dispatching first (because
104// we "count leading zeros" when we extract the highest priority dispatch
105// key.)
106//
107// Note [DispatchKey Classification]
108// This enum actually contains several types of keys, which are explained
109// in more detail further down:
110// (1) non-customizable backends (e.g. FPGA)
111// (2) non-customizable functionalities (e.g. Functionalize)
112// (3) functionalized that are customizable per backend (e.g. Dense, Sparse,
113// AutogradFunctionality) (4) per-backend instances of customizable
114// functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g.
115// CompositeImplicitAutograd)
116//
117// Of the categories above, it's important to note:
118// (a) which keys are assigned individual bits in a DispatchKeySet
119// (b) which keys are assigned individual slots in the runtime operator table
120// ("Runtime keys")
121//
122// (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet.
123// (1), (2) and (4) all get their own dedicated slots in the runtime operator
124// table.
125
126// See Note [DispatchKeySet Internal Representation] for more details.
127//
128// NOTE: Keep the list in sync with `DispatchKey` in torchgen/model.py
129enum class DispatchKey : uint16_t {
130
131 // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
132 // This is not a "real" functionality, but it exists to give us a "nullopt"
133 // element we can return for cases when a DispatchKeySet contains no elements.
134 // You can think a more semantically accurate definition of DispatchKey is:
135 //
136 // using DispatchKey = optional<RealDispatchKey>
137 //
138 // and Undefined == nullopt. We didn't actually represent
139 // it this way because optional<RealDispatchKey> would take two
140 // words, when DispatchKey fits in eight bits.
141
142 Undefined = 0,
143
144 // Define an alias for Undefined to represent CatchAll (long term
145 // this will get eliminated, but for now it's convenient)
146 CatchAll = Undefined,
147
148 // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ //
149 // Every value in the enum (up to EndOfFunctionalityKeys)
150 // corresponds to an individual "functionality" that can be dispatched to.
151 // This is represented in the DispatchKeySet by assigning each of these enum
152 // values
153 // to each of the remaining (64 - len(BackendComponent)) bits.
154 //
155 // Most of these functionalities have a single handler assigned to them,
156 // making them "runtime keys".
157 // That map to a single slot in the runtime operator table.
158 //
159 // A few functionalities are allowed to be customizable per backend.
160 // See [Note: Per-Backend Functionality Dispatch Keys] for details.
161
162 // See [Note: Per-Backend Functionality Dispatch Keys]
163 Dense,
164
165 // Below are non-extensible backends.
166 // These are backends that currently don't have their own overrides for
167 // Autograd/Sparse/Quantized kernels,
168 // and we therefore don't waste space in the runtime operator table allocating
169 // space for them.
170 // If any of these backends ever need to customize, e.g., Autograd, then we'll
171 // need to add a DispatchKey::*Bit for them.
172
173 // TODO: put this in BackendComponents
174 FPGA, // Xilinx support lives out of tree at
175 // https://gitlab.com/pytorch-complex/vitis_kernels
176
177 // TODO: put this in BackendComponents
178 // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and
179 // https://github.com/microsoft/onnxruntime, and is also used to test general
180 // backend/extension machinery in the core. cf:
181 // - test/cpp_extensions/ort_extension.cpp
182 // - test/test_torch.py
183 // - aten/src/ATen/test/extension_backend_test.cpp
184 ORT,
185
186 Vulkan, // TODO: put this in BackendComponents
187 Metal, // TODO: put this in BackendComponents
188
189 // See [Note: Per-Backend Functionality Dispatch Keys]
190 Quantized,
191
192 // This backend is to support custom RNGs; it lets you go
193 // to a different kernel if you pass in a generator that is not a
194 // traditional CPUGeneratorImpl/CUDAGeneratorImpl. To make use of this
195 // key:
196 // 1) set it as a second parameter of at::Generator constructor call in
197 // the user-defined PRNG class.
198 // 2) use it as a dispatch key while registering custom kernels
199 // (templatized kernels specialized for user-defined PRNG class)
200 // intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp
201 CustomRNGKeyId,
202
203 // TODO: Make Mkldnn a functionality key, so we can give it Meta
204 // support
205 // Here are backends which specify more specialized operators
206 // based on the layout of the tensor. Note that the sparse backends
207 // are one case where ordering matters: sparse multi-dispatches with
208 // the corresponding dense tensors, and must be handled before them.
209 MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp
210 // NB: not to be confused with MKLDNN, which is Caffe2 only
211
212 // See [Note: Per-Backend Functionality Dispatch Keys]
213 Sparse,
214
215 // TODO: Make SparseCsr a functionality key
216 SparseCsrCPU,
217 SparseCsrCUDA,
218
219 NestedTensor,
220
221 // In some situations, it is not immediately obvious what the correct
222 // backend for function is, because the function in question doesn't
223 // have any "tensor" arguments. In this case, a BackendSelect function
224 // can be registered to implement the custom determination of the
225 // correct backend.
226 BackendSelect,
227
228 Python,
229
230 // Out-of-core key for Fake Tensor in torchdistx.
231 // See https://pytorch.org/torchdistx/latest/fake_tensor.html
232 // TODO: delete this in favor of Python-implemented fake tensor
233 Fake,
234 // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key
235 // is to insert code after the "autograd subsystem" runs, so this key should
236 // be directly after ADInplaceOrView and all of the autograd keys.
237 FuncTorchDynamicLayerBackMode,
238
239 // Alias and mutation removal.
240 // If some backends want to opt into only alias removal or only mutation
241 // removal,
242 // we can consider adding separate keys dedicated to those individual passes.
243 // See Note [Functionalization Pass In Core] for details.
244 Functionalize,
245
246 // The named dispatch key is set for any tensors with named dimensions.
247 // Although we have a dispatch key for named tensors, for historical reasons,
248 // this dispatch key doesn't do any of the substantive functionality for named
249 // tensor (though, hypothetically, it could!) At the moment, it's just
250 // responsible for letting us give good error messages when operations
251 // don't support named tensors.
252 //
253 // NB: If you ever consider moving named tensor functionality into
254 // this dispatch key, note that it might be necessary add another dispatch
255 // key that triggers before composite operators, in case a composite operator
256 // has named dimension propagation that doesn't match that of its
257 // constituent parts.
258 // TODO: delete this once torchdim lands in functorch
259 Named,
260
261 // The Conjugate dispatch key is set for any tensors that need to perform
262 // conjugation
263 // This is implemented at a dispatch level right before any backends run
264 Conjugate,
265
266 // The Negative dispatch key is set for any tensors that need to perform
267 // negation
268 // This is implemented at a dispatch level right before any backends run
269 Negative,
270
271 ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp
272
273 // Note [ADInplaceOrView key]
274 // ADInplaceOrView key is used by inplace or view ops to register a kernel
275 // that does additional setup for future autograd computation.
276 //
277 // 1. For inplace ops this kernel does version bump
278 // 2. For view ops this kernel does `as_view` setup where we properly setup
279 // DifferentiableViewMeta on the view tensors.
280 //
281 // For other ops it's fallthrough kernel since there's no extra
282 // work to do.
283 //
284 // Note [Dream: skip VariableType kernel when requires_grad=false]
285 //
286 // In an ideal world where we can skip VariableType kernel for inputs
287 // with requires_grad=false, instead of a fallthrough kernel, we'll
288 // register a kernel shown below to all functional ops as well:
289 // torch::Tensor my_functional_op(...) {
290 // {
291 // // Note for every op in VariableType, you need to go through
292 // // `AutoDispatchBelowADInplaceOrView` guard exactly once to add the
293 // // key to TLS excluded set. If you don't go through it at all,
294 // // inplace/view ops called through `at::` inside your backend
295 // // kernel will dispatch to ADInplaceOrView kernels and do a lot
296 // // of extra work.
297 // at::AutoDispatchBelowADInplaceOrView guard;
298 // at::redispatch::my_functional_op(...);
299 // }
300 // }
301 // But this work is currently blocked since it adds an extra dispatch
302 // for all ops and it's non-trivial overhead at model level(a few percents).
303 // Thus our current approach takes advantage of the fact every kernel go
304 // through VariableType kernel first and pulls the
305 // `at::AutoDispatchBelowADInplaceOrView` guard of functional ops
306 // up to the `VariableType` kernel. Thus we only add the extra dispatch
307 // to view/inplace ops to minimize its perf impact to real models.
308 ADInplaceOrView,
309 // Note [Alias Dispatch Key : Autograd]
310 // All backends are oblivious to autograd; autograd is handled as a
311 // layer which happens on top of all backends. It inspects the autograd
312 // metadata of all inputs, determines what autograd metadata should be
313 // constructed by the output, and otherwise defers to the backend to
314 // actually do the numeric computation. Autograd contains
315 // the bulk of this logic.
316
317 // Autograd is now an alias dispatch key which by default maps to all
318 // backend-specific autograd keys.
319 // Backend-specific allow backends to override the default kernel registered
320 // to Autograd key as needed.
321 // For example, XLA wants to define autograd for einsum directly.
322 // Registering a custom autograd implementation at the XLA key won't work
323 // because we process Autograd before XLA. This key has higher priority and
324 // gets processed first. You generally should NOT redispatch after handling
325 // autograd here (since that would result in execution of the Autograd
326 // operator, which you're trying to skip). In AutogradXLA implementations,
327 // you are responsible for handling autograd yourself, or deferring to other
328 // operators which support autograd.
329
330 // Currently we only have backend-specific autograd keys for CPU/CUDA/XLA and
331 // reserved user-defined backends. All other in-tree backends share the
332 // AutogradOther key. We can add specific autograd key for those backends
333 // upon request.
334 AutogradOther,
335
336 // See [Note: Per-Backend Functionality Dispatch Keys]
337 AutogradFunctionality,
338
339 // NestedTensor is an example of something that isn't a "real backend"
340 // (because it mostly consists of redispatching kernels)
341 // but it would like to override autograd functionality in C++.
342 // We can handle cases like this by adding an extra functionality key
343 // exclusively for handling autograd for NestedTensor.
344 // lives out of tree at
345 // https://github.com/pytorch/nestedtensor
346 AutogradNestedTensor,
347
348 Tracer,
349
350 // TODO: make Autocast a functionality key
351 // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed
352 // and inputs are saved for backward in the post-autocast type.
353 AutocastCPU,
354 AutocastXPU,
355 AutocastHPU,
356 // Naughtily, AutocastCUDA is also being used for XLA. In the terminal state,
357 // it probably should get its own Autocast key
358 AutocastCUDA,
359
360 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
361 // There are a number of alternative modes which may want to handle before
362 // autograd; for example, error checking, tracing, profiling or vmap. They
363 // go here.
364
365 FuncTorchBatched, // See Note [Out-of-tree vmap+grad prototype]
366 FuncTorchVmapMode, // See Note [Out-of-tree vmap+grad prototype]
367
368 // This is the dispatch key for BatchedTensorImpl, which is used to implement
369 // batching rules for vmap.
370 Batched,
371
372 // When we are inside a vmap, all tensors dispatch on this key.
373 // See Note: [DispatchKey::VmapMode usage] for more details.
374 VmapMode,
375
376 FuncTorchGradWrapper, // See Note [Out-of-tree vmap+grad prototype]
377
378 // Out-of-core key for Deferred Module Initialization in torchdistx.
379 // See https://pytorch.org/torchdistx/latest/deferred_init.html
380 DeferredInit,
381
382 // Used by Python key logic to know the set of tls on entry to the dispatcher
383 // This kernel assumes it is the top-most non-functorch-related DispatchKey.
384 // If you add a key above, make sure to update the fallback implementation for
385 // this.
386 PythonTLSSnapshot,
387
388 // This key should be at the very top of the dispatcher
389 FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype]
390
391 // TESTING: This is intended to be a generic testing tensor type id.
392 // Don't use it for anything real; its only acceptable use is within a single
393 // process test. Use it by creating a TensorImpl with this DispatchKey, and
394 // then registering operators to operate on this type id. See
395 // aten/src/ATen/core/dispatch/backend_fallback_test.cpp for a usage example.
396 TESTING_ONLY_GenericWrapper,
397
398 // TESTING: This is intended to be a generic testing tensor type id.
399 // Don't use it for anything real; its only acceptable use is within a ingle
400 // process test. Use it by toggling the mode on and off via
401 // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators
402 // to operate on this type id. See
403 // aten/src/ATen/core/dispatch/backend_fallback_test.cpp
404 // for a usage example
405 TESTING_ONLY_GenericMode,
406
407 // This is a bypass that allows you to skip running the C++ dispatcher
408 // entirely
409 PythonDispatcher,
410
411 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
412 EndOfFunctionalityKeys, // End of functionality keys.
413
414// ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ //
415// Here are backends which you think of as traditionally specifying
416// how to implement operations on some device.
417
418#define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n,
419
420#define DEFINE_PER_BACKEND_KEYS(fullname, prefix) \
421 StartOf##fullname##Backends, \
422 C10_FORALL_BACKEND_COMPONENTS( \
423 DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \
424 EndOf##fullname##Backends = prefix##PrivateUse3,
425
426 C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS)
427
428#undef DEFINE_PER_BACKEND_KEYS
429#undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND
430
431 EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends,
432
433 // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ //
434 // Note [Alias Dispatch Keys]
435 // Alias dispatch keys are synthetic dispatch keys which map to multiple
436 // runtime dispatch keys. Alisa keys have precedence, but they are always
437 // lower precedence than runtime keys. You can register a kernel to an
438 // alias key, the kernel might be populated to the mapped runtime keys
439 // during dispatch table computation.
440 // If a runtime dispatch key has multiple kernels from alias keys, which
441 // kernel wins is done based on the precedence of alias keys (but runtime
442 // keys always have precedence over alias keys).
443 // Alias keys won't be directly called during runtime.
444
445 // See Note [Alias Dispatch Key : Autograd]
446 Autograd,
447 CompositeImplicitAutograd, // registered at
448 // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp
449
450 // Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from
451 // all
452 // other alias keysets
453 // and so precedence order doesn't matter
454 FuncTorchBatchedDecomposition, // registered at
455 // build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp
456 // Note: The alias keyset for CompositeImplicitAutogradNestedTensor is
457 // disjoint from all other alias keysets
458 CompositeImplicitAutogradNestedTensor, // registered at
459 // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp
460 CompositeExplicitAutograd, // registered at
461 // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
462 // See Note [CompositeExplicitAutogradNonFunctional Key]
463 CompositeExplicitAutogradNonFunctional, // registered at
464 // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp
465
466 // Define an alias key to represent end of alias dispatch keys.
467 // If you add new alias keys after Autograd, please also update it here.
468 StartOfAliasKeys = Autograd,
469 EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, //
470
471 // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
472 // The aliases exist for backwards compatibility reasons, they shouldn't
473 // be used
474 CPUTensorId = CPU,
475 CUDATensorId = CUDA,
476 DefaultBackend = CompositeExplicitAutograd,
477 PrivateUse1_PreAutograd = AutogradPrivateUse1,
478 PrivateUse2_PreAutograd = AutogradPrivateUse2,
479 PrivateUse3_PreAutograd = AutogradPrivateUse3,
480 Autocast = AutocastCUDA,
481};
482
483// Note [Private use DispatchKey]
484// ~~~~~~~~~~~~~~~~~~~~~~~~~~~
485// Private use tensor IDs are preallocated tensor type IDs for use in user
486// applications. Similar to private use fields in HTTP, they can be used
487// by end users for experimental or private applications, without needing
488// to "standardize" the tensor ID (which would be done by submitting a PR
489// to PyTorch to add your type ID).
490//
491// Private use tensor IDs are appropriate to use if you want to experiment
492// with adding a new tensor type (without having to patch PyTorch first) or
493// have a private, non-distributed application that needs to make use of a
494// new tensor type. Private use tensor IDs are NOT appropriate to use for
495// libraries intended to be distributed to further users: please contact
496// the PyTorch developers to get a type ID registered in this case.
497//
498// We provide two classes of private user tensor id: regular DispatchKeys
499// and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend"
500// DispatchKeys; if you were adding support for a new type of accelerator, you
501// would use a backend DispatchKey, and ideally automatically reuse
502// AutogradOther definitions already defined in PyTorch. AutogradPrivateUse
503// DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for
504// tensors that compose multiple internal tensors, and for cases when the
505// built-in autograd formulas for operators are not appropriate.
506
507static_assert(
508 (static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) +
509 static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64,
510 "The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)"
511 " both map to backend and functionality bits"
512 " into a 64-bit bitmask; you must have less than 64 total entries between them");
513
514// Check if a DispatchKey is an alias mapping to other runtime keys.
515constexpr bool isAliasDispatchKey(DispatchKey k) {
516 return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys;
517}
518
519// [Note: Per-Backend Functionality Dispatch Keys]
520// Check if a DispatchKey is a per-backend functionality key
521// Any functionalities that can be customized per-backend should be added here.
522// These keys correspond to functionalities that can be customized indivually
523// per backend. While they only take up one bit in the `DispatchKeySet` bitset,
524// they map to (# backends) slots in the operator table.
525// Each of these keys also has a separate set of "runtime keys" in the dispatch
526// key enum, per backend, which *do* map to the individual operator table slots.
527// For example, the "Sparse" key maps to an individual bit in the
528// DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual
529// slots in the runtime operator table.
530
531constexpr bool isPerBackendFunctionalityKey(DispatchKey k) {
532 if (k == DispatchKey::Dense || k == DispatchKey::Quantized ||
533 k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality ||
534 k == DispatchKey::NestedTensor) {
535 return true;
536 } else {
537 return false;
538 }
539}
540
541// Note that this includes Undefined in the total count.
542// BUT EndOfFunctionalityKeys is its own (placeholder) key.
543// e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3.
544// In the above example, there are 3 total functionality keys.
545constexpr uint8_t num_functionality_keys =
546 static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys);
547
548constexpr uint8_t num_backends =
549 static_cast<uint8_t>(BackendComponent::EndOfBackendKeys);
550
551// Note [No More Than 16 Backends]
552// Search for this note to find places in the code where the "no more than 16
553// backends" invariant is baked in.
554static_assert(
555 static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16,
556 "BackendComponent currently only supports <= 16 backends. If we really need to extend this, \
557there are a few places where this invariant is baked in");
558
559constexpr uint8_t numPerBackendFunctionalityKeys() {
560 uint8_t count = 0;
561 for (uint8_t k = 0; k <= num_functionality_keys; ++k) {
562 if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k)))
563 ++count;
564 }
565 return count;
566}
567
568#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
569// See [Note: Trimmed Mobile Dispatch Keys]
570constexpr uint16_t num_runtime_entries = 8;
571#else
572constexpr uint16_t num_runtime_entries = num_functionality_keys +
573 (numPerBackendFunctionalityKeys() * (num_backends - 1));
574#endif
575
576// See Note [No More Than 16 Backends]
577constexpr uint16_t full_backend_mask =
578 (static_cast<uint16_t>(1) << num_backends) - 1;
579
580C10_API const char* toString(DispatchKey);
581C10_API const char* toString(BackendComponent);
582C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
583C10_API std::ostream& operator<<(std::ostream&, BackendComponent);
584
585C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k);
586
587// Parses a string into a dispatch key.
588// If the string cannot be correctly parsed, throws an exception.
589C10_API c10::DispatchKey parseDispatchKey(const std::string& k);
590
591// These are some convenience identifiers for dispatch keys which are
592// shorter to type than their long counterparts. Note that some of these
593// dispatch keys directly correspond to DeviceType; and most APIs that
594// accept DispatchKey also accept DeviceType; e.g.,
595// torch::dispatch(torch::kCPU, ...) is also valid.
596constexpr DispatchKey kAutograd = DispatchKey::Autograd;
597
598// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
599// This function relies on the invariant that the dispatch keys between
600// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
601// in the same order as `BackendComponent`.
602constexpr BackendComponent toBackendComponent(DispatchKey k) {
603 if (k >= DispatchKey::StartOfDenseBackends &&
604 k <= DispatchKey::EndOfDenseBackends) {
605 return static_cast<BackendComponent>(
606 static_cast<uint8_t>(k) -
607 static_cast<uint8_t>(DispatchKey::StartOfDenseBackends));
608 } else if (
609 k >= DispatchKey::StartOfQuantizedBackends &&
610 k <= DispatchKey::EndOfQuantizedBackends) {
611 return static_cast<BackendComponent>(
612 static_cast<uint8_t>(k) -
613 static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends));
614 } else if (
615 k >= DispatchKey::StartOfSparseBackends &&
616 k <= DispatchKey::EndOfSparseBackends) {
617 return static_cast<BackendComponent>(
618 static_cast<uint8_t>(k) -
619 static_cast<uint8_t>(DispatchKey::StartOfSparseBackends));
620 } else if (
621 k >= DispatchKey::StartOfNestedTensorBackends &&
622 k <= DispatchKey::EndOfNestedTensorBackends) {
623 return static_cast<BackendComponent>(
624 static_cast<uint8_t>(k) -
625 static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends));
626 } else if (
627 k >= DispatchKey::StartOfAutogradFunctionalityBackends &&
628 k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
629 return static_cast<BackendComponent>(
630 static_cast<uint8_t>(k) -
631 static_cast<uint8_t>(
632 DispatchKey::StartOfAutogradFunctionalityBackends));
633 } else {
634 return BackendComponent::InvalidBit;
635 }
636}
637
638constexpr DispatchKey toFunctionalityKey(DispatchKey k) {
639 if (k <= DispatchKey::EndOfFunctionalityKeys) {
640 return k;
641 } else if (k <= DispatchKey::EndOfDenseBackends) {
642 return DispatchKey::Dense;
643 } else if (k <= DispatchKey::EndOfQuantizedBackends) {
644 return DispatchKey::Quantized;
645 } else if (k <= DispatchKey::EndOfSparseBackends) {
646 return DispatchKey::Sparse;
647 } else if (k <= DispatchKey::EndOfNestedTensorBackends) {
648 return DispatchKey::NestedTensor;
649 } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) {
650 return DispatchKey::AutogradFunctionality;
651 } else {
652 return DispatchKey::Undefined;
653 }
654}
655
656BackendComponent toBackendComponent(DeviceType device_type);
657
658// Given (DispatchKey::Dense, BackendComponent::CUDABit), returns
659// DispatchKey::CUDA.
660// See Note [The Ordering of Per-Backend Dispatch Keys Matters!]
661// This function relies on the invariant that the dispatch keys between
662// StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend
663// in the same order as `BackendComponent`.
664constexpr DispatchKey toRuntimePerBackendFunctionalityKey(
665 DispatchKey functionality_k,
666 BackendComponent backend_k) {
667 if (functionality_k == DispatchKey::Dense) {
668 return static_cast<DispatchKey>(
669 static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) +
670 static_cast<uint8_t>(backend_k));
671 }
672 if (functionality_k == DispatchKey::Sparse) {
673 return static_cast<DispatchKey>(
674 static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) +
675 static_cast<uint8_t>(backend_k));
676 }
677 if (functionality_k == DispatchKey::Quantized) {
678 return static_cast<DispatchKey>(
679 static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) +
680 static_cast<uint8_t>(backend_k));
681 }
682 if (functionality_k == DispatchKey::NestedTensor) {
683 return static_cast<DispatchKey>(
684 static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends) +
685 static_cast<uint8_t>(backend_k));
686 }
687 if (functionality_k == DispatchKey::AutogradFunctionality) {
688 return static_cast<DispatchKey>(
689 static_cast<uint8_t>(
690 DispatchKey::StartOfAutogradFunctionalityBackends) +
691 static_cast<uint8_t>(backend_k));
692 }
693 return DispatchKey::Undefined;
694}
695
696} // namespace c10
697
698namespace torch {
699// Expose the constant, but not the TYPE (DispatchKey is an implementation
700// detail!)
701using c10::kAutograd;
702} // namespace torch
703
704// NB: You really shouldn't use this instance; this enum is guaranteed
705// to be pretty small so a regular array should be acceptable.
706namespace std {
707template <>
708struct hash<c10::DispatchKey> {
709 typedef size_t result_type;
710 typedef c10::DispatchKey argument_type;
711
712 size_t operator()(c10::DispatchKey x) const {
713 return static_cast<size_t>(x);
714 }
715};
716} // namespace std
717