1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/device_factory.h" |
17 | |
18 | #include <memory> |
19 | #include <string> |
20 | #include <unordered_map> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "tensorflow/core/framework/device.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/strings/strcat.h" |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/platform/logging.h" |
29 | #include "tensorflow/core/platform/mutex.h" |
30 | #include "tensorflow/core/platform/types.h" |
31 | #include "tensorflow/core/public/session_options.h" |
32 | #include "tensorflow/core/util/device_name_utils.h" |
33 | #include "tensorflow/core/util/env_var.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | namespace { |
38 | |
39 | static mutex* get_device_factory_lock() { |
40 | static mutex device_factory_lock(LINKER_INITIALIZED); |
41 | return &device_factory_lock; |
42 | } |
43 | |
44 | struct FactoryItem { |
45 | std::unique_ptr<DeviceFactory> factory; |
46 | int priority; |
47 | bool is_pluggable_device; |
48 | }; |
49 | |
50 | std::unordered_map<string, FactoryItem>& device_factories() { |
51 | static std::unordered_map<string, FactoryItem>* factories = |
52 | new std::unordered_map<string, FactoryItem>; |
53 | return *factories; |
54 | } |
55 | |
56 | bool IsDeviceFactoryEnabled(const string& device_type) { |
57 | std::vector<string> enabled_devices; |
58 | TF_CHECK_OK(tensorflow::ReadStringsFromEnvVar( |
59 | /*env_var_name=*/"TF_ENABLED_DEVICE_TYPES" , /*default_val=*/"" , |
60 | &enabled_devices)); |
61 | if (enabled_devices.empty()) { |
62 | return true; |
63 | } |
64 | return std::find(enabled_devices.begin(), enabled_devices.end(), |
65 | device_type) != enabled_devices.end(); |
66 | } |
67 | } // namespace |
68 | |
69 | // static |
70 | int32 DeviceFactory::DevicePriority(const string& device_type) { |
71 | tf_shared_lock l(*get_device_factory_lock()); |
72 | std::unordered_map<string, FactoryItem>& factories = device_factories(); |
73 | auto iter = factories.find(device_type); |
74 | if (iter != factories.end()) { |
75 | return iter->second.priority; |
76 | } |
77 | |
78 | return -1; |
79 | } |
80 | |
81 | bool DeviceFactory::IsPluggableDevice(const string& device_type) { |
82 | tf_shared_lock l(*get_device_factory_lock()); |
83 | std::unordered_map<string, FactoryItem>& factories = device_factories(); |
84 | auto iter = factories.find(device_type); |
85 | if (iter != factories.end()) { |
86 | return iter->second.is_pluggable_device; |
87 | } |
88 | return false; |
89 | } |
90 | |
91 | // static |
92 | void DeviceFactory::Register(const string& device_type, |
93 | std::unique_ptr<DeviceFactory> factory, |
94 | int priority, bool is_pluggable_device) { |
95 | if (!IsDeviceFactoryEnabled(device_type)) { |
96 | LOG(INFO) << "Device factory '" << device_type << "' disabled by " |
97 | << "TF_ENABLED_DEVICE_TYPES environment variable." ; |
98 | return; |
99 | } |
100 | mutex_lock l(*get_device_factory_lock()); |
101 | std::unordered_map<string, FactoryItem>& factories = device_factories(); |
102 | auto iter = factories.find(device_type); |
103 | if (iter == factories.end()) { |
104 | factories[device_type] = {std::move(factory), priority, |
105 | is_pluggable_device}; |
106 | } else { |
107 | if (iter->second.priority < priority) { |
108 | iter->second = {std::move(factory), priority, is_pluggable_device}; |
109 | } else if (iter->second.priority == priority) { |
110 | LOG(FATAL) << "Duplicate registration of device factory for type " |
111 | << device_type << " with the same priority " << priority; |
112 | } |
113 | } |
114 | } |
115 | |
116 | DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { |
117 | tf_shared_lock l(*get_device_factory_lock()); |
118 | auto it = device_factories().find(device_type); |
119 | if (it == device_factories().end()) { |
120 | return nullptr; |
121 | } else if (!IsDeviceFactoryEnabled(device_type)) { |
122 | LOG(FATAL) << "Device type " << device_type // Crash OK |
123 | << " had factory registered but was explicitly disabled by " |
124 | << "`TF_ENABLED_DEVICE_TYPES`. This environment variable needs " |
125 | << "to be set at program startup." ; |
126 | } |
127 | return it->second.factory.get(); |
128 | } |
129 | |
130 | Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) { |
131 | // CPU first. A CPU device is required. |
132 | // TODO(b/183974121): Consider merge the logic into the loop below. |
133 | auto cpu_factory = GetFactory("CPU" ); |
134 | if (!cpu_factory) { |
135 | return errors::NotFound( |
136 | "CPU Factory not registered. Did you link in threadpool_device?" ); |
137 | } |
138 | |
139 | size_t init_size = devices->size(); |
140 | TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(devices)); |
141 | if (devices->size() == init_size) { |
142 | return errors::NotFound("No CPU devices are available in this process" ); |
143 | } |
144 | |
145 | // Then the rest (including GPU). |
146 | tf_shared_lock l(*get_device_factory_lock()); |
147 | for (auto& p : device_factories()) { |
148 | auto factory = p.second.factory.get(); |
149 | if (factory != cpu_factory) { |
150 | TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices)); |
151 | } |
152 | } |
153 | |
154 | return OkStatus(); |
155 | } |
156 | |
157 | Status DeviceFactory::ListPluggablePhysicalDevices( |
158 | std::vector<string>* devices) { |
159 | tf_shared_lock l(*get_device_factory_lock()); |
160 | for (auto& p : device_factories()) { |
161 | if (p.second.is_pluggable_device) { |
162 | auto factory = p.second.factory.get(); |
163 | TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices)); |
164 | } |
165 | } |
166 | return OkStatus(); |
167 | } |
168 | |
169 | Status DeviceFactory::GetAnyDeviceDetails( |
170 | int device_index, std::unordered_map<string, string>* details) { |
171 | if (device_index < 0) { |
172 | return errors::InvalidArgument("Device index out of bounds: " , |
173 | device_index); |
174 | } |
175 | const int orig_device_index = device_index; |
176 | |
177 | // Iterate over devices in the same way as in ListAllPhysicalDevices. |
178 | auto cpu_factory = GetFactory("CPU" ); |
179 | if (!cpu_factory) { |
180 | return errors::NotFound( |
181 | "CPU Factory not registered. Did you link in threadpool_device?" ); |
182 | } |
183 | |
184 | // TODO(b/183974121): Consider merge the logic into the loop below. |
185 | std::vector<string> devices; |
186 | TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices)); |
187 | if (device_index < devices.size()) { |
188 | return cpu_factory->GetDeviceDetails(device_index, details); |
189 | } |
190 | device_index -= devices.size(); |
191 | |
192 | // Then the rest (including GPU). |
193 | tf_shared_lock l(*get_device_factory_lock()); |
194 | for (auto& p : device_factories()) { |
195 | auto factory = p.second.factory.get(); |
196 | if (factory != cpu_factory) { |
197 | devices.clear(); |
198 | // TODO(b/146009447): Find the factory size without having to allocate a |
199 | // vector with all the physical devices. |
200 | TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices)); |
201 | if (device_index < devices.size()) { |
202 | return factory->GetDeviceDetails(device_index, details); |
203 | } |
204 | device_index -= devices.size(); |
205 | } |
206 | } |
207 | |
208 | return errors::InvalidArgument("Device index out of bounds: " , |
209 | orig_device_index); |
210 | } |
211 | |
212 | Status DeviceFactory::AddCpuDevices( |
213 | const SessionOptions& options, const string& name_prefix, |
214 | std::vector<std::unique_ptr<Device>>* devices) { |
215 | auto cpu_factory = GetFactory("CPU" ); |
216 | if (!cpu_factory) { |
217 | return errors::NotFound( |
218 | "CPU Factory not registered. Did you link in threadpool_device?" ); |
219 | } |
220 | size_t init_size = devices->size(); |
221 | TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices)); |
222 | if (devices->size() == init_size) { |
223 | return errors::NotFound("No CPU devices are available in this process" ); |
224 | } |
225 | |
226 | return OkStatus(); |
227 | } |
228 | |
229 | Status DeviceFactory::AddDevices( |
230 | const SessionOptions& options, const string& name_prefix, |
231 | std::vector<std::unique_ptr<Device>>* devices) { |
232 | // CPU first. A CPU device is required. |
233 | // TODO(b/183974121): Consider merge the logic into the loop below. |
234 | TF_RETURN_IF_ERROR(AddCpuDevices(options, name_prefix, devices)); |
235 | |
236 | absl::flat_hash_set<std::string> allowed_device_types; |
237 | for (const auto& device_filter : options.config.device_filters()) { |
238 | DeviceNameUtils::ParsedName parsed; |
239 | if (!DeviceNameUtils::ParseFullOrLocalName(device_filter, &parsed)) { |
240 | return errors::InvalidArgument( |
241 | absl::StrCat("Invalid device filter: " , device_filter)); |
242 | } |
243 | if (parsed.has_type) { |
244 | allowed_device_types.insert(parsed.type); |
245 | } |
246 | } |
247 | |
248 | auto cpu_factory = GetFactory("CPU" ); |
249 | // Then the rest (including GPU). |
250 | mutex_lock l(*get_device_factory_lock()); |
251 | for (auto& p : device_factories()) { |
252 | if (!allowed_device_types.empty() && |
253 | !allowed_device_types.contains(p.first)) { |
254 | continue; // Skip if the device type is not found from the device filter. |
255 | } |
256 | auto factory = p.second.factory.get(); |
257 | if (factory != cpu_factory) { |
258 | TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices)); |
259 | } |
260 | } |
261 | |
262 | return OkStatus(); |
263 | } |
264 | |
265 | std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type, |
266 | const SessionOptions& options, |
267 | const string& name_prefix) { |
268 | auto device_factory = GetFactory(type); |
269 | if (!device_factory) { |
270 | return nullptr; |
271 | } |
272 | SessionOptions opt = options; |
273 | (*opt.config.mutable_device_count())[type] = 1; |
274 | std::vector<std::unique_ptr<Device>> devices; |
275 | TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices)); |
276 | int expected_num_devices = 1; |
277 | auto iter = options.config.device_count().find(type); |
278 | if (iter != options.config.device_count().end()) { |
279 | expected_num_devices = iter->second; |
280 | } |
281 | DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices)); |
282 | return std::move(devices[0]); |
283 | } |
284 | |
285 | } // namespace tensorflow |
286 | |