1 | #include <c10/core/DeviceType.h> |
2 | #include <c10/util/Exception.h> |
3 | #include <atomic> |
4 | #include <mutex> |
5 | |
6 | namespace c10 { |
7 | |
8 | std::string DeviceTypeName(DeviceType d, bool lower_case) { |
9 | switch (d) { |
10 | // I considered instead using ctype::tolower to lower-case the strings |
11 | // on the fly, but this seemed a bit much. |
12 | case DeviceType::CPU: |
13 | return lower_case ? "cpu" : "CPU" ; |
14 | case DeviceType::CUDA: |
15 | return lower_case ? "cuda" : "CUDA" ; |
16 | case DeviceType::OPENGL: |
17 | return lower_case ? "opengl" : "OPENGL" ; |
18 | case DeviceType::OPENCL: |
19 | return lower_case ? "opencl" : "OPENCL" ; |
20 | case DeviceType::MKLDNN: |
21 | return lower_case ? "mkldnn" : "MKLDNN" ; |
22 | case DeviceType::IDEEP: |
23 | return lower_case ? "ideep" : "IDEEP" ; |
24 | case DeviceType::HIP: |
25 | return lower_case ? "hip" : "HIP" ; |
26 | case DeviceType::VE: |
27 | return lower_case ? "ve" : "VE" ; |
28 | case DeviceType::FPGA: |
29 | return lower_case ? "fpga" : "FPGA" ; |
30 | case DeviceType::ORT: |
31 | return lower_case ? "ort" : "ORT" ; |
32 | case DeviceType::XLA: |
33 | return lower_case ? "xla" : "XLA" ; |
34 | case DeviceType::Lazy: |
35 | return lower_case ? "lazy" : "LAZY" ; |
36 | case DeviceType::MPS: |
37 | return lower_case ? "mps" : "MPS" ; |
38 | case DeviceType::Vulkan: |
39 | return lower_case ? "vulkan" : "VULKAN" ; |
40 | case DeviceType::Metal: |
41 | return lower_case ? "metal" : "METAL" ; |
42 | case DeviceType::XPU: |
43 | return lower_case ? "xpu" : "XPU" ; |
44 | case DeviceType::Meta: |
45 | return lower_case ? "meta" : "META" ; |
46 | case DeviceType::HPU: |
47 | return lower_case ? "hpu" : "HPU" ; |
48 | case DeviceType::IPU: |
49 | return lower_case ? "ipu" : "IPU" ; |
50 | case DeviceType::MTIA: |
51 | return lower_case ? "mtia" : "MTIA" ; |
52 | case DeviceType::PrivateUse1: |
53 | return get_privateuse1_backend(/*lower_case=*/lower_case); |
54 | default: |
55 | TORCH_CHECK( |
56 | false, |
57 | "Unknown device: " , |
58 | static_cast<int16_t>(d), |
59 | ". If you have recently updated the caffe2.proto file to add a new " |
60 | "device type, did you forget to update the DeviceTypeName() " |
61 | "function to reflect such recent changes?" ); |
62 | // The below code won't run but is needed to suppress some compiler |
63 | // warnings. |
64 | return "" ; |
65 | } |
66 | } |
67 | |
68 | // NB: Per the C++ standard (e.g., |
69 | // https://stackoverflow.com/questions/18195312/what-happens-if-you-static-cast-invalid-value-to-enum-class) |
70 | // as long as you cast from the same underlying type, it is always valid to cast |
71 | // into an enum class (even if the value would be invalid by the enum.) Thus, |
72 | // the caller is allowed to cast a possibly invalid int16_t to DeviceType and |
73 | // then pass it to this function. (I considered making this function take an |
74 | // int16_t directly, but that just seemed weird.) |
75 | bool isValidDeviceType(DeviceType d) { |
76 | switch (d) { |
77 | case DeviceType::CPU: |
78 | case DeviceType::CUDA: |
79 | case DeviceType::OPENGL: |
80 | case DeviceType::OPENCL: |
81 | case DeviceType::MKLDNN: |
82 | case DeviceType::IDEEP: |
83 | case DeviceType::HIP: |
84 | case DeviceType::VE: |
85 | case DeviceType::FPGA: |
86 | case DeviceType::ORT: |
87 | case DeviceType::XLA: |
88 | case DeviceType::Lazy: |
89 | case DeviceType::MPS: |
90 | case DeviceType::Vulkan: |
91 | case DeviceType::Metal: |
92 | case DeviceType::XPU: |
93 | case DeviceType::Meta: |
94 | case DeviceType::HPU: |
95 | case DeviceType::IPU: |
96 | case DeviceType::MTIA: |
97 | case DeviceType::PrivateUse1: |
98 | return true; |
99 | default: |
100 | return false; |
101 | } |
102 | } |
103 | |
104 | std::ostream& operator<<(std::ostream& stream, DeviceType type) { |
105 | stream << DeviceTypeName(type, /* lower case */ true); |
106 | return stream; |
107 | } |
108 | |
109 | // We use both a mutex and an atomic here because: |
110 | // (1) Mutex is needed during writing: |
111 | // We need to first check the value and potentially error, |
112 | // before setting the value (without any one else racing in the middle). |
113 | // It's also totally fine for this to be slow, since it happens exactly once |
114 | // at import time. |
115 | // (2) Atomic is needed during reading: |
116 | // Whenever a user prints a privatuse1 device name, they need to read this |
117 | // variable. Although unlikely, we'll data race if someone else is trying to |
118 | // set this variable at the same time that another thread is print the |
119 | // device name. We could re-use the same mutex, but reading the atomic will |
120 | // be much faster. |
121 | static std::atomic<bool> privateuse1_backend_name_set; |
122 | static std::string privateuse1_backend_name; |
123 | static std::mutex privateuse1_lock; |
124 | |
125 | std::string get_privateuse1_backend(bool lower_case) { |
126 | // Applying the same atomic read memory ordering logic as in Note [Memory |
127 | // ordering on Python interpreter tag]. |
128 | auto name_registered = |
129 | privateuse1_backend_name_set.load(std::memory_order_acquire); |
130 | // Guaranteed that if the flag is set, then privateuse1_backend_name has been |
131 | // set, and will never be written to. |
132 | auto backend_name = |
133 | name_registered ? privateuse1_backend_name : "privateuseone" ; |
134 | return backend_name; |
135 | } |
136 | |
137 | void register_privateuse1_backend(std::string backend_name) { |
138 | std::lock_guard<std::mutex> guard(privateuse1_lock); |
139 | TORCH_CHECK( |
140 | !privateuse1_backend_name_set.load() || |
141 | privateuse1_backend_name == backend_name, |
142 | "torch.register_privateuse1_backend() has already been set! Current backend: " , |
143 | privateuse1_backend_name); |
144 | |
145 | privateuse1_backend_name = backend_name; |
146 | // Invariant: once this flag is set, privateuse1_backend_name is NEVER written |
147 | // to. |
148 | privateuse1_backend_name_set.store(true, std::memory_order_relaxed); |
149 | } |
150 | |
151 | } // namespace c10 |
152 | |