1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
17#define TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
18
19#include <string>
20#include <vector>
21
22#include "absl/base/attributes.h"
23#include "tensorflow/core/platform/status.h"
24#include "tensorflow/core/platform/types.h"
25
26namespace tensorflow {
27
28class Device;
29struct SessionOptions;
30
31class DeviceFactory {
32 public:
33 virtual ~DeviceFactory() {}
34 static void Register(const std::string& device_type,
35 std::unique_ptr<DeviceFactory> factory, int priority,
36 bool is_pluggable_device);
37 ABSL_DEPRECATED("Use the `Register` function above instead")
38 static void Register(const std::string& device_type, DeviceFactory* factory,
39 int priority, bool is_pluggable_device) {
40 Register(device_type, std::unique_ptr<DeviceFactory>(factory), priority,
41 is_pluggable_device);
42 }
43 static DeviceFactory* GetFactory(const std::string& device_type);
44
45 // Append to "*devices" CPU devices.
46 static Status AddCpuDevices(const SessionOptions& options,
47 const std::string& name_prefix,
48 std::vector<std::unique_ptr<Device>>* devices);
49
50 // Append to "*devices" all suitable devices, respecting
51 // any device type specific properties/counts listed in "options".
52 //
53 // CPU devices are added first.
54 static Status AddDevices(const SessionOptions& options,
55 const std::string& name_prefix,
56 std::vector<std::unique_ptr<Device>>* devices);
57
58 // Helper for tests. Create a single device of type "type". The
59 // returned device is always numbered zero, so if creating multiple
60 // devices of the same type, supply distinct name_prefix arguments.
61 static std::unique_ptr<Device> NewDevice(const string& type,
62 const SessionOptions& options,
63 const string& name_prefix);
64
65 // Iterate through all device factories and build a list of all of the
66 // possible physical devices.
67 //
68 // CPU is are added first.
69 static Status ListAllPhysicalDevices(std::vector<string>* devices);
70
71 // Iterate through all device factories and build a list of all of the
72 // possible pluggable physical devices.
73 static Status ListPluggablePhysicalDevices(std::vector<string>* devices);
74
75 // Get details for a specific device among all device factories.
76 // 'device_index' indexes into devices from ListAllPhysicalDevices.
77 static Status GetAnyDeviceDetails(
78 int device_index, std::unordered_map<string, string>* details);
79
80 // For a specific device factory list all possible physical devices.
81 virtual Status ListPhysicalDevices(std::vector<string>* devices) = 0;
82
83 // Get details for a specific device for a specific factory. Subclasses
84 // can store arbitrary device information in the map. 'device_index' indexes
85 // into devices from ListPhysicalDevices.
86 virtual Status GetDeviceDetails(int device_index,
87 std::unordered_map<string, string>* details) {
88 return OkStatus();
89 }
90
91 // Most clients should call AddDevices() instead.
92 virtual Status CreateDevices(
93 const SessionOptions& options, const std::string& name_prefix,
94 std::vector<std::unique_ptr<Device>>* devices) = 0;
95
96 // Return the device priority number for a "device_type" string.
97 //
98 // Higher number implies higher priority.
99 //
100 // In standard TensorFlow distributions, GPU device types are
101 // preferred over CPU, and by default, custom devices that don't set
102 // a custom priority during registration will be prioritized lower
103 // than CPU. Custom devices that want a higher priority can set the
104 // 'priority' field when registering their device to something
105 // higher than the packaged devices. See calls to
106 // REGISTER_LOCAL_DEVICE_FACTORY to see the existing priorities used
107 // for built-in devices.
108 static int32 DevicePriority(const std::string& device_type);
109
110 // Returns true if 'device_type' is registered from plugin. Returns false if
111 // 'device_type' is a first-party device.
112 static bool IsPluggableDevice(const std::string& device_type);
113};
114
115namespace dfactory {
116
117template <class Factory>
118class Registrar {
119 public:
120 // Multiple registrations for the same device type with different priorities
121 // are allowed. Priorities are used in two different ways:
122 //
123 // 1) When choosing which factory (that is, which device
124 // implementation) to use for a specific 'device_type', the
125 // factory registered with the highest priority will be chosen.
126 // For example, if there are two registrations:
127 //
128 // Registrar<CPUFactory1>("CPU", 125);
129 // Registrar<CPUFactory2>("CPU", 150);
130 //
131 // then CPUFactory2 will be chosen when
132 // DeviceFactory::GetFactory("CPU") is called.
133 //
134 // 2) When choosing which 'device_type' is preferred over other
135 // DeviceTypes in a DeviceSet, the ordering is determined
136 // by the 'priority' set during registration. For example, if there
137 // are two registrations:
138 //
139 // Registrar<CPUFactory>("CPU", 100);
140 // Registrar<GPUFactory>("GPU", 200);
141 //
142 // then DeviceType("GPU") will be prioritized higher than
143 // DeviceType("CPU").
144 //
145 // The default priority values for built-in devices is:
146 // GPU: 210
147 // GPUCompatibleCPU: 70
148 // ThreadPoolDevice: 60
149 // Default: 50
150 explicit Registrar(const std::string& device_type, int priority = 50) {
151 DeviceFactory::Register(device_type, std::make_unique<Factory>(), priority,
152 /*is_pluggable_device*/ false);
153 }
154};
155
156} // namespace dfactory
157
158#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...) \
159 INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
160 __COUNTER__, ##__VA_ARGS__)
161
162#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, \
163 ctr, ...) \
164 static ::tensorflow::dfactory::Registrar<device_factory> \
165 INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr)(device_type, ##__VA_ARGS__)
166
167// __COUNTER__ must go through another macro to be properly expanded
168#define INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY_NAME(ctr) ___##ctr##__object_
169
170} // namespace tensorflow
171
172#endif // TENSORFLOW_CORE_FRAMEWORK_DEVICE_FACTORY_H_
173