1 | #include <c10/core/DispatchKey.h> |
2 | #include <c10/core/DispatchKeySet.h> |
3 | |
4 | #include <unordered_map> |
5 | |
6 | namespace c10 { |
7 | |
8 | const char* toString(BackendComponent t) { |
9 | switch (t) { |
10 | case BackendComponent::CPUBit: |
11 | return "CPUBit" ; |
12 | case BackendComponent::CUDABit: |
13 | return "CUDABit" ; |
14 | case BackendComponent::HIPBit: |
15 | return "HIPBit" ; |
16 | case BackendComponent::XLABit: |
17 | return "XLABit" ; |
18 | case BackendComponent::LazyBit: |
19 | return "LazyBit" ; |
20 | case BackendComponent::MetaBit: |
21 | return "MetaBit" ; |
22 | case BackendComponent::XPUBit: |
23 | return "XPUBit" ; |
24 | case BackendComponent::IPUBit: |
25 | return "IPUBit" ; |
26 | case BackendComponent::MPSBit: |
27 | return "MPSBit" ; |
28 | case BackendComponent::HPUBit: |
29 | return "HPUBit" ; |
30 | case BackendComponent::VEBit: |
31 | return "VEBit" ; |
32 | case BackendComponent::MTIABit: |
33 | return "MTIA" ; |
34 | case BackendComponent::PrivateUse1Bit: |
35 | return "PrivateUse1Bit" ; |
36 | case BackendComponent::PrivateUse2Bit: |
37 | return "PrivateUse2Bit" ; |
38 | case BackendComponent::PrivateUse3Bit: |
39 | return "PrivateUse3Bit" ; |
40 | case BackendComponent::InvalidBit: |
41 | return "InvalidBit" ; |
42 | default: |
43 | return "UNKNOWN_BACKEND_BIT" ; |
44 | } |
45 | } |
46 | |
47 | BackendComponent toBackendComponent(DeviceType device_type) { |
48 | switch (device_type) { |
49 | #define DO_CASE(device, _) \ |
50 | case DeviceType::device: { \ |
51 | return toBackendComponent(DispatchKey::device); \ |
52 | } |
53 | C10_FORALL_BACKEND_DEVICE_TYPES(DO_CASE, unused) |
54 | #undef DO_CASE |
55 | default: |
56 | return BackendComponent::InvalidBit; |
57 | } |
58 | } |
59 | |
60 | const char* toString(DispatchKey t) { |
61 | switch (t) { |
62 | case DispatchKey::Undefined: |
63 | return "Undefined" ; |
64 | |
65 | case DispatchKey::Dense: |
66 | return "Dense" ; |
67 | case DispatchKey::FPGA: |
68 | return "FPGA" ; |
69 | case DispatchKey::ORT: |
70 | return "ORT" ; |
71 | case DispatchKey::Vulkan: |
72 | return "Vulkan" ; |
73 | case DispatchKey::Metal: |
74 | return "Metal" ; |
75 | |
76 | case DispatchKey::Lazy: |
77 | return "Lazy" ; |
78 | case DispatchKey::MPS: |
79 | return "MPS" ; |
80 | case DispatchKey::HPU: |
81 | return "HPU" ; |
82 | case DispatchKey::MTIA: |
83 | return "MTIA" ; |
84 | |
85 | case DispatchKey::Quantized: |
86 | return "Quantized" ; |
87 | case DispatchKey::CustomRNGKeyId: |
88 | return "CustomRNGKeyId" ; |
89 | case DispatchKey::MkldnnCPU: |
90 | return "MkldnnCPU" ; |
91 | |
92 | case DispatchKey::Sparse: |
93 | return "Sparse" ; |
94 | case DispatchKey::SparseCsrCPU: |
95 | return "SparseCsrCPU" ; |
96 | case DispatchKey::SparseCsrCUDA: |
97 | return "SparseCsrCUDA" ; |
98 | |
99 | case DispatchKey::NestedTensor: |
100 | return "NestedTensor" ; |
101 | |
102 | case DispatchKey::BackendSelect: |
103 | return "BackendSelect" ; |
104 | |
105 | case DispatchKey::Python: |
106 | return "Python" ; |
107 | |
108 | case DispatchKey::Fake: |
109 | return "Fake" ; |
110 | case DispatchKey::FuncTorchDynamicLayerBackMode: |
111 | return "FuncTorchDynamicLayerBackMode" ; |
112 | |
113 | case DispatchKey::Functionalize: |
114 | return "Functionalize" ; |
115 | |
116 | case DispatchKey::Named: |
117 | return "Named" ; |
118 | |
119 | case DispatchKey::Conjugate: |
120 | return "Conjugate" ; |
121 | case DispatchKey::Negative: |
122 | return "Negative" ; |
123 | case DispatchKey::ZeroTensor: |
124 | return "ZeroTensor" ; |
125 | |
126 | case DispatchKey::ADInplaceOrView: |
127 | return "ADInplaceOrView" ; |
128 | |
129 | case DispatchKey::AutogradOther: |
130 | return "AutogradOther" ; |
131 | case DispatchKey::AutogradFunctionality: |
132 | return "AutogradFunctionality" ; |
133 | case DispatchKey::AutogradNestedTensor: |
134 | return "AutogradNestedTensor" ; |
135 | |
136 | case DispatchKey::Tracer: |
137 | return "Tracer" ; |
138 | |
139 | case DispatchKey::AutocastCPU: |
140 | return "AutocastCPU" ; |
141 | case DispatchKey::AutocastXPU: |
142 | return "AutocastXPU" ; |
143 | case DispatchKey::AutocastHPU: |
144 | return "AutocastHPU" ; |
145 | case DispatchKey::AutocastCUDA: |
146 | return "AutocastCUDA" ; |
147 | |
148 | case DispatchKey::FuncTorchBatched: |
149 | return "FuncTorchBatched" ; |
150 | case DispatchKey::FuncTorchVmapMode: |
151 | return "FuncTorchVmapMode" ; |
152 | |
153 | case DispatchKey::Batched: |
154 | return "Batched" ; |
155 | case DispatchKey::VmapMode: |
156 | return "VmapMode" ; |
157 | |
158 | case DispatchKey::FuncTorchGradWrapper: |
159 | return "FuncTorchGradWrapper" ; |
160 | |
161 | case DispatchKey::DeferredInit: |
162 | return "DeferredInit" ; |
163 | case DispatchKey::PythonTLSSnapshot: |
164 | return "PythonTLSSnapshot" ; |
165 | |
166 | // Note [Out-of-tree vmap+grad prototype] |
167 | // The following keys are used in the implementation of the out-of-tree |
168 | // composable functions transforms (vmap+grad) prototype that lives at |
169 | // https://github.com/zou3519/functorch |
170 | // We plan on eventually upstreaming the prototype into core, at which |
171 | // point it will have a different design that should use fewer keys. |
172 | case DispatchKey::FuncTorchDynamicLayerFrontMode: |
173 | return "FuncTorchDynamicLayerFrontMode" ; |
174 | |
175 | case DispatchKey::TESTING_ONLY_GenericWrapper: |
176 | return "TESTING_ONLY_GenericWrapper" ; |
177 | |
178 | case DispatchKey::TESTING_ONLY_GenericMode: |
179 | return "TESTING_ONLY_GenericMode" ; |
180 | |
181 | case DispatchKey::PythonDispatcher: |
182 | return "PythonDispatcher" ; |
183 | |
184 | // Aliases |
185 | |
186 | case DispatchKey::Autograd: |
187 | return "Autograd" ; |
188 | case DispatchKey::CompositeImplicitAutograd: |
189 | return "CompositeImplicitAutograd" ; |
190 | case DispatchKey::CompositeImplicitAutogradNestedTensor: |
191 | return "CompositeImplicitAutogradNestedTensor" ; |
192 | case DispatchKey::CompositeExplicitAutograd: |
193 | return "CompositeExplicitAutograd" ; |
194 | case DispatchKey::CompositeExplicitAutogradNonFunctional: |
195 | return "CompositeExplicitAutogradNonFunctional" ; |
196 | case DispatchKey::FuncTorchBatchedDecomposition: |
197 | return "FuncTorchBatchedDecomposition" ; |
198 | |
199 | // Per-backend dispatch keys |
200 | |
201 | default: |
202 | auto bc = toBackendComponent(t); |
203 | auto fk = toFunctionalityKey(t); |
204 | |
205 | switch (fk) { |
206 | #define ENTRY(backend, functionality) \ |
207 | case BackendComponent::backend##Bit: \ |
208 | return #functionality #backend; |
209 | |
210 | #define FORALL_BC(dkname, prefix) \ |
211 | case DispatchKey::dkname: \ |
212 | switch (bc) { \ |
213 | C10_FORALL_BACKEND_COMPONENTS(ENTRY, prefix) \ |
214 | default: \ |
215 | return #prefix "Undefined"; \ |
216 | } |
217 | |
218 | C10_FORALL_FUNCTIONALITY_KEYS(FORALL_BC) |
219 | |
220 | default: |
221 | switch (bc) { |
222 | C10_FORALL_BACKEND_COMPONENTS(ENTRY, Unknown) |
223 | default: |
224 | return "UnknownUnknown" ; |
225 | } |
226 | |
227 | #undef FORALL_BC |
228 | #undef ENTRY |
229 | } |
230 | } |
231 | } |
232 | |
233 | std::ostream& operator<<(std::ostream& str, DispatchKey rhs) { |
234 | return str << toString(rhs); |
235 | } |
236 | std::ostream& operator<<(std::ostream& str, BackendComponent rhs) { |
237 | return str << toString(rhs); |
238 | } |
239 | |
240 | DispatchKey getAutogradKeyFromBackend(BackendComponent k) { |
241 | // We want this to return an autograd key. We're relying on the fact that |
242 | // getAutogradRelatedKeySetFromBackend returns an autograd key + |
243 | // ADInplaceOrView, and autograd has higher precedence. The core mapping from |
244 | // backend -> autograd key lives in `getAutogradRelatedKeySetFromBackend` |
245 | // instead of here for performance. `getAutogradRelatedKeySetFromBackend` is a |
246 | // hotpath function, and we want to make sure that it doesn't have to |
247 | // construct any DispatchKeySets at runtime. |
248 | return getAutogradRelatedKeySetFromBackend(k).highestPriorityTypeId(); |
249 | } |
250 | |
251 | c10::DispatchKey parseDispatchKey(const std::string& k) { |
252 | static std::unordered_map<std::string, c10::DispatchKey> key_map = { |
253 | {"Undefined" , c10::DispatchKey::Undefined}, |
254 | {"Dense" , c10::DispatchKey::Dense}, |
255 | {"FPGA" , c10::DispatchKey::FPGA}, |
256 | {"ORT" , c10::DispatchKey::ORT}, |
257 | {"MPS" , c10::DispatchKey::MPS}, |
258 | {"Vulkan" , c10::DispatchKey::Vulkan}, |
259 | {"Metal" , c10::DispatchKey::Metal}, |
260 | {"VE" , c10::DispatchKey::VE}, |
261 | {"Meta" , c10::DispatchKey::Meta}, |
262 | {"Quantized" , c10::DispatchKey::Quantized}, |
263 | {"CustomRNGKeyId" , c10::DispatchKey::CustomRNGKeyId}, |
264 | {"MkldnnCPU" , c10::DispatchKey::MkldnnCPU}, |
265 | {"Sparse" , c10::DispatchKey::Sparse}, |
266 | {"SparseCsrCPU" , c10::DispatchKey::SparseCsrCPU}, |
267 | {"SparseCsrCUDA" , c10::DispatchKey::SparseCsrCUDA}, |
268 | {"BackendSelect" , c10::DispatchKey::BackendSelect}, |
269 | {"Python" , c10::DispatchKey::Python}, |
270 | {"PythonTLSSnapshot" , c10::DispatchKey::PythonTLSSnapshot}, |
271 | {"Fake" , c10::DispatchKey::Fake}, |
272 | {"Named" , c10::DispatchKey::Named}, |
273 | {"Conjugate" , c10::DispatchKey::Conjugate}, |
274 | {"Negative" , c10::DispatchKey::Negative}, |
275 | {"ZeroTensor" , c10::DispatchKey::ZeroTensor}, |
276 | {"FuncTorchDynamicLayerBackMode" , |
277 | c10::DispatchKey::FuncTorchDynamicLayerBackMode}, |
278 | {"Functionalize" , c10::DispatchKey::Functionalize}, |
279 | {"ADInplaceOrView" , c10::DispatchKey::ADInplaceOrView}, |
280 | {"AutogradOther" , c10::DispatchKey::AutogradOther}, |
281 | {"AutogradFunctionality" , c10::DispatchKey::AutogradFunctionality}, |
282 | {"AutogradNestedTensor" , c10::DispatchKey::AutogradNestedTensor}, |
283 | {"Tracer" , c10::DispatchKey::Tracer}, |
284 | {"AutocastCPU" , c10::DispatchKey::AutocastCPU}, |
285 | {"AutocastXPU" , c10::DispatchKey::AutocastXPU}, |
286 | {"AutocastHPU" , c10::DispatchKey::AutocastHPU}, |
287 | {"AutocastCUDA" , c10::DispatchKey::AutocastCUDA}, |
288 | {"FuncTorchBatched" , c10::DispatchKey::FuncTorchBatched}, |
289 | {"FuncTorchVmapMode" , c10::DispatchKey::FuncTorchVmapMode}, |
290 | {"Batched" , c10::DispatchKey::Batched}, |
291 | {"VmapMode" , c10::DispatchKey::VmapMode}, |
292 | {"DeferredInit" , c10::DispatchKey::DeferredInit}, |
293 | {"FuncTorchGradWrapper" , c10::DispatchKey::FuncTorchGradWrapper}, |
294 | {"FuncTorchDynamicLayerFrontMode" , |
295 | c10::DispatchKey::FuncTorchDynamicLayerFrontMode}, |
296 | {"TESTING_ONLY_GenericWrapper" , |
297 | c10::DispatchKey::TESTING_ONLY_GenericWrapper}, |
298 | {"TESTING_ONLY_GenericMode" , c10::DispatchKey::TESTING_ONLY_GenericMode}, |
299 | {"PythonDispatcher" , c10::DispatchKey::PythonDispatcher}, |
300 | |
301 | {"CPU" , c10::DispatchKey::CPU}, |
302 | {"CUDA" , c10::DispatchKey::CUDA}, |
303 | {"HIP" , c10::DispatchKey::HIP}, |
304 | {"XLA" , c10::DispatchKey::XLA}, |
305 | {"MPS" , c10::DispatchKey::MPS}, |
306 | {"XPU" , c10::DispatchKey::XPU}, |
307 | {"IPU" , c10::DispatchKey::IPU}, |
308 | {"HPU" , c10::DispatchKey::HPU}, |
309 | {"Lazy" , c10::DispatchKey::Lazy}, |
310 | {"MTIA" , c10::DispatchKey::MTIA}, |
311 | {"NestedTensor" , c10::DispatchKey::NestedTensor}, |
312 | {"NestedTensorCPU" , c10::DispatchKey::NestedTensorCPU}, |
313 | {"NestedTensorCUDA" , c10::DispatchKey::NestedTensorCUDA}, |
314 | {"NestedTensorMeta" , c10::DispatchKey::NestedTensorMeta}, |
315 | {"PrivateUse1" , c10::DispatchKey::PrivateUse1}, |
316 | {"PrivateUse2" , c10::DispatchKey::PrivateUse2}, |
317 | {"PrivateUse3" , c10::DispatchKey::PrivateUse3}, |
318 | |
319 | {"QuantizedCPU" , c10::DispatchKey::QuantizedCPU}, |
320 | {"QuantizedCUDA" , c10::DispatchKey::QuantizedCUDA}, |
321 | {"QuantizedXPU" , c10::DispatchKey::QuantizedXPU}, |
322 | |
323 | {"SparseCPU" , c10::DispatchKey::SparseCPU}, |
324 | {"SparseCUDA" , c10::DispatchKey::SparseCUDA}, |
325 | {"SparseHIP" , c10::DispatchKey::SparseHIP}, |
326 | {"SparseXPU" , c10::DispatchKey::SparseXPU}, |
327 | {"SparseVE" , c10::DispatchKey::SparseVE}, |
328 | {"SparseMeta" , c10::DispatchKey::SparseMeta}, |
329 | |
330 | {"AutogradCPU" , c10::DispatchKey::AutogradCPU}, |
331 | {"AutogradCUDA" , c10::DispatchKey::AutogradCUDA}, |
332 | {"AutogradXLA" , c10::DispatchKey::AutogradXLA}, |
333 | {"AutogradLazy" , c10::DispatchKey::AutogradLazy}, |
334 | {"AutogradMeta" , c10::DispatchKey::AutogradMeta}, |
335 | {"AutogradIPU" , c10::DispatchKey::AutogradIPU}, |
336 | {"AutogradXPU" , c10::DispatchKey::AutogradXPU}, |
337 | {"AutogradMPS" , c10::DispatchKey::AutogradMPS}, |
338 | {"AutogradHPU" , c10::DispatchKey::AutogradHPU}, |
339 | {"AutogradPrivateUse1" , c10::DispatchKey::AutogradPrivateUse1}, |
340 | {"AutogradPrivateUse2" , c10::DispatchKey::AutogradPrivateUse2}, |
341 | {"AutogradPrivateUse3" , c10::DispatchKey::AutogradPrivateUse3}, |
342 | |
343 | {"Autograd" , c10::DispatchKey::Autograd}, |
344 | {"CompositeImplicitAutograd" , |
345 | c10::DispatchKey::CompositeImplicitAutograd}, |
346 | {"CompositeImplicitAutogradNestedTensor" , |
347 | c10::DispatchKey::CompositeImplicitAutogradNestedTensor}, |
348 | {"CompositeExplicitAutograd" , |
349 | c10::DispatchKey::CompositeExplicitAutograd}, |
350 | {"CompositeExplicitAutogradNonFunctional" , |
351 | c10::DispatchKey::CompositeExplicitAutogradNonFunctional}, |
352 | {"FuncTorchBatchedDecomposition" , |
353 | c10::DispatchKey::FuncTorchBatchedDecomposition}, |
354 | }; |
355 | auto it = key_map.find(k); |
356 | TORCH_CHECK(it != key_map.end(), "could not parse dispatch key: " , k); |
357 | return it->second; |
358 | } |
359 | |
360 | } // namespace c10 |
361 | |