1 | #pragma once |
2 | |
3 | #include <c10/core/DeviceType.h> |
4 | #include <c10/core/DispatchKey.h> |
5 | #include <c10/core/DispatchKeySet.h> |
6 | #include <c10/util/Exception.h> |
7 | |
8 | #include <stdexcept> |
9 | |
10 | namespace c10 { |
11 | |
12 | /** |
13 | * This legacy enum class defines the set of backends supported by old school, |
14 | * code generated Type-based ATen. A "backend" in this sense roughly |
15 | * corresponds to the cartesian product of (device type, layout), but restricted |
16 | * only to combinations which we actually have kernels for. Backend does NOT |
17 | * include dtype. |
18 | * |
19 | * The reason we are sunsetting this enum class is because it doesn't allow for |
20 | * open registration; e.g., if you want to add SparseXLA, you'd have to |
21 | * edit this enum; you wouldn't be able to do it out of tree. DispatchKey is |
22 | * the replacement for Backend which supports open registration. |
23 | * |
24 | * NB: The concept of 'Backend' here disagrees with the notion of backend |
25 | * exposed to users in torch.backends. Backend here is something like "CPU" |
26 | * or "SparseCUDA"; backend in torch.backends is something like "MKL" or |
27 | * "CUDNN". |
28 | */ |
29 | enum class Backend { |
30 | CPU, |
31 | CUDA, |
32 | HIP, |
33 | VE, |
34 | FPGA, |
35 | IPU, |
36 | XPU, |
37 | SparseCPU, |
38 | SparseCUDA, |
39 | SparseCsrCPU, |
40 | SparseCsrCUDA, |
41 | SparseHIP, |
42 | SparseVE, |
43 | SparseXPU, |
44 | ORT, |
45 | XLA, |
46 | Vulkan, |
47 | Metal, |
48 | Meta, |
49 | QuantizedCPU, |
50 | QuantizedCUDA, |
51 | QuantizedXPU, |
52 | Undefined, |
53 | MkldnnCPU, |
54 | MPS, |
55 | HPU, |
56 | Lazy, |
57 | MTIA, |
58 | PrivateUse1, |
59 | NumOptions |
60 | }; |
61 | |
62 | static inline Backend dispatchKeyToBackend(DispatchKey t) { |
63 | if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) { |
64 | return Backend::CPU; |
65 | } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) { |
66 | return Backend::CUDA; |
67 | } else if (t == DispatchKey::HIP) { |
68 | return Backend::HIP; |
69 | } else if (t == DispatchKey::VE) { |
70 | return Backend::VE; |
71 | } else if (t == DispatchKey::FPGA) { |
72 | return Backend::FPGA; |
73 | } else if (t == DispatchKey::ORT) { |
74 | return Backend::ORT; |
75 | } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { |
76 | return Backend::XLA; |
77 | } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { |
78 | return Backend::Lazy; |
79 | } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) { |
80 | return Backend::MPS; |
81 | } else if (t == DispatchKey::Vulkan) { |
82 | return Backend::Vulkan; |
83 | } else if (t == DispatchKey::Metal) { |
84 | return Backend::Metal; |
85 | } else if (t == DispatchKey::Meta) { |
86 | return Backend::Meta; |
87 | } else if (t == DispatchKey::SparseCPU) { |
88 | return Backend::SparseCPU; |
89 | } else if (t == DispatchKey::SparseCUDA) { |
90 | return Backend::SparseCUDA; |
91 | } else if (t == DispatchKey::SparseHIP) { |
92 | return Backend::SparseHIP; |
93 | } else if (t == DispatchKey::SparseVE) { |
94 | return Backend::SparseVE; |
95 | } else if (t == DispatchKey::SparseCsrCPU) { |
96 | return Backend::SparseCsrCPU; |
97 | } else if (t == DispatchKey::SparseCsrCUDA) { |
98 | return Backend::SparseCsrCUDA; |
99 | } else if (t == DispatchKey::MkldnnCPU) { |
100 | return Backend::MkldnnCPU; |
101 | } else if (t == DispatchKey::QuantizedCPU) { |
102 | return Backend::QuantizedCPU; |
103 | } else if (t == DispatchKey::QuantizedCUDA) { |
104 | return Backend::QuantizedCUDA; |
105 | } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) { |
106 | return Backend::IPU; |
107 | } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) { |
108 | return Backend::XPU; |
109 | } else if (t == DispatchKey::SparseXPU) { |
110 | return Backend::SparseXPU; |
111 | } else if (t == DispatchKey::QuantizedXPU) { |
112 | return Backend::QuantizedXPU; |
113 | } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) { |
114 | return Backend::HPU; |
115 | } else if (t == DispatchKey::MTIA) { |
116 | return Backend::MTIA; |
117 | } else if (t == DispatchKey::PrivateUse1) { |
118 | return Backend::PrivateUse1; |
119 | } else if (t == DispatchKey::Undefined) { |
120 | return Backend::Undefined; |
121 | } else { |
122 | TORCH_CHECK(false, "Unrecognized tensor type ID: " , t); |
123 | } |
124 | } |
125 | |
126 | static inline DispatchKey backendToDispatchKey(Backend b) { |
127 | switch (b) { |
128 | case Backend::CPU: |
129 | return DispatchKey::CPU; |
130 | case Backend::CUDA: |
131 | return DispatchKey::CUDA; |
132 | case Backend::HIP: |
133 | return DispatchKey::HIP; |
134 | case Backend::VE: |
135 | return DispatchKey::VE; |
136 | case Backend::FPGA: |
137 | return DispatchKey::FPGA; |
138 | case Backend::ORT: |
139 | return DispatchKey::ORT; |
140 | case Backend::XLA: |
141 | return DispatchKey::XLA; |
142 | case Backend::Lazy: |
143 | return DispatchKey::Lazy; |
144 | case Backend::IPU: |
145 | return DispatchKey::IPU; |
146 | case Backend::XPU: |
147 | return DispatchKey::XPU; |
148 | case Backend::SparseXPU: |
149 | return DispatchKey::SparseXPU; |
150 | case Backend::SparseCPU: |
151 | return DispatchKey::SparseCPU; |
152 | case Backend::SparseCUDA: |
153 | return DispatchKey::SparseCUDA; |
154 | case Backend::SparseHIP: |
155 | return DispatchKey::SparseHIP; |
156 | case Backend::SparseVE: |
157 | return DispatchKey::SparseVE; |
158 | case Backend::SparseCsrCPU: |
159 | return DispatchKey::SparseCsrCPU; |
160 | case Backend::SparseCsrCUDA: |
161 | return DispatchKey::SparseCsrCUDA; |
162 | case Backend::MkldnnCPU: |
163 | return DispatchKey::MkldnnCPU; |
164 | case Backend::Vulkan: |
165 | return DispatchKey::Vulkan; |
166 | case Backend::Metal: |
167 | return DispatchKey::Metal; |
168 | case Backend::Meta: |
169 | return DispatchKey::Meta; |
170 | case Backend::QuantizedCPU: |
171 | return DispatchKey::QuantizedCPU; |
172 | case Backend::QuantizedCUDA: |
173 | return DispatchKey::QuantizedCUDA; |
174 | case Backend::Undefined: |
175 | return DispatchKey::Undefined; |
176 | case Backend::MPS: |
177 | return DispatchKey::MPS; |
178 | case Backend::HPU: |
179 | return DispatchKey::HPU; |
180 | case Backend::MTIA: |
181 | return DispatchKey::MTIA; |
182 | case Backend::PrivateUse1: |
183 | return DispatchKey::PrivateUse1; |
184 | default: |
185 | throw std::runtime_error("Unknown backend" ); |
186 | } |
187 | } |
188 | |
189 | static inline DeviceType backendToDeviceType(Backend b) { |
190 | switch (b) { |
191 | case Backend::CPU: |
192 | return DeviceType::CPU; |
193 | case Backend::CUDA: |
194 | return DeviceType::CUDA; |
195 | case Backend::HIP: |
196 | return DeviceType::HIP; |
197 | case Backend::VE: |
198 | return DeviceType::VE; |
199 | case Backend::FPGA: |
200 | return DeviceType::FPGA; |
201 | case Backend::ORT: |
202 | return DeviceType::ORT; |
203 | case Backend::XLA: |
204 | return DeviceType::XLA; |
205 | case Backend::Lazy: |
206 | return DeviceType::Lazy; |
207 | case Backend::SparseCPU: |
208 | return DeviceType::CPU; |
209 | case Backend::SparseCUDA: |
210 | return DeviceType::CUDA; |
211 | case Backend::SparseHIP: |
212 | return DeviceType::HIP; |
213 | case Backend::SparseVE: |
214 | return DeviceType::VE; |
215 | case Backend::SparseCsrCPU: |
216 | return DeviceType::CPU; |
217 | case Backend::SparseCsrCUDA: |
218 | return DeviceType::CUDA; |
219 | case Backend::IPU: |
220 | return DeviceType::IPU; |
221 | case Backend::XPU: |
222 | case Backend::SparseXPU: |
223 | case Backend::QuantizedXPU: |
224 | return DeviceType::XPU; |
225 | case Backend::MkldnnCPU: |
226 | case Backend::QuantizedCPU: |
227 | return DeviceType::CPU; |
228 | case Backend::QuantizedCUDA: |
229 | return DeviceType::CUDA; |
230 | case Backend::Vulkan: |
231 | return DeviceType::Vulkan; |
232 | case Backend::Metal: |
233 | return DeviceType::Metal; |
234 | case Backend::Meta: |
235 | return DeviceType::Meta; |
236 | case Backend::MPS: |
237 | return DeviceType::MPS; |
238 | case Backend::HPU: |
239 | return DeviceType::HPU; |
240 | case Backend::MTIA: |
241 | return DeviceType::MTIA; |
242 | case Backend::PrivateUse1: |
243 | return DeviceType::PrivateUse1; |
244 | case Backend::Undefined: |
245 | TORCH_CHECK(false, "Undefined backend is not a valid device type" ); |
246 | default: |
247 | TORCH_CHECK(false, "Unknown backend" ); |
248 | } |
249 | } |
250 | |
251 | // TODO: This probably shouldn't actually be static inline |
252 | static inline const char* toString(Backend b) { |
253 | switch (b) { |
254 | case Backend::CPU: |
255 | return "CPU" ; |
256 | case Backend::CUDA: |
257 | return "CUDA" ; |
258 | case Backend::HIP: |
259 | return "HIP" ; |
260 | case Backend::VE: |
261 | return "VE" ; |
262 | case Backend::FPGA: |
263 | return "FPGA" ; |
264 | case Backend::XPU: |
265 | return "XPU" ; |
266 | case Backend::IPU: |
267 | return "IPU" ; |
268 | case Backend::ORT: |
269 | return "ORT" ; |
270 | case Backend::XLA: |
271 | return "XLA" ; |
272 | case Backend::Lazy: |
273 | return "Lazy" ; |
274 | case Backend::MPS: |
275 | return "MPS" ; |
276 | case Backend::SparseCPU: |
277 | return "SparseCPU" ; |
278 | case Backend::SparseCUDA: |
279 | return "SparseCUDA" ; |
280 | case Backend::SparseHIP: |
281 | return "SparseHIP" ; |
282 | case Backend::SparseVE: |
283 | return "SparseVE" ; |
284 | case Backend::SparseXPU: |
285 | return "SparseXPU" ; |
286 | case Backend::SparseCsrCPU: |
287 | return "SparseCsrCPU" ; |
288 | case Backend::SparseCsrCUDA: |
289 | return "SparseCsrCUDA" ; |
290 | case Backend::MkldnnCPU: |
291 | return "MkldnnCPU" ; |
292 | case Backend::Vulkan: |
293 | return "Vulkan" ; |
294 | case Backend::Metal: |
295 | return "Metal" ; |
296 | case Backend::Meta: |
297 | return "Meta" ; |
298 | case Backend::QuantizedCPU: |
299 | return "QuantizedCPU" ; |
300 | case Backend::QuantizedCUDA: |
301 | return "QuantizedCUDA" ; |
302 | case Backend::QuantizedXPU: |
303 | return "QuantizedXPU" ; |
304 | case Backend::HPU: |
305 | return "HPU" ; |
306 | case Backend::MTIA: |
307 | return "MTIA" ; |
308 | case Backend::PrivateUse1: |
309 | return "PrivateUseOne" ; |
310 | default: |
311 | return "UNKNOWN_BACKEND" ; |
312 | } |
313 | } |
314 | |
315 | static inline bool isSparse(Backend b) { |
316 | switch (b) { |
317 | case Backend::SparseXPU: |
318 | case Backend::SparseCPU: |
319 | case Backend::SparseCUDA: |
320 | case Backend::SparseHIP: |
321 | case Backend::SparseVE: |
322 | return true; |
323 | default: |
324 | return false; |
325 | } |
326 | } |
327 | |
328 | static inline bool isSparseCsr(Backend b) { |
329 | switch (b) { |
330 | case Backend::SparseCsrCPU: |
331 | case Backend::SparseCsrCUDA: |
332 | return true; |
333 | default: |
334 | return false; |
335 | } |
336 | } |
337 | |
338 | } // namespace c10 |
339 | |