1#include <c10/core/DispatchKey.h>
2#include <c10/core/DispatchKeySet.h>
3
4#include <unordered_map>
5
6namespace c10 {
7
8const char* toString(BackendComponent t) {
9 switch (t) {
10 case BackendComponent::CPUBit:
11 return "CPUBit";
12 case BackendComponent::CUDABit:
13 return "CUDABit";
14 case BackendComponent::HIPBit:
15 return "HIPBit";
16 case BackendComponent::XLABit:
17 return "XLABit";
18 case BackendComponent::LazyBit:
19 return "LazyBit";
20 case BackendComponent::MetaBit:
21 return "MetaBit";
22 case BackendComponent::XPUBit:
23 return "XPUBit";
24 case BackendComponent::IPUBit:
25 return "IPUBit";
26 case BackendComponent::MPSBit:
27 return "MPSBit";
28 case BackendComponent::HPUBit:
29 return "HPUBit";
30 case BackendComponent::VEBit:
31 return "VEBit";
32 case BackendComponent::MTIABit:
33 return "MTIA";
34 case BackendComponent::PrivateUse1Bit:
35 return "PrivateUse1Bit";
36 case BackendComponent::PrivateUse2Bit:
37 return "PrivateUse2Bit";
38 case BackendComponent::PrivateUse3Bit:
39 return "PrivateUse3Bit";
40 case BackendComponent::InvalidBit:
41 return "InvalidBit";
42 default:
43 return "UNKNOWN_BACKEND_BIT";
44 }
45}
46
47BackendComponent toBackendComponent(DeviceType device_type) {
48 switch (device_type) {
49#define DO_CASE(device, _) \
50 case DeviceType::device: { \
51 return toBackendComponent(DispatchKey::device); \
52 }
53 C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused)
54#undef DO_CASE
55 default:
56 return BackendComponent::InvalidBit;
57 }
58}
59
60const char* toString(DispatchKey t) {
61 switch (t) {
62 case DispatchKey::Undefined:
63 return "Undefined";
64
65 case DispatchKey::Dense:
66 return "Dense";
67 case DispatchKey::FPGA:
68 return "FPGA";
69 case DispatchKey::ORT:
70 return "ORT";
71 case DispatchKey::Vulkan:
72 return "Vulkan";
73 case DispatchKey::Metal:
74 return "Metal";
75
76 case DispatchKey::Lazy:
77 return "Lazy";
78 case DispatchKey::MPS:
79 return "MPS";
80 case DispatchKey::HPU:
81 return "HPU";
82 case DispatchKey::MTIA:
83 return "MTIA";
84
85 case DispatchKey::Quantized:
86 return "Quantized";
87 case DispatchKey::CustomRNGKeyId:
88 return "CustomRNGKeyId";
89 case DispatchKey::MkldnnCPU:
90 return "MkldnnCPU";
91
92 case DispatchKey::Sparse:
93 return "Sparse";
94 case DispatchKey::SparseCsrCPU:
95 return "SparseCsrCPU";
96 case DispatchKey::SparseCsrCUDA:
97 return "SparseCsrCUDA";
98
99 case DispatchKey::NestedTensor:
100 return "NestedTensor";
101
102 case DispatchKey::BackendSelect:
103 return "BackendSelect";
104
105 case DispatchKey::Python:
106 return "Python";
107
108 case DispatchKey::Fake:
109 return "Fake";
110 case DispatchKey::FuncTorchDynamicLayerBackMode:
111 return "FuncTorchDynamicLayerBackMode";
112
113 case DispatchKey::Functionalize:
114 return "Functionalize";
115
116 case DispatchKey::Named:
117 return "Named";
118
119 case DispatchKey::Conjugate:
120 return "Conjugate";
121 case DispatchKey::Negative:
122 return "Negative";
123 case DispatchKey::ZeroTensor:
124 return "ZeroTensor";
125
126 case DispatchKey::ADInplaceOrView:
127 return "ADInplaceOrView";
128
129 case DispatchKey::AutogradOther:
130 return "AutogradOther";
131 case DispatchKey::AutogradFunctionality:
132 return "AutogradFunctionality";
133 case DispatchKey::AutogradNestedTensor:
134 return "AutogradNestedTensor";
135
136 case DispatchKey::Tracer:
137 return "Tracer";
138
139 case DispatchKey::AutocastCPU:
140 return "AutocastCPU";
141 case DispatchKey::AutocastXPU:
142 return "AutocastXPU";
143 case DispatchKey::AutocastHPU:
144 return "AutocastHPU";
145 case DispatchKey::AutocastCUDA:
146 return "AutocastCUDA";
147
148 case DispatchKey::FuncTorchBatched:
149 return "FuncTorchBatched";
150 case DispatchKey::FuncTorchVmapMode:
151 return "FuncTorchVmapMode";
152
153 case DispatchKey::Batched:
154 return "Batched";
155 case DispatchKey::VmapMode:
156 return "VmapMode";
157
158 case DispatchKey::FuncTorchGradWrapper:
159 return "FuncTorchGradWrapper";
160
161 case DispatchKey::DeferredInit:
162 return "DeferredInit";
163 case DispatchKey::PythonTLSSnapshot:
164 return "PythonTLSSnapshot";
165
166 // Note [Out-of-tree vmap+grad prototype]
167 // The following keys are used in the implementation of the out-of-tree
168 // composable functions transforms (vmap+grad) prototype that lives at
169 // https://github.com/zou3519/functorch
170 // We plan on eventually upstreaming the prototype into core, at which
171 // point it will have a different design that should use fewer keys.
172 case DispatchKey::FuncTorchDynamicLayerFrontMode:
173 return "FuncTorchDynamicLayerFrontMode";
174
175 case DispatchKey::TESTING_ONLY_GenericWrapper:
176 return "TESTING_ONLY_GenericWrapper";
177
178 case DispatchKey::TESTING_ONLY_GenericMode:
179 return "TESTING_ONLY_GenericMode";
180
181 case DispatchKey::PythonDispatcher:
182 return "PythonDispatcher";
183
184 // Aliases
185
186 case DispatchKey::Autograd:
187 return "Autograd";
188 case DispatchKey::CompositeImplicitAutograd:
189 return "CompositeImplicitAutograd";
190 case DispatchKey::CompositeImplicitAutogradNestedTensor:
191 return "CompositeImplicitAutogradNestedTensor";
192 case DispatchKey::CompositeExplicitAutograd:
193 return "CompositeExplicitAutograd";
194 case DispatchKey::CompositeExplicitAutogradNonFunctional:
195 return "CompositeExplicitAutogradNonFunctional";
196 case DispatchKey::FuncTorchBatchedDecomposition:
197 return "FuncTorchBatchedDecomposition";
198
199 // Per-backend dispatch keys
200
201 default:
202 auto bc = toBackendComponent(t);
203 auto fk = toFunctionalityKey(t);
204
205 switch (fk) {
206#define ENTRY(backend, functionality) \
207 case BackendComponent::backend##Bit: \
208 return #functionality #backend;
209
210#define FORALL_BC(dkname, prefix) \
211 case DispatchKey::dkname: \
212 switch (bc) { \
213 C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \
214 default: \
215 return #prefix "Undefined"; \
216 }
217
218 C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC)
219
220 default:
221 switch (bc) {
222 C10_FORALL_BACKEND_COMPONENTS(ENTRY, Unknown)
223 default:
224 return "UnknownUnknown";
225 }
226
227#undef FORALL_BC
228#undef ENTRY
229 }
230 }
231}
232
233std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
234 return str << toString(rhs);
235}
236std::ostream& operator<<(std::ostream& str, BackendComponent rhs) {
237 return str << toString(rhs);
238}
239
240DispatchKey getAutogradKeyFromBackend(BackendComponent k) {
241 // We want this to return an autograd key. We're relying on the fact that
242 // getAutogradRelatedKeySetFromBackend returns an autograd key +
243 // ADInplaceOrView, and autograd has higher precedence. The core mapping from
244 // backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend`
245 // instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a
246 // hotpath function, and we want to make sure that it doesn't have to
247 // construct any DispatchKeySets at runtime.
248 return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId();
249}
250
251c10::DispatchKey parseDispatchKey(const std::string& k) {
252 static std::unordered_map<std::string, c10::DispatchKey> key_map = {
253 {"Undefined", c10::DispatchKey::Undefined},
254 {"Dense", c10::DispatchKey::Dense},
255 {"FPGA", c10::DispatchKey::FPGA},
256 {"ORT", c10::DispatchKey::ORT},
257 {"MPS", c10::DispatchKey::MPS},
258 {"Vulkan", c10::DispatchKey::Vulkan},
259 {"Metal", c10::DispatchKey::Metal},
260 {"VE", c10::DispatchKey::VE},
261 {"Meta", c10::DispatchKey::Meta},
262 {"Quantized", c10::DispatchKey::Quantized},
263 {"CustomRNGKeyId", c10::DispatchKey::CustomRNGKeyId},
264 {"MkldnnCPU", c10::DispatchKey::MkldnnCPU},
265 {"Sparse", c10::DispatchKey::Sparse},
266 {"SparseCsrCPU", c10::DispatchKey::SparseCsrCPU},
267 {"SparseCsrCUDA", c10::DispatchKey::SparseCsrCUDA},
268 {"BackendSelect", c10::DispatchKey::BackendSelect},
269 {"Python", c10::DispatchKey::Python},
270 {"PythonTLSSnapshot", c10::DispatchKey::PythonTLSSnapshot},
271 {"Fake", c10::DispatchKey::Fake},
272 {"Named", c10::DispatchKey::Named},
273 {"Conjugate", c10::DispatchKey::Conjugate},
274 {"Negative", c10::DispatchKey::Negative},
275 {"ZeroTensor", c10::DispatchKey::ZeroTensor},
276 {"FuncTorchDynamicLayerBackMode",
277 c10::DispatchKey::FuncTorchDynamicLayerBackMode},
278 {"Functionalize", c10::DispatchKey::Functionalize},
279 {"ADInplaceOrView", c10::DispatchKey::ADInplaceOrView},
280 {"AutogradOther", c10::DispatchKey::AutogradOther},
281 {"AutogradFunctionality", c10::DispatchKey::AutogradFunctionality},
282 {"AutogradNestedTensor", c10::DispatchKey::AutogradNestedTensor},
283 {"Tracer", c10::DispatchKey::Tracer},
284 {"AutocastCPU", c10::DispatchKey::AutocastCPU},
285 {"AutocastXPU", c10::DispatchKey::AutocastXPU},
286 {"AutocastHPU", c10::DispatchKey::AutocastHPU},
287 {"AutocastCUDA", c10::DispatchKey::AutocastCUDA},
288 {"FuncTorchBatched", c10::DispatchKey::FuncTorchBatched},
289 {"FuncTorchVmapMode", c10::DispatchKey::FuncTorchVmapMode},
290 {"Batched", c10::DispatchKey::Batched},
291 {"VmapMode", c10::DispatchKey::VmapMode},
292 {"DeferredInit", c10::DispatchKey::DeferredInit},
293 {"FuncTorchGradWrapper", c10::DispatchKey::FuncTorchGradWrapper},
294 {"FuncTorchDynamicLayerFrontMode",
295 c10::DispatchKey::FuncTorchDynamicLayerFrontMode},
296 {"TESTING_ONLY_GenericWrapper",
297 c10::DispatchKey::TESTING_ONLY_GenericWrapper},
298 {"TESTING_ONLY_GenericMode", c10::DispatchKey::TESTING_ONLY_GenericMode},
299 {"PythonDispatcher", c10::DispatchKey::PythonDispatcher},
300
301 {"CPU", c10::DispatchKey::CPU},
302 {"CUDA", c10::DispatchKey::CUDA},
303 {"HIP", c10::DispatchKey::HIP},
304 {"XLA", c10::DispatchKey::XLA},
305 {"MPS", c10::DispatchKey::MPS},
306 {"XPU", c10::DispatchKey::XPU},
307 {"IPU", c10::DispatchKey::IPU},
308 {"HPU", c10::DispatchKey::HPU},
309 {"Lazy", c10::DispatchKey::Lazy},
310 {"MTIA", c10::DispatchKey::MTIA},
311 {"NestedTensor", c10::DispatchKey::NestedTensor},
312 {"NestedTensorCPU", c10::DispatchKey::NestedTensorCPU},
313 {"NestedTensorCUDA", c10::DispatchKey::NestedTensorCUDA},
314 {"NestedTensorMeta", c10::DispatchKey::NestedTensorMeta},
315 {"PrivateUse1", c10::DispatchKey::PrivateUse1},
316 {"PrivateUse2", c10::DispatchKey::PrivateUse2},
317 {"PrivateUse3", c10::DispatchKey::PrivateUse3},
318
319 {"QuantizedCPU", c10::DispatchKey::QuantizedCPU},
320 {"QuantizedCUDA", c10::DispatchKey::QuantizedCUDA},
321 {"QuantizedXPU", c10::DispatchKey::QuantizedXPU},
322
323 {"SparseCPU", c10::DispatchKey::SparseCPU},
324 {"SparseCUDA", c10::DispatchKey::SparseCUDA},
325 {"SparseHIP", c10::DispatchKey::SparseHIP},
326 {"SparseXPU", c10::DispatchKey::SparseXPU},
327 {"SparseVE", c10::DispatchKey::SparseVE},
328 {"SparseMeta", c10::DispatchKey::SparseMeta},
329
330 {"AutogradCPU", c10::DispatchKey::AutogradCPU},
331 {"AutogradCUDA", c10::DispatchKey::AutogradCUDA},
332 {"AutogradXLA", c10::DispatchKey::AutogradXLA},
333 {"AutogradLazy", c10::DispatchKey::AutogradLazy},
334 {"AutogradMeta", c10::DispatchKey::AutogradMeta},
335 {"AutogradIPU", c10::DispatchKey::AutogradIPU},
336 {"AutogradXPU", c10::DispatchKey::AutogradXPU},
337 {"AutogradMPS", c10::DispatchKey::AutogradMPS},
338 {"AutogradHPU", c10::DispatchKey::AutogradHPU},
339 {"AutogradPrivateUse1", c10::DispatchKey::AutogradPrivateUse1},
340 {"AutogradPrivateUse2", c10::DispatchKey::AutogradPrivateUse2},
341 {"AutogradPrivateUse3", c10::DispatchKey::AutogradPrivateUse3},
342
343 {"Autograd", c10::DispatchKey::Autograd},
344 {"CompositeImplicitAutograd",
345 c10::DispatchKey::CompositeImplicitAutograd},
346 {"CompositeImplicitAutogradNestedTensor",
347 c10::DispatchKey::CompositeImplicitAutogradNestedTensor},
348 {"CompositeExplicitAutograd",
349 c10::DispatchKey::CompositeExplicitAutograd},
350 {"CompositeExplicitAutogradNonFunctional",
351 c10::DispatchKey::CompositeExplicitAutogradNonFunctional},
352 {"FuncTorchBatchedDecomposition",
353 c10::DispatchKey::FuncTorchBatchedDecomposition},
354 };
355 auto it = key_map.find(k);
356 TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: ", k);
357 return it->second;
358}
359
360} // namespace c10
361