1#pragma once
2
3#include <c10/core/DeviceType.h>
4#include <c10/core/DispatchKey.h>
5#include <c10/core/DispatchKeySet.h>
6#include <c10/util/Exception.h>
7
8#include <stdexcept>
9
10namespace c10 {
11
12/**
13 * This legacy enum class defines the set of backends supported by old school,
14 * code generated Type-based ATen. A "backend" in this sense roughly
15 * corresponds to the cartesian product of (device type, layout), but restricted
16 * only to combinations which we actually have kernels for. Backend does NOT
17 * include dtype.
18 *
19 * The reason we are sunsetting this enum class is because it doesn't allow for
20 * open registration; e.g., if you want to add SparseXLA, you'd have to
21 * edit this enum; you wouldn't be able to do it out of tree. DispatchKey is
22 * the replacement for Backend which supports open registration.
23 *
24 * NB: The concept of 'Backend' here disagrees with the notion of backend
25 * exposed to users in torch.backends. Backend here is something like "CPU"
26 * or "SparseCUDA"; backend in torch.backends is something like "MKL" or
27 * "CUDNN".
28 */
29enum class Backend {
30 CPU,
31 CUDA,
32 HIP,
33 VE,
34 FPGA,
35 IPU,
36 XPU,
37 SparseCPU,
38 SparseCUDA,
39 SparseCsrCPU,
40 SparseCsrCUDA,
41 SparseHIP,
42 SparseVE,
43 SparseXPU,
44 ORT,
45 XLA,
46 Vulkan,
47 Metal,
48 Meta,
49 QuantizedCPU,
50 QuantizedCUDA,
51 QuantizedXPU,
52 Undefined,
53 MkldnnCPU,
54 MPS,
55 HPU,
56 Lazy,
57 MTIA,
58 PrivateUse1,
59 NumOptions
60};
61
62static inline Backend dispatchKeyToBackend(DispatchKey t) {
63 if (t == DispatchKey::CPU || t == DispatchKey::AutogradCPU) {
64 return Backend::CPU;
65 } else if (t == DispatchKey::CUDA || t == DispatchKey::AutogradCUDA) {
66 return Backend::CUDA;
67 } else if (t == DispatchKey::HIP) {
68 return Backend::HIP;
69 } else if (t == DispatchKey::VE) {
70 return Backend::VE;
71 } else if (t == DispatchKey::FPGA) {
72 return Backend::FPGA;
73 } else if (t == DispatchKey::ORT) {
74 return Backend::ORT;
75 } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
76 return Backend::XLA;
77 } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
78 return Backend::Lazy;
79 } else if (t == DispatchKey::MPS || t == DispatchKey::AutogradMPS) {
80 return Backend::MPS;
81 } else if (t == DispatchKey::Vulkan) {
82 return Backend::Vulkan;
83 } else if (t == DispatchKey::Metal) {
84 return Backend::Metal;
85 } else if (t == DispatchKey::Meta) {
86 return Backend::Meta;
87 } else if (t == DispatchKey::SparseCPU) {
88 return Backend::SparseCPU;
89 } else if (t == DispatchKey::SparseCUDA) {
90 return Backend::SparseCUDA;
91 } else if (t == DispatchKey::SparseHIP) {
92 return Backend::SparseHIP;
93 } else if (t == DispatchKey::SparseVE) {
94 return Backend::SparseVE;
95 } else if (t == DispatchKey::SparseCsrCPU) {
96 return Backend::SparseCsrCPU;
97 } else if (t == DispatchKey::SparseCsrCUDA) {
98 return Backend::SparseCsrCUDA;
99 } else if (t == DispatchKey::MkldnnCPU) {
100 return Backend::MkldnnCPU;
101 } else if (t == DispatchKey::QuantizedCPU) {
102 return Backend::QuantizedCPU;
103 } else if (t == DispatchKey::QuantizedCUDA) {
104 return Backend::QuantizedCUDA;
105 } else if (t == DispatchKey::IPU || t == DispatchKey::AutogradIPU) {
106 return Backend::IPU;
107 } else if (t == DispatchKey::XPU || t == DispatchKey::AutogradXPU) {
108 return Backend::XPU;
109 } else if (t == DispatchKey::SparseXPU) {
110 return Backend::SparseXPU;
111 } else if (t == DispatchKey::QuantizedXPU) {
112 return Backend::QuantizedXPU;
113 } else if (t == DispatchKey::HPU || t == DispatchKey::AutogradHPU) {
114 return Backend::HPU;
115 } else if (t == DispatchKey::MTIA) {
116 return Backend::MTIA;
117 } else if (t == DispatchKey::PrivateUse1) {
118 return Backend::PrivateUse1;
119 } else if (t == DispatchKey::Undefined) {
120 return Backend::Undefined;
121 } else {
122 TORCH_CHECK(false, "Unrecognized tensor type ID: ", t);
123 }
124}
125
126static inline DispatchKey backendToDispatchKey(Backend b) {
127 switch (b) {
128 case Backend::CPU:
129 return DispatchKey::CPU;
130 case Backend::CUDA:
131 return DispatchKey::CUDA;
132 case Backend::HIP:
133 return DispatchKey::HIP;
134 case Backend::VE:
135 return DispatchKey::VE;
136 case Backend::FPGA:
137 return DispatchKey::FPGA;
138 case Backend::ORT:
139 return DispatchKey::ORT;
140 case Backend::XLA:
141 return DispatchKey::XLA;
142 case Backend::Lazy:
143 return DispatchKey::Lazy;
144 case Backend::IPU:
145 return DispatchKey::IPU;
146 case Backend::XPU:
147 return DispatchKey::XPU;
148 case Backend::SparseXPU:
149 return DispatchKey::SparseXPU;
150 case Backend::SparseCPU:
151 return DispatchKey::SparseCPU;
152 case Backend::SparseCUDA:
153 return DispatchKey::SparseCUDA;
154 case Backend::SparseHIP:
155 return DispatchKey::SparseHIP;
156 case Backend::SparseVE:
157 return DispatchKey::SparseVE;
158 case Backend::SparseCsrCPU:
159 return DispatchKey::SparseCsrCPU;
160 case Backend::SparseCsrCUDA:
161 return DispatchKey::SparseCsrCUDA;
162 case Backend::MkldnnCPU:
163 return DispatchKey::MkldnnCPU;
164 case Backend::Vulkan:
165 return DispatchKey::Vulkan;
166 case Backend::Metal:
167 return DispatchKey::Metal;
168 case Backend::Meta:
169 return DispatchKey::Meta;
170 case Backend::QuantizedCPU:
171 return DispatchKey::QuantizedCPU;
172 case Backend::QuantizedCUDA:
173 return DispatchKey::QuantizedCUDA;
174 case Backend::Undefined:
175 return DispatchKey::Undefined;
176 case Backend::MPS:
177 return DispatchKey::MPS;
178 case Backend::HPU:
179 return DispatchKey::HPU;
180 case Backend::MTIA:
181 return DispatchKey::MTIA;
182 case Backend::PrivateUse1:
183 return DispatchKey::PrivateUse1;
184 default:
185 throw std::runtime_error("Unknown backend");
186 }
187}
188
189static inline DeviceType backendToDeviceType(Backend b) {
190 switch (b) {
191 case Backend::CPU:
192 return DeviceType::CPU;
193 case Backend::CUDA:
194 return DeviceType::CUDA;
195 case Backend::HIP:
196 return DeviceType::HIP;
197 case Backend::VE:
198 return DeviceType::VE;
199 case Backend::FPGA:
200 return DeviceType::FPGA;
201 case Backend::ORT:
202 return DeviceType::ORT;
203 case Backend::XLA:
204 return DeviceType::XLA;
205 case Backend::Lazy:
206 return DeviceType::Lazy;
207 case Backend::SparseCPU:
208 return DeviceType::CPU;
209 case Backend::SparseCUDA:
210 return DeviceType::CUDA;
211 case Backend::SparseHIP:
212 return DeviceType::HIP;
213 case Backend::SparseVE:
214 return DeviceType::VE;
215 case Backend::SparseCsrCPU:
216 return DeviceType::CPU;
217 case Backend::SparseCsrCUDA:
218 return DeviceType::CUDA;
219 case Backend::IPU:
220 return DeviceType::IPU;
221 case Backend::XPU:
222 case Backend::SparseXPU:
223 case Backend::QuantizedXPU:
224 return DeviceType::XPU;
225 case Backend::MkldnnCPU:
226 case Backend::QuantizedCPU:
227 return DeviceType::CPU;
228 case Backend::QuantizedCUDA:
229 return DeviceType::CUDA;
230 case Backend::Vulkan:
231 return DeviceType::Vulkan;
232 case Backend::Metal:
233 return DeviceType::Metal;
234 case Backend::Meta:
235 return DeviceType::Meta;
236 case Backend::MPS:
237 return DeviceType::MPS;
238 case Backend::HPU:
239 return DeviceType::HPU;
240 case Backend::MTIA:
241 return DeviceType::MTIA;
242 case Backend::PrivateUse1:
243 return DeviceType::PrivateUse1;
244 case Backend::Undefined:
245 TORCH_CHECK(false, "Undefined backend is not a valid device type");
246 default:
247 TORCH_CHECK(false, "Unknown backend");
248 }
249}
250
251// TODO: This probably shouldn't actually be static inline
252static inline const char* toString(Backend b) {
253 switch (b) {
254 case Backend::CPU:
255 return "CPU";
256 case Backend::CUDA:
257 return "CUDA";
258 case Backend::HIP:
259 return "HIP";
260 case Backend::VE:
261 return "VE";
262 case Backend::FPGA:
263 return "FPGA";
264 case Backend::XPU:
265 return "XPU";
266 case Backend::IPU:
267 return "IPU";
268 case Backend::ORT:
269 return "ORT";
270 case Backend::XLA:
271 return "XLA";
272 case Backend::Lazy:
273 return "Lazy";
274 case Backend::MPS:
275 return "MPS";
276 case Backend::SparseCPU:
277 return "SparseCPU";
278 case Backend::SparseCUDA:
279 return "SparseCUDA";
280 case Backend::SparseHIP:
281 return "SparseHIP";
282 case Backend::SparseVE:
283 return "SparseVE";
284 case Backend::SparseXPU:
285 return "SparseXPU";
286 case Backend::SparseCsrCPU:
287 return "SparseCsrCPU";
288 case Backend::SparseCsrCUDA:
289 return "SparseCsrCUDA";
290 case Backend::MkldnnCPU:
291 return "MkldnnCPU";
292 case Backend::Vulkan:
293 return "Vulkan";
294 case Backend::Metal:
295 return "Metal";
296 case Backend::Meta:
297 return "Meta";
298 case Backend::QuantizedCPU:
299 return "QuantizedCPU";
300 case Backend::QuantizedCUDA:
301 return "QuantizedCUDA";
302 case Backend::QuantizedXPU:
303 return "QuantizedXPU";
304 case Backend::HPU:
305 return "HPU";
306 case Backend::MTIA:
307 return "MTIA";
308 case Backend::PrivateUse1:
309 return "PrivateUseOne";
310 default:
311 return "UNKNOWN_BACKEND";
312 }
313}
314
315static inline bool isSparse(Backend b) {
316 switch (b) {
317 case Backend::SparseXPU:
318 case Backend::SparseCPU:
319 case Backend::SparseCUDA:
320 case Backend::SparseHIP:
321 case Backend::SparseVE:
322 return true;
323 default:
324 return false;
325 }
326}
327
328static inline bool isSparseCsr(Backend b) {
329 switch (b) {
330 case Backend::SparseCsrCPU:
331 case Backend::SparseCsrCUDA:
332 return true;
333 default:
334 return false;
335 }
336}
337
338} // namespace c10
339