1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/device_factory.h"
17
18#include <memory>
19#include <string>
20#include <unordered_map>
21#include <vector>
22
23#include "absl/container/flat_hash_set.h"
24#include "tensorflow/core/framework/device.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/platform/logging.h"
29#include "tensorflow/core/platform/mutex.h"
30#include "tensorflow/core/platform/types.h"
31#include "tensorflow/core/public/session_options.h"
32#include "tensorflow/core/util/device_name_utils.h"
33#include "tensorflow/core/util/env_var.h"
34
35namespace tensorflow {
36
37namespace {
38
39static mutex* get_device_factory_lock() {
40 static mutex device_factory_lock(LINKER_INITIALIZED);
41 return &device_factory_lock;
42}
43
44struct FactoryItem {
45 std::unique_ptr<DeviceFactory> factory;
46 int priority;
47 bool is_pluggable_device;
48};
49
50std::unordered_map<string, FactoryItem>& device_factories() {
51 static std::unordered_map<string, FactoryItem>* factories =
52 new std::unordered_map<string, FactoryItem>;
53 return *factories;
54}
55
56bool IsDeviceFactoryEnabled(const string& device_type) {
57 std::vector<string> enabled_devices;
58 TF_CHECK_OK(tensorflow::ReadStringsFromEnvVar(
59 /*env_var_name=*/"TF_ENABLED_DEVICE_TYPES", /*default_val=*/"",
60 &enabled_devices));
61 if (enabled_devices.empty()) {
62 return true;
63 }
64 return std::find(enabled_devices.begin(), enabled_devices.end(),
65 device_type) != enabled_devices.end();
66}
67} // namespace
68
69// static
70int32 DeviceFactory::DevicePriority(const string& device_type) {
71 tf_shared_lock l(*get_device_factory_lock());
72 std::unordered_map<string, FactoryItem>& factories = device_factories();
73 auto iter = factories.find(device_type);
74 if (iter != factories.end()) {
75 return iter->second.priority;
76 }
77
78 return -1;
79}
80
81bool DeviceFactory::IsPluggableDevice(const string& device_type) {
82 tf_shared_lock l(*get_device_factory_lock());
83 std::unordered_map<string, FactoryItem>& factories = device_factories();
84 auto iter = factories.find(device_type);
85 if (iter != factories.end()) {
86 return iter->second.is_pluggable_device;
87 }
88 return false;
89}
90
91// static
92void DeviceFactory::Register(const string& device_type,
93 std::unique_ptr<DeviceFactory> factory,
94 int priority, bool is_pluggable_device) {
95 if (!IsDeviceFactoryEnabled(device_type)) {
96 LOG(INFO) << "Device factory '" << device_type << "' disabled by "
97 << "TF_ENABLED_DEVICE_TYPES environment variable.";
98 return;
99 }
100 mutex_lock l(*get_device_factory_lock());
101 std::unordered_map<string, FactoryItem>& factories = device_factories();
102 auto iter = factories.find(device_type);
103 if (iter == factories.end()) {
104 factories[device_type] = {std::move(factory), priority,
105 is_pluggable_device};
106 } else {
107 if (iter->second.priority < priority) {
108 iter->second = {std::move(factory), priority, is_pluggable_device};
109 } else if (iter->second.priority == priority) {
110 LOG(FATAL) << "Duplicate registration of device factory for type "
111 << device_type << " with the same priority " << priority;
112 }
113 }
114}
115
116DeviceFactory* DeviceFactory::GetFactory(const string& device_type) {
117 tf_shared_lock l(*get_device_factory_lock());
118 auto it = device_factories().find(device_type);
119 if (it == device_factories().end()) {
120 return nullptr;
121 } else if (!IsDeviceFactoryEnabled(device_type)) {
122 LOG(FATAL) << "Device type " << device_type // Crash OK
123 << " had factory registered but was explicitly disabled by "
124 << "`TF_ENABLED_DEVICE_TYPES`. This environment variable needs "
125 << "to be set at program startup.";
126 }
127 return it->second.factory.get();
128}
129
130Status DeviceFactory::ListAllPhysicalDevices(std::vector<string>* devices) {
131 // CPU first. A CPU device is required.
132 // TODO(b/183974121): Consider merge the logic into the loop below.
133 auto cpu_factory = GetFactory("CPU");
134 if (!cpu_factory) {
135 return errors::NotFound(
136 "CPU Factory not registered. Did you link in threadpool_device?");
137 }
138
139 size_t init_size = devices->size();
140 TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(devices));
141 if (devices->size() == init_size) {
142 return errors::NotFound("No CPU devices are available in this process");
143 }
144
145 // Then the rest (including GPU).
146 tf_shared_lock l(*get_device_factory_lock());
147 for (auto& p : device_factories()) {
148 auto factory = p.second.factory.get();
149 if (factory != cpu_factory) {
150 TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
151 }
152 }
153
154 return OkStatus();
155}
156
157Status DeviceFactory::ListPluggablePhysicalDevices(
158 std::vector<string>* devices) {
159 tf_shared_lock l(*get_device_factory_lock());
160 for (auto& p : device_factories()) {
161 if (p.second.is_pluggable_device) {
162 auto factory = p.second.factory.get();
163 TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(devices));
164 }
165 }
166 return OkStatus();
167}
168
169Status DeviceFactory::GetAnyDeviceDetails(
170 int device_index, std::unordered_map<string, string>* details) {
171 if (device_index < 0) {
172 return errors::InvalidArgument("Device index out of bounds: ",
173 device_index);
174 }
175 const int orig_device_index = device_index;
176
177 // Iterate over devices in the same way as in ListAllPhysicalDevices.
178 auto cpu_factory = GetFactory("CPU");
179 if (!cpu_factory) {
180 return errors::NotFound(
181 "CPU Factory not registered. Did you link in threadpool_device?");
182 }
183
184 // TODO(b/183974121): Consider merge the logic into the loop below.
185 std::vector<string> devices;
186 TF_RETURN_IF_ERROR(cpu_factory->ListPhysicalDevices(&devices));
187 if (device_index < devices.size()) {
188 return cpu_factory->GetDeviceDetails(device_index, details);
189 }
190 device_index -= devices.size();
191
192 // Then the rest (including GPU).
193 tf_shared_lock l(*get_device_factory_lock());
194 for (auto& p : device_factories()) {
195 auto factory = p.second.factory.get();
196 if (factory != cpu_factory) {
197 devices.clear();
198 // TODO(b/146009447): Find the factory size without having to allocate a
199 // vector with all the physical devices.
200 TF_RETURN_IF_ERROR(factory->ListPhysicalDevices(&devices));
201 if (device_index < devices.size()) {
202 return factory->GetDeviceDetails(device_index, details);
203 }
204 device_index -= devices.size();
205 }
206 }
207
208 return errors::InvalidArgument("Device index out of bounds: ",
209 orig_device_index);
210}
211
212Status DeviceFactory::AddCpuDevices(
213 const SessionOptions& options, const string& name_prefix,
214 std::vector<std::unique_ptr<Device>>* devices) {
215 auto cpu_factory = GetFactory("CPU");
216 if (!cpu_factory) {
217 return errors::NotFound(
218 "CPU Factory not registered. Did you link in threadpool_device?");
219 }
220 size_t init_size = devices->size();
221 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices));
222 if (devices->size() == init_size) {
223 return errors::NotFound("No CPU devices are available in this process");
224 }
225
226 return OkStatus();
227}
228
229Status DeviceFactory::AddDevices(
230 const SessionOptions& options, const string& name_prefix,
231 std::vector<std::unique_ptr<Device>>* devices) {
232 // CPU first. A CPU device is required.
233 // TODO(b/183974121): Consider merge the logic into the loop below.
234 TF_RETURN_IF_ERROR(AddCpuDevices(options, name_prefix, devices));
235
236 absl::flat_hash_set<std::string> allowed_device_types;
237 for (const auto& device_filter : options.config.device_filters()) {
238 DeviceNameUtils::ParsedName parsed;
239 if (!DeviceNameUtils::ParseFullOrLocalName(device_filter, &parsed)) {
240 return errors::InvalidArgument(
241 absl::StrCat("Invalid device filter: ", device_filter));
242 }
243 if (parsed.has_type) {
244 allowed_device_types.insert(parsed.type);
245 }
246 }
247
248 auto cpu_factory = GetFactory("CPU");
249 // Then the rest (including GPU).
250 mutex_lock l(*get_device_factory_lock());
251 for (auto& p : device_factories()) {
252 if (!allowed_device_types.empty() &&
253 !allowed_device_types.contains(p.first)) {
254 continue; // Skip if the device type is not found from the device filter.
255 }
256 auto factory = p.second.factory.get();
257 if (factory != cpu_factory) {
258 TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices));
259 }
260 }
261
262 return OkStatus();
263}
264
265std::unique_ptr<Device> DeviceFactory::NewDevice(const string& type,
266 const SessionOptions& options,
267 const string& name_prefix) {
268 auto device_factory = GetFactory(type);
269 if (!device_factory) {
270 return nullptr;
271 }
272 SessionOptions opt = options;
273 (*opt.config.mutable_device_count())[type] = 1;
274 std::vector<std::unique_ptr<Device>> devices;
275 TF_CHECK_OK(device_factory->CreateDevices(opt, name_prefix, &devices));
276 int expected_num_devices = 1;
277 auto iter = options.config.device_count().find(type);
278 if (iter != options.config.device_count().end()) {
279 expected_num_devices = iter->second;
280 }
281 DCHECK_EQ(devices.size(), static_cast<size_t>(expected_num_devices));
282 return std::move(devices[0]);
283}
284
285} // namespace tensorflow
286