1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "absl/base/attributes.h" |
23 | #include "tensorflow/core/platform/status.h" |
24 | #include "tensorflow/core/platform/types.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | class Device; |
29 | struct SessionOptions; |
30 | |
31 | class DeviceFactory { |
32 | public: |
33 | virtual ~DeviceFactory() {} |
34 | static void Register(const std::string& device_type, |
35 | std::unique_ptr<DeviceFactory> factory, int priority, |
36 | bool is_pluggable_device); |
37 | ABSL_DEPRECATED("Use the `Register` function above instead" ) |
38 | static void Register(const std::string& device_type, DeviceFactory* factory, |
39 | int priority, bool is_pluggable_device) { |
40 | Register(device_type, std::unique_ptr<DeviceFactory>(factory), priority, |
41 | is_pluggable_device); |
42 | } |
43 | static DeviceFactory* GetFactory(const std::string& device_type); |
44 | |
45 | // Append to "*devices" CPU devices. |
46 | static Status AddCpuDevices(const SessionOptions& options, |
47 | const std::string& name_prefix, |
48 | std::vector<std::unique_ptr<Device>>* devices); |
49 | |
50 | // Append to "*devices" all suitable devices, respecting |
51 | // any device type specific properties/counts listed in "options". |
52 | // |
53 | // CPU devices are added first. |
54 | static Status AddDevices(const SessionOptions& options, |
55 | const std::string& name_prefix, |
56 | std::vector<std::unique_ptr<Device>>* devices); |
57 | |
58 | // Helper for tests. Create a single device of type "type". The |
59 | // returned device is always numbered zero, so if creating multiple |
60 | // devices of the same type, supply distinct name_prefix arguments. |
61 | static std::unique_ptr<Device> NewDevice(const string& type, |
62 | const SessionOptions& options, |
63 | const string& name_prefix); |
64 | |
65 | // Iterate through all device factories and build a list of all of the |
66 | // possible physical devices. |
67 | // |
68 | // CPU is are added first. |
69 | static Status ListAllPhysicalDevices(std::vector<string>* devices); |
70 | |
71 | // Iterate through all device factories and build a list of all of the |
72 | // possible pluggable physical devices. |
73 | static Status ListPluggablePhysicalDevices(std::vector<string>* devices); |
74 | |
75 | // Get details for a specific device among all device factories. |
76 | // 'device_index' indexes into devices from ListAllPhysicalDevices. |
77 | static Status GetAnyDeviceDetails( |
78 | int device_index, std::unordered_map<string, string>* details); |
79 | |
80 | // For a specific device factory list all possible physical devices. |
81 | virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0; |
82 | |
83 | // Get details for a specific device for a specific factory. Subclasses |
84 | // can store arbitrary device information in the map. 'device_index' indexes |
85 | // into devices from ListPhysicalDevices. |
86 | virtual Status GetDeviceDetails(int device_index, |
87 | std::unordered_map<string, string>* details) { |
88 | return OkStatus(); |
89 | } |
90 | |
91 | // Most clients should call AddDevices() instead. |
92 | virtual Status CreateDevices( |
93 | const SessionOptions& options, const std::string& name_prefix, |
94 | std::vector<std::unique_ptr<Device>>* devices) = 0; |
95 | |
96 | // Return the device priority number for a "device_type" string. |
97 | // |
98 | // Higher number implies higher priority. |
99 | // |
100 | // In standard TensorFlow distributions, GPU device types are |
101 | // preferred over CPU, and by default, custom devices that don't set |
102 | // a custom priority during registration will be prioritized lower |
103 | // than CPU. Custom devices that want a higher priority can set the |
104 | // 'priority' field when registering their device to something |
105 | // higher than the packaged devices. See calls to |
106 | // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used |
107 | // for built-in devices. |
108 | static int32 DevicePriority(const std::string& device_type); |
109 | |
110 | // Returns true if 'device_type' is registered from plugin. Returns false if |
111 | // 'device_type' is a first-party device. |
112 | static bool IsPluggableDevice(const std::string& device_type); |
113 | }; |
114 | |
115 | namespace dfactory { |
116 | |
117 | template <class Factory> |
118 | class Registrar { |
119 | public: |
120 | // Multiple registrations for the same device type with different priorities |
121 | // are allowed. Priorities are used in two different ways: |
122 | // |
123 | // 1) When choosing which factory (that is, which device |
124 | // implementation) to use for a specific 'device_type', the |
125 | // factory registered with the highest priority will be chosen. |
126 | // For example, if there are two registrations: |
127 | // |
128 | // Registrar<CPUFactory1>("CPU", 125); |
129 | // Registrar<CPUFactory2>("CPU", 150); |
130 | // |
131 | // then CPUFactory2 will be chosen when |
132 | // DeviceFactory::GetFactory("CPU") is called. |
133 | // |
134 | // 2) When choosing which 'device_type' is preferred over other |
135 | // DeviceTypes in a DeviceSet, the ordering is determined |
136 | // by the 'priority' set during registration. For example, if there |
137 | // are two registrations: |
138 | // |
139 | // Registrar<CPUFactory>("CPU", 100); |
140 | // Registrar<GPUFactory>("GPU", 200); |
141 | // |
142 | // then DeviceType("GPU") will be prioritized higher than |
143 | // DeviceType("CPU"). |
144 | // |
145 | // The default priority values for built-in devices is: |
146 | // GPU: 210 |
147 | // GPUCompatibleCPU: 70 |
148 | // ThreadPoolDevice: 60 |
149 | // Default: 50 |
150 | explicit Registrar(const std::string& device_type, int priority = 50) { |
151 | DeviceFactory::Register(device_type, std::make_unique<Factory>(), priority, |
152 | /*is_pluggable_device*/ false); |
153 | } |
154 | }; |
155 | |
156 | } // namespace dfactory |
157 | |
158 | #define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \ |
159 | INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ |
160 | __COUNTER__, ##__VA_ARGS__) |
161 | |
162 | #define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \ |
163 | ctr, ...) \ |
164 | static ::tensorflow::dfactory::Registrar<device_factory> \ |
165 | INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, ##__VA_ARGS__) |
166 | |
167 | // __COUNTER__ must go through another macro to be properly expanded |
168 | #define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_ |
169 | |
170 | } // namespace tensorflow |
171 | |
172 | #endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_ |
173 | |