1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/device_set.h"
17
18#include <set>
19#include <utility>
20#include <vector>
21
22#include "tensorflow/core/common_runtime/device.h"
23#include "tensorflow/core/common_runtime/device_factory.h"
24#include "tensorflow/core/lib/core/stringpiece.h"
25#include "tensorflow/core/lib/gtl/map_util.h"
26
27namespace tensorflow {
28
29DeviceSet::DeviceSet() {}
30
31DeviceSet::~DeviceSet() {}
32
33void DeviceSet::AddDevice(Device* device) {
34 mutex_lock l(devices_mu_);
35 devices_.push_back(device);
36 prioritized_devices_.clear();
37 prioritized_device_types_.clear();
38 for (const string& name :
39 DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) {
40 device_by_name_.insert({name, device});
41 }
42}
43
44void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
45 std::vector<Device*>* devices) const {
46 // TODO(jeff): If we are going to repeatedly lookup the set of devices
47 // for the same spec, maybe we should have a cache of some sort
48 devices->clear();
49 for (Device* d : devices_) {
50 if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
51 devices->push_back(d);
52 }
53 }
54}
55
56Device* DeviceSet::FindDeviceByName(const string& name) const {
57 return gtl::FindPtrOrNull(device_by_name_, name);
58}
59
60// static
61int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
62 return DeviceFactory::DevicePriority(d.type_string());
63}
64
65static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
66 // First sort by prioritized device type (higher is preferred) and
67 // then by device name (lexicographically).
68 auto a_priority = DeviceSet::DeviceTypeOrder(a);
69 auto b_priority = DeviceSet::DeviceTypeOrder(b);
70 if (a_priority != b_priority) {
71 return a_priority > b_priority;
72 }
73
74 return StringPiece(a.type()) < StringPiece(b.type());
75}
76
77std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
78 std::vector<DeviceType> result;
79 std::set<string> seen;
80 for (Device* d : devices_) {
81 const auto& t = d->device_type();
82 if (seen.insert(t).second) {
83 result.emplace_back(t);
84 }
85 }
86 std::sort(result.begin(), result.end(), DeviceTypeComparator);
87 return result;
88}
89
90void DeviceSet::SortPrioritizedDeviceTypeVector(
91 PrioritizedDeviceTypeVector* vector) {
92 if (vector == nullptr) return;
93
94 auto device_sort = [](const PrioritizedDeviceTypeVector::value_type& a,
95 const PrioritizedDeviceTypeVector::value_type& b) {
96 // First look at set priorities.
97 if (a.second != b.second) {
98 return a.second > b.second;
99 }
100 // Then fallback to default priorities.
101 return DeviceTypeComparator(a.first, b.first);
102 };
103
104 std::sort(vector->begin(), vector->end(), device_sort);
105}
106
107void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) {
108 auto device_sort = [](const std::pair<Device*, int32>& a,
109 const std::pair<Device*, int32>& b) {
110 if (a.second != b.second) {
111 return a.second > b.second;
112 }
113
114 const string& a_type_name = a.first->device_type();
115 const string& b_type_name = b.first->device_type();
116 if (a_type_name != b_type_name) {
117 auto a_priority = DeviceFactory::DevicePriority(a_type_name);
118 auto b_priority = DeviceFactory::DevicePriority(b_type_name);
119 if (a_priority != b_priority) {
120 return a_priority > b_priority;
121 }
122 }
123
124 if (a.first->IsLocal() != b.first->IsLocal()) {
125 return a.first->IsLocal();
126 }
127
128 return StringPiece(a.first->name()) < StringPiece(b.first->name());
129 };
130 std::sort(vector->begin(), vector->end(), device_sort);
131}
132
133namespace {
134
135void UpdatePrioritizedVectors(
136 const std::vector<Device*>& devices,
137 PrioritizedDeviceVector* prioritized_devices,
138 PrioritizedDeviceTypeVector* prioritized_device_types) {
139 if (prioritized_devices->size() != devices.size()) {
140 for (Device* d : devices) {
141 prioritized_devices->emplace_back(
142 d, DeviceSet::DeviceTypeOrder(DeviceType(d->device_type())));
143 }
144 DeviceSet::SortPrioritizedDeviceVector(prioritized_devices);
145 }
146
147 if (prioritized_device_types != nullptr &&
148 prioritized_device_types->size() != devices.size()) {
149 std::set<DeviceType> seen;
150 for (const std::pair<Device*, int32>& p : *prioritized_devices) {
151 DeviceType t(p.first->device_type());
152 if (seen.insert(t).second) {
153 prioritized_device_types->emplace_back(t, p.second);
154 }
155 }
156 }
157}
158
159} // namespace
160
161const PrioritizedDeviceVector& DeviceSet::prioritized_devices() const {
162 mutex_lock l(devices_mu_);
163 UpdatePrioritizedVectors(devices_, &prioritized_devices_,
164 /* prioritized_device_types */ nullptr);
165 return prioritized_devices_;
166}
167
168const PrioritizedDeviceTypeVector& DeviceSet::prioritized_device_types() const {
169 mutex_lock l(devices_mu_);
170 UpdatePrioritizedVectors(devices_, &prioritized_devices_,
171 &prioritized_device_types_);
172 return prioritized_device_types_;
173}
174
175} // namespace tensorflow
176