1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ |
18 | |
19 | #include <memory> |
20 | #include <unordered_map> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/common_runtime/device.h" |
24 | #include "tensorflow/core/platform/macros.h" |
25 | #include "tensorflow/core/platform/types.h" |
26 | #include "tensorflow/core/util/device_name_utils.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | typedef std::vector<std::pair<Device*, int32>> PrioritizedDeviceVector; |
31 | |
32 | // DeviceSet is a container class for managing the various types of |
33 | // devices used by a model. |
34 | class DeviceSet { |
35 | public: |
36 | DeviceSet(); |
37 | ~DeviceSet(); |
38 | |
39 | // Does not take ownership of 'device'. |
40 | void AddDevice(Device* device) TF_LOCKS_EXCLUDED(devices_mu_); |
41 | |
42 | // Set the device designated as the "client". This device |
43 | // must also be registered via AddDevice(). |
44 | void set_client_device(Device* device) { |
45 | DCHECK(client_device_ == nullptr); |
46 | client_device_ = device; |
47 | } |
48 | |
49 | // Returns a pointer to the device designated as the "client". |
50 | Device* client_device() const { return client_device_; } |
51 | |
52 | // Return the list of devices in this set. |
53 | const std::vector<Device*>& devices() const { return devices_; } |
54 | |
55 | // Given a DeviceNameUtils::ParsedName (which may have some |
56 | // wildcards for different components), fills "*devices" with all |
57 | // devices in "*this" that match "spec". |
58 | void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec, |
59 | std::vector<Device*>* devices) const; |
60 | |
61 | // Finds the device with the given "fullname". Returns nullptr if |
62 | // not found. |
63 | Device* FindDeviceByName(const string& fullname) const; |
64 | |
65 | // Return the list of unique device types in this set, ordered |
66 | // with more preferable devices earlier. |
67 | std::vector<DeviceType> PrioritizedDeviceTypeList() const; |
68 | |
69 | // Return the prioritized list of devices in this set. |
70 | // Devices are prioritized first by `DeviceTypeOrder`, then by name. |
71 | const PrioritizedDeviceVector& prioritized_devices() const |
72 | TF_LOCKS_EXCLUDED(devices_mu_); |
73 | |
74 | // Return the prioritized list of unique device types in this set. |
75 | // |
76 | // The list will be ordered by decreasing priority. The priorities (the second |
77 | // element in the list's `std::pair<DeviceType, int32>`) will be initialized |
78 | // to the value of `DeviceTypeOrder` for the device types. |
79 | const PrioritizedDeviceTypeVector& prioritized_device_types() const |
80 | TF_LOCKS_EXCLUDED(devices_mu_); |
81 | |
82 | // An order to sort by device types according to system-determined |
83 | // priority. |
84 | // |
85 | // Higher result implies higher priority. |
86 | static int DeviceTypeOrder(const DeviceType& d); |
87 | |
88 | // Sorts a PrioritizedDeviceVector according to devices and explicit |
89 | // priorities. |
90 | // |
91 | // After a call to this function, the argument vector will be sorted by |
92 | // explicit priority (the second element in the `std::pair<DeviceType, |
93 | // int32>`), then by `DeviceTypeOrder` of the device type, then by device |
94 | // locality, and lastly by device name. |
95 | static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector); |
96 | |
97 | // Sorts a PrioritizedDeviceTypeVector according to types and explicit |
98 | // priorities. |
99 | // |
100 | // After a call to this function, the argument vector will be sorted by |
101 | // explicit priority (the second element in the `std::pair<DeviceType, |
102 | // int32>`), then by `DeviceTypeOrder` of the device type. |
103 | static void SortPrioritizedDeviceTypeVector( |
104 | PrioritizedDeviceTypeVector* vector); |
105 | |
106 | private: |
107 | mutable mutex devices_mu_; |
108 | |
109 | // Not owned. |
110 | std::vector<Device*> devices_; |
111 | |
112 | // Cached prioritized vector, created on-the-fly when |
113 | // prioritized_devices() is called. |
114 | mutable PrioritizedDeviceVector prioritized_devices_ |
115 | TF_GUARDED_BY(devices_mu_); |
116 | |
117 | // Cached prioritized vector, created on-the-fly when |
118 | // prioritized_device_types() is called. |
119 | mutable PrioritizedDeviceTypeVector prioritized_device_types_ |
120 | TF_GUARDED_BY(devices_mu_); |
121 | |
122 | // Fullname -> device* for device in devices_. |
123 | std::unordered_map<string, Device*> device_by_name_; |
124 | |
125 | // client_device_ points to an element of devices_ that we consider |
126 | // to be the client device (in this local process). |
127 | Device* client_device_ = nullptr; |
128 | |
129 | TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet); |
130 | }; |
131 | |
132 | } // namespace tensorflow |
133 | |
134 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_ |
135 | |