1#pragma once
2
3// This is directly synchronized with caffe2/proto/caffe2.proto, but
4// doesn't require me to figure out how to get Protobuf headers into
5// ATen/core (which would require a lot more build system hacking.)
6// If you modify me, keep me synchronized with that file.
7
8#include <c10/macros/Macros.h>
9
10#include <functional>
11#include <ostream>
12
13namespace c10 {
14
15// These contains all device types that also have a BackendComponent
16// and therefore participate in per-backend functionality dispatch keys.
17// This is most backends except PrivateUse2 and PrivateUse3
18#define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \
19 _(CPU, extra) \
20 _(CUDA, extra) \
21 _(HIP, extra) \
22 _(XLA, extra) \
23 _(MPS, extra) \
24 _(IPU, extra) \
25 _(XPU, extra) \
26 _(HPU, extra) \
27 _(VE, extra) \
28 _(Lazy, extra) \
29 _(Meta, extra) \
30 _(MTIA, extra) \
31 _(PrivateUse1, extra)
32
33enum class DeviceType : int8_t {
34 CPU = 0,
35 CUDA = 1, // CUDA.
36 MKLDNN = 2, // Reserved for explicit MKLDNN
37 OPENGL = 3, // OpenGL
38 OPENCL = 4, // OpenCL
39 IDEEP = 5, // IDEEP.
40 HIP = 6, // AMD HIP
41 FPGA = 7, // FPGA
42 ORT = 8, // ONNX Runtime / Microsoft
43 XLA = 9, // XLA / TPU
44 Vulkan = 10, // Vulkan
45 Metal = 11, // Metal
46 XPU = 12, // XPU
47 MPS = 13, // MPS
48 Meta = 14, // Meta (tensors with no data)
49 HPU = 15, // HPU / HABANA
50 VE = 16, // SX-Aurora / NEC
51 Lazy = 17, // Lazy Tensors
52 IPU = 18, // Graphcore IPU
53 MTIA = 19, // Meta training and inference devices
54 PrivateUse1 = 20, // PrivateUse1 device
55 // NB: If you add more devices:
56 // - Change the implementations of DeviceTypeName and isValidDeviceType
57 // in DeviceType.cpp
58 // - Change the number below
59 COMPILE_TIME_MAX_DEVICE_TYPES = 21,
60};
61
62constexpr DeviceType kCPU = DeviceType::CPU;
63constexpr DeviceType kCUDA = DeviceType::CUDA;
64constexpr DeviceType kHIP = DeviceType::HIP;
65constexpr DeviceType kFPGA = DeviceType::FPGA;
66constexpr DeviceType kORT = DeviceType::ORT;
67constexpr DeviceType kXLA = DeviceType::XLA;
68constexpr DeviceType kMPS = DeviceType::MPS;
69constexpr DeviceType kMeta = DeviceType::Meta;
70constexpr DeviceType kVulkan = DeviceType::Vulkan;
71constexpr DeviceType kMetal = DeviceType::Metal;
72constexpr DeviceType kXPU = DeviceType::XPU;
73constexpr DeviceType kHPU = DeviceType::HPU;
74constexpr DeviceType kVE = DeviceType::VE;
75constexpr DeviceType kLazy = DeviceType::Lazy;
76constexpr DeviceType kIPU = DeviceType::IPU;
77constexpr DeviceType kMTIA = DeviceType::MTIA;
78constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1;
79
80// define explicit int constant
81constexpr int COMPILE_TIME_MAX_DEVICE_TYPES =
82 static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES);
83
84static_assert(
85 COMPILE_TIME_MAX_DEVICE_TYPES <= 21,
86 "Hey! You seem to be adding a lot of new DeviceTypes. The intent was "
87 "for this constant to reflect the actual number of DeviceTypes we support "
88 "in PyTorch; it's important that this number is not too large as we "
89 "use this to allocate stack arrays in some places in our code. If you "
90 "are indeed just adding the 20th device type, feel free to change "
91 "the check to 32; but if you are adding some sort of extensible device "
92 "types registration, please be aware that you are affecting code that "
93 "this number is small. Try auditing uses of this constant.");
94
95C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false);
96
97C10_API bool isValidDeviceType(DeviceType d);
98
99C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type);
100
101C10_API void register_privateuse1_backend(std::string backend_name);
102C10_API std::string get_privateuse1_backend(bool lower_case = true);
103
104} // namespace c10
105
106namespace std {
107template <>
108struct hash<c10::DeviceType> {
109 std::size_t operator()(c10::DeviceType k) const {
110 return std::hash<int>()(static_cast<int>(k));
111 }
112};
113} // namespace std
114
115namespace torch {
116using c10::DeviceType;
117}
118