1#include <c10/core/Device.h>
2#include <c10/macros/Macros.h>
3#include <c10/util/Exception.h>
4
5#include <algorithm>
6#include <array>
7#include <cctype>
8#include <exception>
9#include <ostream>
10#include <string>
11#include <vector>
12
13namespace c10 {
14namespace {
15DeviceType parse_type(const std::string& device_string) {
16 static const std::array<
17 std::pair<const char*, DeviceType>,
18 static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)>
19 types = {{
20 {"cpu", DeviceType::CPU},
21 {"cuda", DeviceType::CUDA},
22 {"ipu", DeviceType::IPU},
23 {"xpu", DeviceType::XPU},
24 {"mkldnn", DeviceType::MKLDNN},
25 {"opengl", DeviceType::OPENGL},
26 {"opencl", DeviceType::OPENCL},
27 {"ideep", DeviceType::IDEEP},
28 {"hip", DeviceType::HIP},
29 {"ve", DeviceType::VE},
30 {"fpga", DeviceType::FPGA},
31 {"ort", DeviceType::ORT},
32 {"xla", DeviceType::XLA},
33 {"lazy", DeviceType::Lazy},
34 {"vulkan", DeviceType::Vulkan},
35 {"mps", DeviceType::MPS},
36 {"meta", DeviceType::Meta},
37 {"hpu", DeviceType::HPU},
38 {"mtia", DeviceType::MTIA},
39 {"privateuseone", DeviceType::PrivateUse1},
40 }};
41 auto device = std::find_if(
42 types.begin(),
43 types.end(),
44 [&device_string](const std::pair<const char*, DeviceType>& p) {
45 return p.first && p.first == device_string;
46 });
47 if (device != types.end()) {
48 return device->second;
49 }
50 if (device_string == get_privateuse1_backend()) {
51 return DeviceType::PrivateUse1;
52 }
53 std::vector<const char*> device_names;
54 for (const auto& it : types) {
55 if (it.first) {
56 device_names.push_back(it.first);
57 }
58 }
59 TORCH_CHECK(
60 false,
61 "Expected one of ",
62 c10::Join(", ", device_names),
63 " device type at start of device string: ",
64 device_string);
65}
66enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
67
68} // namespace
69
70Device::Device(const std::string& device_string) : Device(Type::CPU) {
71 TORCH_CHECK(!device_string.empty(), "Device string must not be empty");
72
73 std::string device_name, device_index_str;
74 DeviceStringParsingState pstate = DeviceStringParsingState::START;
75
76 // The code below tries to match the string in the variable
77 // device_string against the regular expression:
78 // ([a-zA-Z_]+)(?::([1-9]\\d*|0))?
79 for (size_t i = 0;
80 pstate != DeviceStringParsingState::ERROR && i < device_string.size();
81 ++i) {
82 const char ch = device_string.at(i);
83 switch (pstate) {
84 case DeviceStringParsingState::START:
85 if (ch != ':') {
86 if (isalpha(ch) || ch == '_') {
87 device_name.push_back(ch);
88 } else {
89 pstate = DeviceStringParsingState::ERROR;
90 }
91 } else {
92 pstate = DeviceStringParsingState::INDEX_START;
93 }
94 break;
95
96 case DeviceStringParsingState::INDEX_START:
97 if (isdigit(ch)) {
98 device_index_str.push_back(ch);
99 pstate = DeviceStringParsingState::INDEX_REST;
100 } else {
101 pstate = DeviceStringParsingState::ERROR;
102 }
103 break;
104
105 case DeviceStringParsingState::INDEX_REST:
106 if (device_index_str.at(0) == '0') {
107 pstate = DeviceStringParsingState::ERROR;
108 break;
109 }
110 if (isdigit(ch)) {
111 device_index_str.push_back(ch);
112 } else {
113 pstate = DeviceStringParsingState::ERROR;
114 }
115 break;
116
117 case DeviceStringParsingState::ERROR:
118 // Execution won't reach here.
119 break;
120 }
121 }
122
123 const bool has_error = device_name.empty() ||
124 pstate == DeviceStringParsingState::ERROR ||
125 (pstate == DeviceStringParsingState::INDEX_START &&
126 device_index_str.empty());
127
128 TORCH_CHECK(!has_error, "Invalid device string: '", device_string, "'");
129
130 try {
131 if (!device_index_str.empty()) {
132 index_ = static_cast<c10::DeviceIndex>(c10::stoi(device_index_str));
133 }
134 } catch (const std::exception&) {
135 TORCH_CHECK(
136 false,
137 "Could not parse device index '",
138 device_index_str,
139 "' in device string '",
140 device_string,
141 "'");
142 }
143 type_ = parse_type(device_name);
144 validate();
145}
146
147std::string Device::str() const {
148 std::string str = DeviceTypeName(type(), /* lower case */ true);
149 if (has_index()) {
150 str.push_back(':');
151 str.append(to_string(index()));
152 }
153 return str;
154}
155
156std::ostream& operator<<(std::ostream& stream, const Device& device) {
157 stream << device.str();
158 return stream;
159}
160
161} // namespace c10
162