1 | #pragma once |
2 | |
3 | #include <c10/core/DeviceType.h> |
4 | #include <c10/macros/Macros.h> |
5 | #include <ostream> |
6 | #include <string> |
7 | |
8 | namespace c10 { |
9 | |
10 | // Semantically, each value of BackendComponent identifies a "backend" for our |
11 | // dispatch. Some functionalities that we may dispatch to are allowed to |
12 | // register different handlers for each backend. The BackendComponent is then |
13 | // used to figure out which backend implementation to dispatch to. |
14 | |
15 | // In implementation terms, the backend component identifies a specific "bit" in |
16 | // a DispatchKeySet. The bits in the DispatchKeySet are split between the bottom |
17 | // ~12 "BackendComponent" bits, while the remaining upper bits are assigned to |
18 | // functionalities. When we encounter a functionality bit that is known to be |
19 | // customizeable per-backend, then we also look at the lower BackendComponent |
20 | // bits and take the highest bit to determine which backend's implementation to |
21 | // use. |
22 | |
23 | // WARNING! If you add a new backend component to the end of this list, |
24 | // make sure you update PrivateUse3Bit. (But you shouldn't: private use |
25 | // keys should have higher precedence than all built-in keys) |
26 | |
27 | // If you add a new (non-privateuse) backend here, |
28 | // make sure to add an Autograd<Backend> fallthrough kernel |
29 | // in aten/src/ATen/core/VariableFallbackKernel.cpp |
30 | |
31 | #define C10_FORALL_BACKEND_COMPONENTS(_, extra) \ |
32 | _(CPU, extra) \ |
33 | _(CUDA, extra) \ |
34 | _(HIP, extra) \ |
35 | _(XLA, extra) \ |
36 | _(MPS, extra) \ |
37 | _(IPU, extra) \ |
38 | _(XPU, extra) \ |
39 | _(HPU, extra) \ |
40 | _(VE, extra) \ |
41 | _(Lazy, extra) \ |
42 | _(Meta, extra) \ |
43 | _(MTIA, extra) \ |
44 | _(PrivateUse1, extra) \ |
45 | _(PrivateUse2, extra) \ |
46 | _(PrivateUse3, extra) |
47 | |
48 | // WARNING! If we add a new per-backend functionality key that has higher |
49 | // priority than Autograd, then make sure you update EndOfRuntimeBackendKeys |
50 | |
51 | #define C10_FORALL_FUNCTIONALITY_KEYS(_) \ |
52 | _(Dense, ) \ |
53 | _(Quantized, Quantized) \ |
54 | _(Sparse, Sparse) \ |
55 | _(NestedTensor, NestedTensor) \ |
56 | _(AutogradFunctionality, Autograd) |
57 | |
58 | enum class BackendComponent : uint8_t { |
59 | |
60 | // A "backend" is colloquially used to refer to handlers for dispatch |
61 | // which actually implement the numerics of an operation in question. |
62 | // |
63 | // Due to the nature of the enum, these backends are specified in |
64 | // an ordered way, but for most backends this order is not semantically |
65 | // meaningful (e.g., it's valid to reorder these backends without changing |
66 | // semantics). The only situation when backend ordering is meaningful |
67 | // is when the backend participates in multiple dispatch with another |
68 | // backend; e.g., CPU and CUDA (cuda must have higher priority). |
69 | |
70 | // These keys don't correspond to individual kernels. |
71 | // Instead, they represent the backends that are allowed to override specific |
72 | // pieces of functionality: |
73 | // - dense kernels (e.g. DispatchKey::CPU) |
74 | // - sparse kernels (e.g. DispatchKey::SparseCPU) |
75 | // - quantized kernels (e.g. DispatchKey::QuantizedCPU) |
76 | // - autograd kernels (e.g. DispatchKey::AutogradCPU) |
77 | // We reserve space in the runtime operator table for this full cross product |
78 | // of |
79 | // [backends in this enum] x [keys below that are explicitly marked as having |
80 | // per-backend functionality] |
81 | // |
82 | // A meta tensor is a tensor without any data associated with it. (They |
83 | // have also colloquially been referred to as tensors on the "null" device). |
84 | // A meta tensor can be used to dry run operators without actually doing any |
85 | // computation, e.g., add on two meta tensors would give you another meta |
86 | // tensor with the output shape and dtype, but wouldn't actually add anything. |
87 | |
88 | InvalidBit = 0, |
89 | #define DEFINE_BACKEND_COMPONENT(n, _) n##Bit, |
90 | C10_FORALL_BACKEND_COMPONENTS(DEFINE_BACKEND_COMPONENT, unused) |
91 | #undef DEFINE_BACKEND_COMPONENT |
92 | |
93 | // Define an alias to represent end of backend dispatch keys. |
94 | // If you add new backend keys after PrivateUse3, please also update it here. |
95 | EndOfBackendKeys = PrivateUse3Bit, |
96 | }; |
97 | |
98 | // Semantically, a dispatch key identifies a possible "level" in our |
99 | // dispatch, for which a handler may be registered. Each handler corresponds |
100 | // to a type of functionality. |
101 | // |
102 | // In implementation terms, the dispatch key identifies a specific "bit" in a |
103 | // DispatchKeySet. Higher bit indexes get handled by dispatching first (because |
104 | // we "count leading zeros" when we extract the highest priority dispatch |
105 | // key.) |
106 | // |
107 | // Note [DispatchKey Classification] |
108 | // This enum actually contains several types of keys, which are explained |
109 | // in more detail further down: |
110 | // (1) non-customizable backends (e.g. FPGA) |
111 | // (2) non-customizable functionalities (e.g. Functionalize) |
112 | // (3) functionalized that are customizable per backend (e.g. Dense, Sparse, |
113 | // AutogradFunctionality) (4) per-backend instances of customizable |
114 | // functionalities (e.g. CPU, SparseCPU, AutogradCPU) (5) alias keys (e.g. |
115 | // CompositeImplicitAutograd) |
116 | // |
117 | // Of the categories above, it's important to note: |
118 | // (a) which keys are assigned individual bits in a DispatchKeySet |
119 | // (b) which keys are assigned individual slots in the runtime operator table |
120 | // ("Runtime keys") |
121 | // |
122 | // (1), (2) and (3) all get their own dedicated bits in the DispatchKeySet. |
123 | // (1), (2) and (4) all get their own dedicated slots in the runtime operator |
124 | // table. |
125 | |
126 | // See Note [DispatchKeySet Internal Representation] for more details. |
127 | // |
128 | // NOTE: Keep the list in sync with `DispatchKey` in torchgen/model.py |
129 | enum class DispatchKey : uint16_t { |
130 | |
131 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // |
132 | // This is not a "real" functionality, but it exists to give us a "nullopt" |
133 | // element we can return for cases when a DispatchKeySet contains no elements. |
134 | // You can think a more semantically accurate definition of DispatchKey is: |
135 | // |
136 | // using DispatchKey = optional<RealDispatchKey> |
137 | // |
138 | // and Undefined == nullopt. We didn't actually represent |
139 | // it this way because optional<RealDispatchKey> would take two |
140 | // words, when DispatchKey fits in eight bits. |
141 | |
142 | Undefined = 0, |
143 | |
144 | // Define an alias for Undefined to represent CatchAll (long term |
145 | // this will get eliminated, but for now it's convenient) |
146 | CatchAll = Undefined, |
147 | |
148 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~ Functionality Keys ~~~~~~~~~~~~~~~~~~~~~~ // |
149 | // Every value in the enum (up to EndOfFunctionalityKeys) |
150 | // corresponds to an individual "functionality" that can be dispatched to. |
151 | // This is represented in the DispatchKeySet by assigning each of these enum |
152 | // values |
153 | // to each of the remaining (64 - len(BackendComponent)) bits. |
154 | // |
155 | // Most of these functionalities have a single handler assigned to them, |
156 | // making them "runtime keys". |
157 | // That map to a single slot in the runtime operator table. |
158 | // |
159 | // A few functionalities are allowed to be customizable per backend. |
160 | // See [Note: Per-Backend Functionality Dispatch Keys] for details. |
161 | |
162 | // See [Note: Per-Backend Functionality Dispatch Keys] |
163 | Dense, |
164 | |
165 | // Below are non-extensible backends. |
166 | // These are backends that currently don't have their own overrides for |
167 | // Autograd/Sparse/Quantized kernels, |
168 | // and we therefore don't waste space in the runtime operator table allocating |
169 | // space for them. |
170 | // If any of these backends ever need to customize, e.g., Autograd, then we'll |
171 | // need to add a DispatchKey::*Bit for them. |
172 | |
173 | // TODO: put this in BackendComponents |
174 | FPGA, // Xilinx support lives out of tree at |
175 | // https://gitlab.com/pytorch-complex/vitis_kernels |
176 | |
177 | // TODO: put this in BackendComponents |
178 | // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and |
179 | // https://github.com/microsoft/onnxruntime, and is also used to test general |
180 | // backend/extension machinery in the core. cf: |
181 | // - test/cpp_extensions/ort_extension.cpp |
182 | // - test/test_torch.py |
183 | // - aten/src/ATen/test/extension_backend_test.cpp |
184 | ORT, |
185 | |
186 | Vulkan, // TODO: put this in BackendComponents |
187 | Metal, // TODO: put this in BackendComponents |
188 | |
189 | // See [Note: Per-Backend Functionality Dispatch Keys] |
190 | Quantized, |
191 | |
192 | // This backend is to support custom RNGs; it lets you go |
193 | // to a different kernel if you pass in a generator that is not a |
194 | // traditional CPUGeneratorImpl/CUDAGeneratorImpl. To make use of this |
195 | // key: |
196 | // 1) set it as a second parameter of at::Generator constructor call in |
197 | // the user-defined PRNG class. |
198 | // 2) use it as a dispatch key while registering custom kernels |
199 | // (templatized kernels specialized for user-defined PRNG class) |
200 | // intended for out of tree use; tested by aten/src/ATen/test/rng_test.cpp |
201 | CustomRNGKeyId, |
202 | |
203 | // TODO: Make Mkldnn a functionality key, so we can give it Meta |
204 | // support |
205 | // Here are backends which specify more specialized operators |
206 | // based on the layout of the tensor. Note that the sparse backends |
207 | // are one case where ordering matters: sparse multi-dispatches with |
208 | // the corresponding dense tensors, and must be handled before them. |
209 | MkldnnCPU, // registered at build/aten/src/ATen/RegisterMkldnnCPU.cpp |
210 | // NB: not to be confused with MKLDNN, which is Caffe2 only |
211 | |
212 | // See [Note: Per-Backend Functionality Dispatch Keys] |
213 | Sparse, |
214 | |
215 | // TODO: Make SparseCsr a functionality key |
216 | SparseCsrCPU, |
217 | SparseCsrCUDA, |
218 | |
219 | NestedTensor, |
220 | |
221 | // In some situations, it is not immediately obvious what the correct |
222 | // backend for function is, because the function in question doesn't |
223 | // have any "tensor" arguments. In this case, a BackendSelect function |
224 | // can be registered to implement the custom determination of the |
225 | // correct backend. |
226 | BackendSelect, |
227 | |
228 | Python, |
229 | |
230 | // Out-of-core key for Fake Tensor in torchdistx. |
231 | // See https://pytorch.org/torchdistx/latest/fake_tensor.html |
232 | // TODO: delete this in favor of Python-implemented fake tensor |
233 | Fake, |
234 | // See Note [Out-of-tree vmap+grad prototype]. The purpose of this key |
235 | // is to insert code after the "autograd subsystem" runs, so this key should |
236 | // be directly after ADInplaceOrView and all of the autograd keys. |
237 | FuncTorchDynamicLayerBackMode, |
238 | |
239 | // Alias and mutation removal. |
240 | // If some backends want to opt into only alias removal or only mutation |
241 | // removal, |
242 | // we can consider adding separate keys dedicated to those individual passes. |
243 | // See Note [Functionalization Pass In Core] for details. |
244 | Functionalize, |
245 | |
246 | // The named dispatch key is set for any tensors with named dimensions. |
247 | // Although we have a dispatch key for named tensors, for historical reasons, |
248 | // this dispatch key doesn't do any of the substantive functionality for named |
249 | // tensor (though, hypothetically, it could!) At the moment, it's just |
250 | // responsible for letting us give good error messages when operations |
251 | // don't support named tensors. |
252 | // |
253 | // NB: If you ever consider moving named tensor functionality into |
254 | // this dispatch key, note that it might be necessary add another dispatch |
255 | // key that triggers before composite operators, in case a composite operator |
256 | // has named dimension propagation that doesn't match that of its |
257 | // constituent parts. |
258 | // TODO: delete this once torchdim lands in functorch |
259 | Named, |
260 | |
261 | // The Conjugate dispatch key is set for any tensors that need to perform |
262 | // conjugation |
263 | // This is implemented at a dispatch level right before any backends run |
264 | Conjugate, |
265 | |
266 | // The Negative dispatch key is set for any tensors that need to perform |
267 | // negation |
268 | // This is implemented at a dispatch level right before any backends run |
269 | Negative, |
270 | |
271 | ZeroTensor, // registered at build/aten/src/ATen/RegisterZeroTensor.cpp |
272 | |
273 | // Note [ADInplaceOrView key] |
274 | // ADInplaceOrView key is used by inplace or view ops to register a kernel |
275 | // that does additional setup for future autograd computation. |
276 | // |
277 | // 1. For inplace ops this kernel does version bump |
278 | // 2. For view ops this kernel does `as_view` setup where we properly setup |
279 | // DifferentiableViewMeta on the view tensors. |
280 | // |
281 | // For other ops it's fallthrough kernel since there's no extra |
282 | // work to do. |
283 | // |
284 | // Note [Dream: skip VariableType kernel when requires_grad=false] |
285 | // |
286 | // In an ideal world where we can skip VariableType kernel for inputs |
287 | // with requires_grad=false, instead of a fallthrough kernel, we'll |
288 | // register a kernel shown below to all functional ops as well: |
289 | // torch::Tensor my_functional_op(...) { |
290 | // { |
291 | // // Note for every op in VariableType, you need to go through |
292 | // // `AutoDispatchBelowADInplaceOrView` guard exactly once to add the |
293 | // // key to TLS excluded set. If you don't go through it at all, |
294 | // // inplace/view ops called through `at::` inside your backend |
295 | // // kernel will dispatch to ADInplaceOrView kernels and do a lot |
296 | // // of extra work. |
297 | // at::AutoDispatchBelowADInplaceOrView guard; |
298 | // at::redispatch::my_functional_op(...); |
299 | // } |
300 | // } |
301 | // But this work is currently blocked since it adds an extra dispatch |
302 | // for all ops and it's non-trivial overhead at model level(a few percents). |
303 | // Thus our current approach takes advantage of the fact every kernel go |
304 | // through VariableType kernel first and pulls the |
305 | // `at::AutoDispatchBelowADInplaceOrView` guard of functional ops |
306 | // up to the `VariableType` kernel. Thus we only add the extra dispatch |
307 | // to view/inplace ops to minimize its perf impact to real models. |
308 | ADInplaceOrView, |
309 | // Note [Alias Dispatch Key : Autograd] |
310 | // All backends are oblivious to autograd; autograd is handled as a |
311 | // layer which happens on top of all backends. It inspects the autograd |
312 | // metadata of all inputs, determines what autograd metadata should be |
313 | // constructed by the output, and otherwise defers to the backend to |
314 | // actually do the numeric computation. Autograd contains |
315 | // the bulk of this logic. |
316 | |
317 | // Autograd is now an alias dispatch key which by default maps to all |
318 | // backend-specific autograd keys. |
319 | // Backend-specific allow backends to override the default kernel registered |
320 | // to Autograd key as needed. |
321 | // For example, XLA wants to define autograd for einsum directly. |
322 | // Registering a custom autograd implementation at the XLA key won't work |
323 | // because we process Autograd before XLA. This key has higher priority and |
324 | // gets processed first. You generally should NOT redispatch after handling |
325 | // autograd here (since that would result in execution of the Autograd |
326 | // operator, which you're trying to skip). In AutogradXLA implementations, |
327 | // you are responsible for handling autograd yourself, or deferring to other |
328 | // operators which support autograd. |
329 | |
330 | // Currently we only have backend-specific autograd keys for CPU/CUDA/XLA and |
331 | // reserved user-defined backends. All other in-tree backends share the |
332 | // AutogradOther key. We can add specific autograd key for those backends |
333 | // upon request. |
334 | AutogradOther, |
335 | |
336 | // See [Note: Per-Backend Functionality Dispatch Keys] |
337 | AutogradFunctionality, |
338 | |
339 | // NestedTensor is an example of something that isn't a "real backend" |
340 | // (because it mostly consists of redispatching kernels) |
341 | // but it would like to override autograd functionality in C++. |
342 | // We can handle cases like this by adding an extra functionality key |
343 | // exclusively for handling autograd for NestedTensor. |
344 | // lives out of tree at |
345 | // https://github.com/pytorch/nestedtensor |
346 | AutogradNestedTensor, |
347 | |
348 | Tracer, |
349 | |
350 | // TODO: make Autocast a functionality key |
351 | // Autocasting precedes VariableTypeId, to ensure casts are autograd-exposed |
352 | // and inputs are saved for backward in the post-autocast type. |
353 | AutocastCPU, |
354 | AutocastXPU, |
355 | AutocastHPU, |
356 | // Naughtily, AutocastCUDA is also being used for XLA. In the terminal state, |
357 | // it probably should get its own Autocast key |
358 | AutocastCUDA, |
359 | |
360 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ WRAPPERS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // |
361 | // There are a number of alternative modes which may want to handle before |
362 | // autograd; for example, error checking, tracing, profiling or vmap. They |
363 | // go here. |
364 | |
365 | FuncTorchBatched, // See Note [Out-of-tree vmap+grad prototype] |
366 | FuncTorchVmapMode, // See Note [Out-of-tree vmap+grad prototype] |
367 | |
368 | // This is the dispatch key for BatchedTensorImpl, which is used to implement |
369 | // batching rules for vmap. |
370 | Batched, |
371 | |
372 | // When we are inside a vmap, all tensors dispatch on this key. |
373 | // See Note: [DispatchKey::VmapMode usage] for more details. |
374 | VmapMode, |
375 | |
376 | FuncTorchGradWrapper, // See Note [Out-of-tree vmap+grad prototype] |
377 | |
378 | // Out-of-core key for Deferred Module Initialization in torchdistx. |
379 | // See https://pytorch.org/torchdistx/latest/deferred_init.html |
380 | DeferredInit, |
381 | |
382 | // Used by Python key logic to know the set of tls on entry to the dispatcher |
383 | // This kernel assumes it is the top-most non-functorch-related DispatchKey. |
384 | // If you add a key above, make sure to update the fallback implementation for |
385 | // this. |
386 | PythonTLSSnapshot, |
387 | |
388 | // This key should be at the very top of the dispatcher |
389 | FuncTorchDynamicLayerFrontMode, // See Note [Out-of-tree vmap+grad prototype] |
390 | |
391 | // TESTING: This is intended to be a generic testing tensor type id. |
392 | // Don't use it for anything real; its only acceptable use is within a single |
393 | // process test. Use it by creating a TensorImpl with this DispatchKey, and |
394 | // then registering operators to operate on this type id. See |
395 | // aten/src/ATen/core/dispatch/backend_fallback_test.cpp for a usage example. |
396 | TESTING_ONLY_GenericWrapper, |
397 | |
398 | // TESTING: This is intended to be a generic testing tensor type id. |
399 | // Don't use it for anything real; its only acceptable use is within a ingle |
400 | // process test. Use it by toggling the mode on and off via |
401 | // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators |
402 | // to operate on this type id. See |
403 | // aten/src/ATen/core/dispatch/backend_fallback_test.cpp |
404 | // for a usage example |
405 | TESTING_ONLY_GenericMode, |
406 | |
407 | // This is a bypass that allows you to skip running the C++ dispatcher |
408 | // entirely |
409 | PythonDispatcher, |
410 | |
411 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ FIN ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // |
412 | EndOfFunctionalityKeys, // End of functionality keys. |
413 | |
414 | // ~~~~~~~~~~~~~~ "Dense" Per-Backend Dispatch keys ~~~~~~~~~~~~~~~~~~~~ // |
415 | // Here are backends which you think of as traditionally specifying |
416 | // how to implement operations on some device. |
417 | |
418 | #define DEFINE_PER_BACKEND_KEYS_FOR_BACKEND(n, prefix) prefix##n, |
419 | |
420 | #define DEFINE_PER_BACKEND_KEYS(fullname, prefix) \ |
421 | StartOf##fullname##Backends, \ |
422 | C10_FORALL_BACKEND_COMPONENTS( \ |
423 | DEFINE_PER_BACKEND_KEYS_FOR_BACKEND, prefix) \ |
424 | EndOf##fullname##Backends = prefix##PrivateUse3, |
425 | |
426 | C10_FORALL_FUNCTIONALITY_KEYS(DEFINE_PER_BACKEND_KEYS) |
427 | |
428 | #undef DEFINE_PER_BACKEND_KEYS |
429 | #undef DEFINE_PER_BACKEND_KEYS_FOR_BACKEND |
430 | |
431 | EndOfRuntimeBackendKeys = EndOfAutogradFunctionalityBackends, |
432 | |
433 | // ~~~~~~~~~~~~~~~~~~~~~~ Alias Dispatch Keys ~~~~~~~~~~~~~~~~~~~~~~~~~~ // |
434 | // Note [Alias Dispatch Keys] |
435 | // Alias dispatch keys are synthetic dispatch keys which map to multiple |
436 | // runtime dispatch keys. Alisa keys have precedence, but they are always |
437 | // lower precedence than runtime keys. You can register a kernel to an |
438 | // alias key, the kernel might be populated to the mapped runtime keys |
439 | // during dispatch table computation. |
440 | // If a runtime dispatch key has multiple kernels from alias keys, which |
441 | // kernel wins is done based on the precedence of alias keys (but runtime |
442 | // keys always have precedence over alias keys). |
443 | // Alias keys won't be directly called during runtime. |
444 | |
445 | // See Note [Alias Dispatch Key : Autograd] |
446 | Autograd, |
447 | CompositeImplicitAutograd, // registered at |
448 | // build/aten/src/ATen/RegisterCompositeImplicitAutograd.cpp |
449 | |
450 | // Note: The alias keyset for FuncTorchBatchedDecomposition is disjoint from |
451 | // all |
452 | // other alias keysets |
453 | // and so precedence order doesn't matter |
454 | FuncTorchBatchedDecomposition, // registered at |
455 | // build/aten/src/ATen/RegisterFuncTorchBatchedDecomposition.cpp |
456 | // Note: The alias keyset for CompositeImplicitAutogradNestedTensor is |
457 | // disjoint from all other alias keysets |
458 | CompositeImplicitAutogradNestedTensor, // registered at |
459 | // build/aten/src/ATen/RegisterCompositeImplicitAutogradNestedTensor.cpp |
460 | CompositeExplicitAutograd, // registered at |
461 | // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp |
462 | // See Note [CompositeExplicitAutogradNonFunctional Key] |
463 | CompositeExplicitAutogradNonFunctional, // registered at |
464 | // build/aten/src/ATen/RegisterCompositeExplicitAutograd.cpp |
465 | |
466 | // Define an alias key to represent end of alias dispatch keys. |
467 | // If you add new alias keys after Autograd, please also update it here. |
468 | StartOfAliasKeys = Autograd, |
469 | EndOfAliasKeys = CompositeExplicitAutogradNonFunctional, // |
470 | |
471 | // ~~~~~~~~~~~~~~~~~~~~~~~~~ BC ALIASES ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // |
472 | // The aliases exist for backwards compatibility reasons, they shouldn't |
473 | // be used |
474 | CPUTensorId = CPU, |
475 | CUDATensorId = CUDA, |
476 | DefaultBackend = CompositeExplicitAutograd, |
477 | PrivateUse1_PreAutograd = AutogradPrivateUse1, |
478 | PrivateUse2_PreAutograd = AutogradPrivateUse2, |
479 | PrivateUse3_PreAutograd = AutogradPrivateUse3, |
480 | Autocast = AutocastCUDA, |
481 | }; |
482 | |
483 | // Note [Private use DispatchKey] |
484 | // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
485 | // Private use tensor IDs are preallocated tensor type IDs for use in user |
486 | // applications. Similar to private use fields in HTTP, they can be used |
487 | // by end users for experimental or private applications, without needing |
488 | // to "standardize" the tensor ID (which would be done by submitting a PR |
489 | // to PyTorch to add your type ID). |
490 | // |
491 | // Private use tensor IDs are appropriate to use if you want to experiment |
492 | // with adding a new tensor type (without having to patch PyTorch first) or |
493 | // have a private, non-distributed application that needs to make use of a |
494 | // new tensor type. Private use tensor IDs are NOT appropriate to use for |
495 | // libraries intended to be distributed to further users: please contact |
496 | // the PyTorch developers to get a type ID registered in this case. |
497 | // |
498 | // We provide two classes of private user tensor id: regular DispatchKeys |
499 | // and Autograd DispatchKeys. DispatchKeys serve the role of ordinary "backend" |
500 | // DispatchKeys; if you were adding support for a new type of accelerator, you |
501 | // would use a backend DispatchKey, and ideally automatically reuse |
502 | // AutogradOther definitions already defined in PyTorch. AutogradPrivateUse |
503 | // DispatchKeys serve as "wrapper" DispatchKeys: they are only necessary for |
504 | // tensors that compose multiple internal tensors, and for cases when the |
505 | // built-in autograd formulas for operators are not appropriate. |
506 | |
507 | static_assert( |
508 | (static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) + |
509 | static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys)) <= 64, |
510 | "The BackendComponent and DispatchKey enums (below EndOfFunctionalityKeys)" |
511 | " both map to backend and functionality bits" |
512 | " into a 64-bit bitmask; you must have less than 64 total entries between them" ); |
513 | |
514 | // Check if a DispatchKey is an alias mapping to other runtime keys. |
515 | constexpr bool isAliasDispatchKey(DispatchKey k) { |
516 | return k >= DispatchKey::StartOfAliasKeys && k <= DispatchKey::EndOfAliasKeys; |
517 | } |
518 | |
519 | // [Note: Per-Backend Functionality Dispatch Keys] |
520 | // Check if a DispatchKey is a per-backend functionality key |
521 | // Any functionalities that can be customized per-backend should be added here. |
522 | // These keys correspond to functionalities that can be customized indivually |
523 | // per backend. While they only take up one bit in the `DispatchKeySet` bitset, |
524 | // they map to (# backends) slots in the operator table. |
525 | // Each of these keys also has a separate set of "runtime keys" in the dispatch |
526 | // key enum, per backend, which *do* map to the individual operator table slots. |
527 | // For example, the "Sparse" key maps to an individual bit in the |
528 | // DispatchKeySet, while `SparseCPU`, `SparseCUDA`, etc all map to individual |
529 | // slots in the runtime operator table. |
530 | |
531 | constexpr bool isPerBackendFunctionalityKey(DispatchKey k) { |
532 | if (k == DispatchKey::Dense || k == DispatchKey::Quantized || |
533 | k == DispatchKey::Sparse || k == DispatchKey::AutogradFunctionality || |
534 | k == DispatchKey::NestedTensor) { |
535 | return true; |
536 | } else { |
537 | return false; |
538 | } |
539 | } |
540 | |
541 | // Note that this includes Undefined in the total count. |
542 | // BUT EndOfFunctionalityKeys is its own (placeholder) key. |
543 | // e.g. Undefined=0, Dense=1, Sparse=2, EndOfFunctionalityKeys=3. |
544 | // In the above example, there are 3 total functionality keys. |
545 | constexpr uint8_t num_functionality_keys = |
546 | static_cast<uint8_t>(DispatchKey::EndOfFunctionalityKeys); |
547 | |
548 | constexpr uint8_t num_backends = |
549 | static_cast<uint8_t>(BackendComponent::EndOfBackendKeys); |
550 | |
551 | // Note [No More Than 16 Backends] |
552 | // Search for this note to find places in the code where the "no more than 16 |
553 | // backends" invariant is baked in. |
554 | static_assert( |
555 | static_cast<uint8_t>(BackendComponent::EndOfBackendKeys) <= 16, |
556 | "BackendComponent currently only supports <= 16 backends. If we really need to extend this, \ |
557 | there are a few places where this invariant is baked in" ); |
558 | |
559 | constexpr uint8_t numPerBackendFunctionalityKeys() { |
560 | uint8_t count = 0; |
561 | for (uint8_t k = 0; k <= num_functionality_keys; ++k) { |
562 | if (isPerBackendFunctionalityKey(static_cast<DispatchKey>(k))) |
563 | ++count; |
564 | } |
565 | return count; |
566 | } |
567 | |
568 | #if defined(C10_MOBILE_TRIM_DISPATCH_KEYS) |
569 | // See [Note: Trimmed Mobile Dispatch Keys] |
570 | constexpr uint16_t num_runtime_entries = 8; |
571 | #else |
572 | constexpr uint16_t num_runtime_entries = num_functionality_keys + |
573 | (numPerBackendFunctionalityKeys() * (num_backends - 1)); |
574 | #endif |
575 | |
576 | // See Note [No More Than 16 Backends] |
577 | constexpr uint16_t full_backend_mask = |
578 | (static_cast<uint16_t>(1) << num_backends) - 1; |
579 | |
580 | C10_API const char* toString(DispatchKey); |
581 | C10_API const char* toString(BackendComponent); |
582 | C10_API std::ostream& operator<<(std::ostream&, DispatchKey); |
583 | C10_API std::ostream& operator<<(std::ostream&, BackendComponent); |
584 | |
585 | C10_API DispatchKey getAutogradKeyFromBackend(BackendComponent k); |
586 | |
587 | // Parses a string into a dispatch key. |
588 | // If the string cannot be correctly parsed, throws an exception. |
589 | C10_API c10::DispatchKey parseDispatchKey(const std::string& k); |
590 | |
591 | // These are some convenience identifiers for dispatch keys which are |
592 | // shorter to type than their long counterparts. Note that some of these |
593 | // dispatch keys directly correspond to DeviceType; and most APIs that |
594 | // accept DispatchKey also accept DeviceType; e.g., |
595 | // torch::dispatch(torch::kCPU, ...) is also valid. |
596 | constexpr DispatchKey kAutograd = DispatchKey::Autograd; |
597 | |
598 | // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] |
599 | // This function relies on the invariant that the dispatch keys between |
600 | // StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend |
601 | // in the same order as `BackendComponent`. |
602 | constexpr BackendComponent toBackendComponent(DispatchKey k) { |
603 | if (k >= DispatchKey::StartOfDenseBackends && |
604 | k <= DispatchKey::EndOfDenseBackends) { |
605 | return static_cast<BackendComponent>( |
606 | static_cast<uint8_t>(k) - |
607 | static_cast<uint8_t>(DispatchKey::StartOfDenseBackends)); |
608 | } else if ( |
609 | k >= DispatchKey::StartOfQuantizedBackends && |
610 | k <= DispatchKey::EndOfQuantizedBackends) { |
611 | return static_cast<BackendComponent>( |
612 | static_cast<uint8_t>(k) - |
613 | static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends)); |
614 | } else if ( |
615 | k >= DispatchKey::StartOfSparseBackends && |
616 | k <= DispatchKey::EndOfSparseBackends) { |
617 | return static_cast<BackendComponent>( |
618 | static_cast<uint8_t>(k) - |
619 | static_cast<uint8_t>(DispatchKey::StartOfSparseBackends)); |
620 | } else if ( |
621 | k >= DispatchKey::StartOfNestedTensorBackends && |
622 | k <= DispatchKey::EndOfNestedTensorBackends) { |
623 | return static_cast<BackendComponent>( |
624 | static_cast<uint8_t>(k) - |
625 | static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends)); |
626 | } else if ( |
627 | k >= DispatchKey::StartOfAutogradFunctionalityBackends && |
628 | k <= DispatchKey::EndOfAutogradFunctionalityBackends) { |
629 | return static_cast<BackendComponent>( |
630 | static_cast<uint8_t>(k) - |
631 | static_cast<uint8_t>( |
632 | DispatchKey::StartOfAutogradFunctionalityBackends)); |
633 | } else { |
634 | return BackendComponent::InvalidBit; |
635 | } |
636 | } |
637 | |
638 | constexpr DispatchKey toFunctionalityKey(DispatchKey k) { |
639 | if (k <= DispatchKey::EndOfFunctionalityKeys) { |
640 | return k; |
641 | } else if (k <= DispatchKey::EndOfDenseBackends) { |
642 | return DispatchKey::Dense; |
643 | } else if (k <= DispatchKey::EndOfQuantizedBackends) { |
644 | return DispatchKey::Quantized; |
645 | } else if (k <= DispatchKey::EndOfSparseBackends) { |
646 | return DispatchKey::Sparse; |
647 | } else if (k <= DispatchKey::EndOfNestedTensorBackends) { |
648 | return DispatchKey::NestedTensor; |
649 | } else if (k <= DispatchKey::EndOfAutogradFunctionalityBackends) { |
650 | return DispatchKey::AutogradFunctionality; |
651 | } else { |
652 | return DispatchKey::Undefined; |
653 | } |
654 | } |
655 | |
656 | BackendComponent toBackendComponent(DeviceType device_type); |
657 | |
658 | // Given (DispatchKey::Dense, BackendComponent::CUDABit), returns |
659 | // DispatchKey::CUDA. |
660 | // See Note [The Ordering of Per-Backend Dispatch Keys Matters!] |
661 | // This function relies on the invariant that the dispatch keys between |
662 | // StartOfDenseBackends and EndOfRuntimeBackendKeys are ordered by backend |
663 | // in the same order as `BackendComponent`. |
664 | constexpr DispatchKey toRuntimePerBackendFunctionalityKey( |
665 | DispatchKey functionality_k, |
666 | BackendComponent backend_k) { |
667 | if (functionality_k == DispatchKey::Dense) { |
668 | return static_cast<DispatchKey>( |
669 | static_cast<uint8_t>(DispatchKey::StartOfDenseBackends) + |
670 | static_cast<uint8_t>(backend_k)); |
671 | } |
672 | if (functionality_k == DispatchKey::Sparse) { |
673 | return static_cast<DispatchKey>( |
674 | static_cast<uint8_t>(DispatchKey::StartOfSparseBackends) + |
675 | static_cast<uint8_t>(backend_k)); |
676 | } |
677 | if (functionality_k == DispatchKey::Quantized) { |
678 | return static_cast<DispatchKey>( |
679 | static_cast<uint8_t>(DispatchKey::StartOfQuantizedBackends) + |
680 | static_cast<uint8_t>(backend_k)); |
681 | } |
682 | if (functionality_k == DispatchKey::NestedTensor) { |
683 | return static_cast<DispatchKey>( |
684 | static_cast<uint8_t>(DispatchKey::StartOfNestedTensorBackends) + |
685 | static_cast<uint8_t>(backend_k)); |
686 | } |
687 | if (functionality_k == DispatchKey::AutogradFunctionality) { |
688 | return static_cast<DispatchKey>( |
689 | static_cast<uint8_t>( |
690 | DispatchKey::StartOfAutogradFunctionalityBackends) + |
691 | static_cast<uint8_t>(backend_k)); |
692 | } |
693 | return DispatchKey::Undefined; |
694 | } |
695 | |
696 | } // namespace c10 |
697 | |
698 | namespace torch { |
699 | // Expose the constant, but not the TYPE (DispatchKey is an implementation |
700 | // detail!) |
701 | using c10::kAutograd; |
702 | } // namespace torch |
703 | |
704 | // NB: You really shouldn't use this instance; this enum is guaranteed |
705 | // to be pretty small so a regular array should be acceptable. |
706 | namespace std { |
707 | template <> |
708 | struct hash<c10::DispatchKey> { |
709 | typedef size_t result_type; |
710 | typedef c10::DispatchKey argument_type; |
711 | |
712 | size_t operator()(c10::DispatchKey x) const { |
713 | return static_cast<size_t>(x); |
714 | } |
715 | }; |
716 | } // namespace std |
717 | |