1 | #pragma once |
2 | |
3 | // This is directly synchronized with caffe2/proto/caffe2.proto, but |
4 | // doesn't require me to figure out how to get Protobuf headers into |
5 | // ATen/core (which would require a lot more build system hacking.) |
6 | // If you modify me, keep me synchronized with that file. |
7 | |
8 | #include <c10/macros/Macros.h> |
9 | |
10 | #include <functional> |
11 | #include <ostream> |
12 | |
13 | namespace c10 { |
14 | |
15 | // These contains all device types that also have a BackendComponent |
16 | // and therefore participate in per-backend functionality dispatch keys. |
17 | // This is most backends except PrivateUse2 and PrivateUse3 |
18 | #define C10_FORALL_BACKEND_DEVICE_TYPES(_, extra) \ |
19 | _(CPU, extra) \ |
20 | _(CUDA, extra) \ |
21 | _(HIP, extra) \ |
22 | _(XLA, extra) \ |
23 | _(MPS, extra) \ |
24 | _(IPU, extra) \ |
25 | _(XPU, extra) \ |
26 | _(HPU, extra) \ |
27 | _(VE, extra) \ |
28 | _(Lazy, extra) \ |
29 | _(Meta, extra) \ |
30 | _(MTIA, extra) \ |
31 | _(PrivateUse1, extra) |
32 | |
33 | enum class DeviceType : int8_t { |
34 | CPU = 0, |
35 | CUDA = 1, // CUDA. |
36 | MKLDNN = 2, // Reserved for explicit MKLDNN |
37 | OPENGL = 3, // OpenGL |
38 | OPENCL = 4, // OpenCL |
39 | IDEEP = 5, // IDEEP. |
40 | HIP = 6, // AMD HIP |
41 | FPGA = 7, // FPGA |
42 | ORT = 8, // ONNX Runtime / Microsoft |
43 | XLA = 9, // XLA / TPU |
44 | Vulkan = 10, // Vulkan |
45 | Metal = 11, // Metal |
46 | XPU = 12, // XPU |
47 | MPS = 13, // MPS |
48 | Meta = 14, // Meta (tensors with no data) |
49 | HPU = 15, // HPU / HABANA |
50 | VE = 16, // SX-Aurora / NEC |
51 | Lazy = 17, // Lazy Tensors |
52 | IPU = 18, // Graphcore IPU |
53 | MTIA = 19, // Meta training and inference devices |
54 | PrivateUse1 = 20, // PrivateUse1 device |
55 | // NB: If you add more devices: |
56 | // - Change the implementations of DeviceTypeName and isValidDeviceType |
57 | // in DeviceType.cpp |
58 | // - Change the number below |
59 | COMPILE_TIME_MAX_DEVICE_TYPES = 21, |
60 | }; |
61 | |
62 | constexpr DeviceType kCPU = DeviceType::CPU; |
63 | constexpr DeviceType kCUDA = DeviceType::CUDA; |
64 | constexpr DeviceType kHIP = DeviceType::HIP; |
65 | constexpr DeviceType kFPGA = DeviceType::FPGA; |
66 | constexpr DeviceType kORT = DeviceType::ORT; |
67 | constexpr DeviceType kXLA = DeviceType::XLA; |
68 | constexpr DeviceType kMPS = DeviceType::MPS; |
69 | constexpr DeviceType kMeta = DeviceType::Meta; |
70 | constexpr DeviceType kVulkan = DeviceType::Vulkan; |
71 | constexpr DeviceType kMetal = DeviceType::Metal; |
72 | constexpr DeviceType kXPU = DeviceType::XPU; |
73 | constexpr DeviceType kHPU = DeviceType::HPU; |
74 | constexpr DeviceType kVE = DeviceType::VE; |
75 | constexpr DeviceType kLazy = DeviceType::Lazy; |
76 | constexpr DeviceType kIPU = DeviceType::IPU; |
77 | constexpr DeviceType kMTIA = DeviceType::MTIA; |
78 | constexpr DeviceType kPrivateUse1 = DeviceType::PrivateUse1; |
79 | |
80 | // define explicit int constant |
81 | constexpr int COMPILE_TIME_MAX_DEVICE_TYPES = |
82 | static_cast<int>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); |
83 | |
84 | static_assert( |
85 | COMPILE_TIME_MAX_DEVICE_TYPES <= 21, |
86 | "Hey! You seem to be adding a lot of new DeviceTypes. The intent was " |
87 | "for this constant to reflect the actual number of DeviceTypes we support " |
88 | "in PyTorch; it's important that this number is not too large as we " |
89 | "use this to allocate stack arrays in some places in our code. If you " |
90 | "are indeed just adding the 20th device type, feel free to change " |
91 | "the check to 32; but if you are adding some sort of extensible device " |
92 | "types registration, please be aware that you are affecting code that " |
93 | "this number is small. Try auditing uses of this constant." ); |
94 | |
95 | C10_API std::string DeviceTypeName(DeviceType d, bool lower_case = false); |
96 | |
97 | C10_API bool isValidDeviceType(DeviceType d); |
98 | |
99 | C10_API std::ostream& operator<<(std::ostream& stream, DeviceType type); |
100 | |
101 | C10_API void register_privateuse1_backend(std::string backend_name); |
102 | C10_API std::string get_privateuse1_backend(bool lower_case = true); |
103 | |
104 | } // namespace c10 |
105 | |
106 | namespace std { |
107 | template <> |
108 | struct hash<c10::DeviceType> { |
109 | std::size_t operator()(c10::DeviceType k) const { |
110 | return std::hash<int>()(static_cast<int>(k)); |
111 | } |
112 | }; |
113 | } // namespace std |
114 | |
115 | namespace torch { |
116 | using c10::DeviceType; |
117 | } |
118 | |