1#include <c10/core/DeviceType.h>
2#include <c10/util/Exception.h>
3#include <atomic>
4#include <mutex>
5
6namespace c10 {
7
8std::string DeviceTypeName(DeviceType d, bool lower_case) {
9 switch (d) {
10 // I considered instead using ctype::tolower to lower-case the strings
11 // on the fly, but this seemed a bit much.
12 case DeviceType::CPU:
13 return lower_case ? "cpu" : "CPU";
14 case DeviceType::CUDA:
15 return lower_case ? "cuda" : "CUDA";
16 case DeviceType::OPENGL:
17 return lower_case ? "opengl" : "OPENGL";
18 case DeviceType::OPENCL:
19 return lower_case ? "opencl" : "OPENCL";
20 case DeviceType::MKLDNN:
21 return lower_case ? "mkldnn" : "MKLDNN";
22 case DeviceType::IDEEP:
23 return lower_case ? "ideep" : "IDEEP";
24 case DeviceType::HIP:
25 return lower_case ? "hip" : "HIP";
26 case DeviceType::VE:
27 return lower_case ? "ve" : "VE";
28 case DeviceType::FPGA:
29 return lower_case ? "fpga" : "FPGA";
30 case DeviceType::ORT:
31 return lower_case ? "ort" : "ORT";
32 case DeviceType::XLA:
33 return lower_case ? "xla" : "XLA";
34 case DeviceType::Lazy:
35 return lower_case ? "lazy" : "LAZY";
36 case DeviceType::MPS:
37 return lower_case ? "mps" : "MPS";
38 case DeviceType::Vulkan:
39 return lower_case ? "vulkan" : "VULKAN";
40 case DeviceType::Metal:
41 return lower_case ? "metal" : "METAL";
42 case DeviceType::XPU:
43 return lower_case ? "xpu" : "XPU";
44 case DeviceType::Meta:
45 return lower_case ? "meta" : "META";
46 case DeviceType::HPU:
47 return lower_case ? "hpu" : "HPU";
48 case DeviceType::IPU:
49 return lower_case ? "ipu" : "IPU";
50 case DeviceType::MTIA:
51 return lower_case ? "mtia" : "MTIA";
52 case DeviceType::PrivateUse1:
53 return get_privateuse1_backend(/*lower_case=*/lower_case);
54 default:
55 TORCH_CHECK(
56 false,
57 "Unknown device: ",
58 static_cast<int16_t>(d),
59 ". If you have recently updated the caffe2.proto file to add a new "
60 "device type, did you forget to update the DeviceTypeName() "
61 "function to reflect such recent changes?");
62 // The below code won't run but is needed to suppress some compiler
63 // warnings.
64 return "";
65 }
66}
67
68// NB: Per the C++ standard (e.g.,
69// https://stackoverflow.com/questions/18195312/what-happens-if-you-static-cast-invalid-value-to-enum-class)
70// as long as you cast from the same underlying type, it is always valid to cast
71// into an enum class (even if the value would be invalid by the enum.) Thus,
72// the caller is allowed to cast a possibly invalid int16_t to DeviceType and
73// then pass it to this function. (I considered making this function take an
74// int16_t directly, but that just seemed weird.)
75bool isValidDeviceType(DeviceType d) {
76 switch (d) {
77 case DeviceType::CPU:
78 case DeviceType::CUDA:
79 case DeviceType::OPENGL:
80 case DeviceType::OPENCL:
81 case DeviceType::MKLDNN:
82 case DeviceType::IDEEP:
83 case DeviceType::HIP:
84 case DeviceType::VE:
85 case DeviceType::FPGA:
86 case DeviceType::ORT:
87 case DeviceType::XLA:
88 case DeviceType::Lazy:
89 case DeviceType::MPS:
90 case DeviceType::Vulkan:
91 case DeviceType::Metal:
92 case DeviceType::XPU:
93 case DeviceType::Meta:
94 case DeviceType::HPU:
95 case DeviceType::IPU:
96 case DeviceType::MTIA:
97 case DeviceType::PrivateUse1:
98 return true;
99 default:
100 return false;
101 }
102}
103
104std::ostream& operator<<(std::ostream& stream, DeviceType type) {
105 stream << DeviceTypeName(type, /* lower case */ true);
106 return stream;
107}
108
109// We use both a mutex and an atomic here because:
110// (1) Mutex is needed during writing:
111// We need to first check the value and potentially error,
112// before setting the value (without any one else racing in the middle).
113// It's also totally fine for this to be slow, since it happens exactly once
114// at import time.
115// (2) Atomic is needed during reading:
116// Whenever a user prints a privatuse1 device name, they need to read this
117// variable. Although unlikely, we'll data race if someone else is trying to
118// set this variable at the same time that another thread is print the
119// device name. We could re-use the same mutex, but reading the atomic will
120// be much faster.
121static std::atomic<bool> privateuse1_backend_name_set;
122static std::string privateuse1_backend_name;
123static std::mutex privateuse1_lock;
124
125std::string get_privateuse1_backend(bool lower_case) {
126 // Applying the same atomic read memory ordering logic as in Note [Memory
127 // ordering on Python interpreter tag].
128 auto name_registered =
129 privateuse1_backend_name_set.load(std::memory_order_acquire);
130 // Guaranteed that if the flag is set, then privateuse1_backend_name has been
131 // set, and will never be written to.
132 auto backend_name =
133 name_registered ? privateuse1_backend_name : "privateuseone";
134 return backend_name;
135}
136
137void register_privateuse1_backend(std::string backend_name) {
138 std::lock_guard<std::mutex> guard(privateuse1_lock);
139 TORCH_CHECK(
140 !privateuse1_backend_name_set.load() ||
141 privateuse1_backend_name == backend_name,
142 "torch.register_privateuse1_backend() has already been set! Current backend: ",
143 privateuse1_backend_name);
144
145 privateuse1_backend_name = backend_name;
146 // Invariant: once this flag is set, privateuse1_backend_name is NEVER written
147 // to.
148 privateuse1_backend_name_set.store(true, std::memory_order_relaxed);
149}
150
151} // namespace c10
152