1 | #include <c10/core/Device.h> |
2 | #include <c10/macros/Macros.h> |
3 | #include <c10/util/Exception.h> |
4 | |
5 | #include <algorithm> |
6 | #include <array> |
7 | #include <cctype> |
8 | #include <exception> |
9 | #include <ostream> |
10 | #include <string> |
11 | #include <vector> |
12 | |
13 | namespace c10 { |
14 | namespace { |
15 | DeviceType parse_type(const std::string& device_string) { |
16 | static const std::array< |
17 | std::pair<const char*, DeviceType>, |
18 | static_cast<size_t>(DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES)> |
19 | types = {{ |
20 | {"cpu" , DeviceType::CPU}, |
21 | {"cuda" , DeviceType::CUDA}, |
22 | {"ipu" , DeviceType::IPU}, |
23 | {"xpu" , DeviceType::XPU}, |
24 | {"mkldnn" , DeviceType::MKLDNN}, |
25 | {"opengl" , DeviceType::OPENGL}, |
26 | {"opencl" , DeviceType::OPENCL}, |
27 | {"ideep" , DeviceType::IDEEP}, |
28 | {"hip" , DeviceType::HIP}, |
29 | {"ve" , DeviceType::VE}, |
30 | {"fpga" , DeviceType::FPGA}, |
31 | {"ort" , DeviceType::ORT}, |
32 | {"xla" , DeviceType::XLA}, |
33 | {"lazy" , DeviceType::Lazy}, |
34 | {"vulkan" , DeviceType::Vulkan}, |
35 | {"mps" , DeviceType::MPS}, |
36 | {"meta" , DeviceType::Meta}, |
37 | {"hpu" , DeviceType::HPU}, |
38 | {"mtia" , DeviceType::MTIA}, |
39 | {"privateuseone" , DeviceType::PrivateUse1}, |
40 | }}; |
41 | auto device = std::find_if( |
42 | types.begin(), |
43 | types.end(), |
44 | [&device_string](const std::pair<const char*, DeviceType>& p) { |
45 | return p.first && p.first == device_string; |
46 | }); |
47 | if (device != types.end()) { |
48 | return device->second; |
49 | } |
50 | if (device_string == get_privateuse1_backend()) { |
51 | return DeviceType::PrivateUse1; |
52 | } |
53 | std::vector<const char*> device_names; |
54 | for (const auto& it : types) { |
55 | if (it.first) { |
56 | device_names.push_back(it.first); |
57 | } |
58 | } |
59 | TORCH_CHECK( |
60 | false, |
61 | "Expected one of " , |
62 | c10::Join(", " , device_names), |
63 | " device type at start of device string: " , |
64 | device_string); |
65 | } |
66 | enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR }; |
67 | |
68 | } // namespace |
69 | |
70 | Device::Device(const std::string& device_string) : Device(Type::CPU) { |
71 | TORCH_CHECK(!device_string.empty(), "Device string must not be empty" ); |
72 | |
73 | std::string device_name, device_index_str; |
74 | DeviceStringParsingState pstate = DeviceStringParsingState::START; |
75 | |
76 | // The code below tries to match the string in the variable |
77 | // device_string against the regular expression: |
78 | // ([a-zA-Z_]+)(?::([1-9]\\d*|0))? |
79 | for (size_t i = 0; |
80 | pstate != DeviceStringParsingState::ERROR && i < device_string.size(); |
81 | ++i) { |
82 | const char ch = device_string.at(i); |
83 | switch (pstate) { |
84 | case DeviceStringParsingState::START: |
85 | if (ch != ':') { |
86 | if (isalpha(ch) || ch == '_') { |
87 | device_name.push_back(ch); |
88 | } else { |
89 | pstate = DeviceStringParsingState::ERROR; |
90 | } |
91 | } else { |
92 | pstate = DeviceStringParsingState::INDEX_START; |
93 | } |
94 | break; |
95 | |
96 | case DeviceStringParsingState::INDEX_START: |
97 | if (isdigit(ch)) { |
98 | device_index_str.push_back(ch); |
99 | pstate = DeviceStringParsingState::INDEX_REST; |
100 | } else { |
101 | pstate = DeviceStringParsingState::ERROR; |
102 | } |
103 | break; |
104 | |
105 | case DeviceStringParsingState::INDEX_REST: |
106 | if (device_index_str.at(0) == '0') { |
107 | pstate = DeviceStringParsingState::ERROR; |
108 | break; |
109 | } |
110 | if (isdigit(ch)) { |
111 | device_index_str.push_back(ch); |
112 | } else { |
113 | pstate = DeviceStringParsingState::ERROR; |
114 | } |
115 | break; |
116 | |
117 | case DeviceStringParsingState::ERROR: |
118 | // Execution won't reach here. |
119 | break; |
120 | } |
121 | } |
122 | |
123 | const bool has_error = device_name.empty() || |
124 | pstate == DeviceStringParsingState::ERROR || |
125 | (pstate == DeviceStringParsingState::INDEX_START && |
126 | device_index_str.empty()); |
127 | |
128 | TORCH_CHECK(!has_error, "Invalid device string: '" , device_string, "'" ); |
129 | |
130 | try { |
131 | if (!device_index_str.empty()) { |
132 | index_ = static_cast<c10::DeviceIndex>(c10::stoi(device_index_str)); |
133 | } |
134 | } catch (const std::exception&) { |
135 | TORCH_CHECK( |
136 | false, |
137 | "Could not parse device index '" , |
138 | device_index_str, |
139 | "' in device string '" , |
140 | device_string, |
141 | "'" ); |
142 | } |
143 | type_ = parse_type(device_name); |
144 | validate(); |
145 | } |
146 | |
147 | std::string Device::str() const { |
148 | std::string str = DeviceTypeName(type(), /* lower case */ true); |
149 | if (has_index()) { |
150 | str.push_back(':'); |
151 | str.append(to_string(index())); |
152 | } |
153 | return str; |
154 | } |
155 | |
156 | std::ostream& operator<<(std::ostream& stream, const Device& device) { |
157 | stream << device.str(); |
158 | return stream; |
159 | } |
160 | |
161 | } // namespace c10 |
162 | |