1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | #include <unordered_set> |
23 | #include <vector> |
24 | |
25 | #include "absl/container/flat_hash_set.h" |
26 | #include "tensorflow/core/common_runtime/device.h" |
27 | #include "tensorflow/core/lib/core/arena.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/lib/core/stringpiece.h" |
30 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
31 | #include "tensorflow/core/platform/macros.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | class DeviceAttributes; |
36 | |
37 | // Represents a set of devices. |
38 | class DeviceMgr { |
39 | public: |
40 | DeviceMgr() = default; |
41 | virtual ~DeviceMgr(); |
42 | |
43 | // Returns attributes of all devices. |
44 | virtual void ListDeviceAttributes( |
45 | std::vector<DeviceAttributes>* devices) const = 0; |
46 | |
47 | // Returns raw pointers to the underlying devices. |
48 | virtual std::vector<Device*> ListDevices() const = 0; |
49 | |
50 | // Returns a string listing all devices. |
51 | virtual string DebugString() const = 0; |
52 | |
53 | // Returns a string of all the device mapping. |
54 | virtual string DeviceMappingString() const = 0; |
55 | |
56 | // Assigns *device with pointer to Device of the given name. |
57 | // Accepts either a full device name, or just the replica-local suffix. |
58 | virtual Status LookupDevice(StringPiece name, Device** device) const = 0; |
59 | |
60 | // Check if the current device manager contains device with the given |
61 | // incarnation ID. Looking up by incarnation IDs because they are randomly |
62 | // generated and not intentionally reused (unlike device pointers). |
63 | virtual bool ContainsDevice(int64_t device_incarnation) const = 0; |
64 | |
65 | // Clears given containers of all devices if 'container' is |
66 | // non-empty. Otherwise, clears default containers of all devices. |
67 | virtual void ClearContainers(gtl::ArraySlice<string> containers) const = 0; |
68 | |
69 | virtual int NumDeviceType(const string& type) const = 0; |
70 | |
71 | // Returns an arbitrary CPU device if one is present, otherwise return |
72 | // nullptr. |
73 | virtual Device* HostCPU() const = 0; |
74 | |
75 | TF_DISALLOW_COPY_AND_ASSIGN(DeviceMgr); |
76 | }; |
77 | |
78 | // Represents a static set of devices. |
79 | class StaticDeviceMgr : public DeviceMgr { |
80 | public: |
81 | // Constructs a StaticDeviceMgr from a list of devices. |
82 | explicit StaticDeviceMgr(std::vector<std::unique_ptr<Device>> devices); |
83 | |
84 | // Constructs a StaticDeviceMgr managing a single device. |
85 | explicit StaticDeviceMgr(std::unique_ptr<Device> device); |
86 | |
87 | ~StaticDeviceMgr() override; |
88 | |
89 | void ListDeviceAttributes( |
90 | std::vector<DeviceAttributes>* devices) const override; |
91 | std::vector<Device*> ListDevices() const override; |
92 | string DebugString() const override; |
93 | string DeviceMappingString() const override; |
94 | Status LookupDevice(StringPiece name, Device** device) const override; |
95 | bool ContainsDevice(int64_t device_incarnation) const override; |
96 | void ClearContainers(gtl::ArraySlice<string> containers) const override; |
97 | int NumDeviceType(const string& type) const override; |
98 | Device* HostCPU() const override; |
99 | |
100 | private: |
101 | const std::vector<std::unique_ptr<Device>> devices_; |
102 | |
103 | StringPiece CopyToBackingStore(StringPiece s); |
104 | |
105 | absl::flat_hash_set<int64_t> device_incarnation_set_; |
106 | std::unordered_map<StringPiece, Device*, StringPieceHasher> device_map_; |
107 | core::Arena name_backing_store_; // Storage for keys in device_map_ |
108 | std::unordered_map<string, int> device_type_counts_; |
109 | Device* cpu_device_; |
110 | |
111 | TF_DISALLOW_COPY_AND_ASSIGN(StaticDeviceMgr); |
112 | }; |
113 | |
114 | // Size of stale device buffer for temporary storage of removed devices. |
115 | static const size_t kStaleDeviceBufferSize = 8192; |
116 | |
117 | // Represents a dynamic set of devices |
118 | class DynamicDeviceMgr : public DeviceMgr { |
119 | public: |
120 | // Constructs an empty DynamicDeviceMgr. |
121 | DynamicDeviceMgr(); |
122 | |
123 | // Constructs a DynamicDeviceMgr from a list of devices. |
124 | // TODO(b/183966398): Remove StaticDeviceMgr since there's no usage. |
125 | explicit DynamicDeviceMgr(std::vector<std::unique_ptr<Device>> devices); |
126 | |
127 | ~DynamicDeviceMgr() override; |
128 | |
129 | void ListDeviceAttributes( |
130 | std::vector<DeviceAttributes>* devices) const override; |
131 | std::vector<Device*> ListDevices() const override; |
132 | string DebugString() const override; |
133 | string DeviceMappingString() const override; |
134 | Status LookupDevice(StringPiece name, Device** device) const override; |
135 | bool ContainsDevice(int64_t device_incarnation) const override; |
136 | void ClearContainers(gtl::ArraySlice<string> containers) const override; |
137 | int NumDeviceType(const string& type) const override; |
138 | Device* HostCPU() const override; |
139 | |
140 | // Add devices to device manager. Returns error for repeated device names. |
141 | Status AddDevices(std::vector<std::unique_ptr<Device>> devices); |
142 | |
143 | // Remove devices from device manager. |
144 | // Returns error for non-existing devices or if the HostCPU() device is in the |
145 | // input list. If an error is returned, the device list is not modified. |
146 | Status RemoveDevices(const std::vector<Device*>& devices); |
147 | |
148 | // Remove devices from device manager by their names. Returns error for |
149 | // non-existing devices or if the HostCPU() device is given in the input list. |
150 | // If an error is returned, the device list is not modified. |
151 | Status RemoveDevicesByName(const std::vector<string>& device_names); |
152 | |
153 | private: |
154 | mutable mutex devices_mu_; |
155 | |
156 | std::vector<std::unique_ptr<Device>> dynamic_devices_ |
157 | TF_GUARDED_BY(devices_mu_); |
158 | |
159 | absl::flat_hash_set<int64_t> device_incarnation_set_ |
160 | TF_GUARDED_BY(devices_mu_); |
161 | std::unordered_map<string, Device*> device_map_ TF_GUARDED_BY(devices_mu_); |
162 | |
163 | std::unordered_map<string, int> device_type_counts_ |
164 | TF_GUARDED_BY(devices_mu_); |
165 | |
166 | mutable std::atomic<Device*> cpu_device_; // memoize `HostCPU` result |
167 | |
168 | class DeviceCircularBuffer { |
169 | public: |
170 | DeviceCircularBuffer() : index_(0) { |
171 | devices_.resize(kStaleDeviceBufferSize); |
172 | } |
173 | void add(std::unique_ptr<Device> device) { |
174 | devices_[index_] = std::move(device); |
175 | index_ = (index_ + 1) % kStaleDeviceBufferSize; |
176 | } |
177 | |
178 | private: |
179 | int index_; |
180 | std::vector<std::unique_ptr<Device>> devices_; |
181 | }; |
182 | |
183 | // Buffer to temporarily store the removed devices. Raw device pointers are |
184 | // accessible to DeviceSet, and if the function instantiation process directly |
185 | // access fields through the device set, the underlying device object must |
186 | // still be available to avoid segmentation fault. We keep the devices in this |
187 | // buffer only for that purpose. |
188 | DeviceCircularBuffer stale_devices_ TF_GUARDED_BY(devices_mu_); |
189 | |
190 | TF_DISALLOW_COPY_AND_ASSIGN(DynamicDeviceMgr); |
191 | }; |
192 | } // namespace tensorflow |
193 | |
194 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_ |
195 | |