1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
18
19#include <memory>
20#include <unordered_map>
21#include <vector>
22
23#include "tensorflow/core/common_runtime/device.h"
24#include "tensorflow/core/platform/macros.h"
25#include "tensorflow/core/platform/types.h"
26#include "tensorflow/core/util/device_name_utils.h"
27
28namespace tensorflow {
29
30typedef std::vector<std::pair<Device*, int32>> PrioritizedDeviceVector;
31
32// DeviceSet is a container class for managing the various types of
33// devices used by a model.
34class DeviceSet {
35 public:
36 DeviceSet();
37 ~DeviceSet();
38
39 // Does not take ownership of 'device'.
40 void AddDevice(Device* device) TF_LOCKS_EXCLUDED(devices_mu_);
41
42 // Set the device designated as the "client". This device
43 // must also be registered via AddDevice().
44 void set_client_device(Device* device) {
45 DCHECK(client_device_ == nullptr);
46 client_device_ = device;
47 }
48
49 // Returns a pointer to the device designated as the "client".
50 Device* client_device() const { return client_device_; }
51
52 // Return the list of devices in this set.
53 const std::vector<Device*>& devices() const { return devices_; }
54
55 // Given a DeviceNameUtils::ParsedName (which may have some
56 // wildcards for different components), fills "*devices" with all
57 // devices in "*this" that match "spec".
58 void FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
59 std::vector<Device*>* devices) const;
60
61 // Finds the device with the given "fullname". Returns nullptr if
62 // not found.
63 Device* FindDeviceByName(const string& fullname) const;
64
65 // Return the list of unique device types in this set, ordered
66 // with more preferable devices earlier.
67 std::vector<DeviceType> PrioritizedDeviceTypeList() const;
68
69 // Return the prioritized list of devices in this set.
70 // Devices are prioritized first by `DeviceTypeOrder`, then by name.
71 const PrioritizedDeviceVector& prioritized_devices() const
72 TF_LOCKS_EXCLUDED(devices_mu_);
73
74 // Return the prioritized list of unique device types in this set.
75 //
76 // The list will be ordered by decreasing priority. The priorities (the second
77 // element in the list's `std::pair<DeviceType, int32>`) will be initialized
78 // to the value of `DeviceTypeOrder` for the device types.
79 const PrioritizedDeviceTypeVector& prioritized_device_types() const
80 TF_LOCKS_EXCLUDED(devices_mu_);
81
82 // An order to sort by device types according to system-determined
83 // priority.
84 //
85 // Higher result implies higher priority.
86 static int DeviceTypeOrder(const DeviceType& d);
87
88 // Sorts a PrioritizedDeviceVector according to devices and explicit
89 // priorities.
90 //
91 // After a call to this function, the argument vector will be sorted by
92 // explicit priority (the second element in the `std::pair<DeviceType,
93 // int32>`), then by `DeviceTypeOrder` of the device type, then by device
94 // locality, and lastly by device name.
95 static void SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector);
96
97 // Sorts a PrioritizedDeviceTypeVector according to types and explicit
98 // priorities.
99 //
100 // After a call to this function, the argument vector will be sorted by
101 // explicit priority (the second element in the `std::pair<DeviceType,
102 // int32>`), then by `DeviceTypeOrder` of the device type.
103 static void SortPrioritizedDeviceTypeVector(
104 PrioritizedDeviceTypeVector* vector);
105
106 private:
107 mutable mutex devices_mu_;
108
109 // Not owned.
110 std::vector<Device*> devices_;
111
112 // Cached prioritized vector, created on-the-fly when
113 // prioritized_devices() is called.
114 mutable PrioritizedDeviceVector prioritized_devices_
115 TF_GUARDED_BY(devices_mu_);
116
117 // Cached prioritized vector, created on-the-fly when
118 // prioritized_device_types() is called.
119 mutable PrioritizedDeviceTypeVector prioritized_device_types_
120 TF_GUARDED_BY(devices_mu_);
121
122 // Fullname -> device* for device in devices_.
123 std::unordered_map<string, Device*> device_by_name_;
124
125 // client_device_ points to an element of devices_ that we consider
126 // to be the client device (in this local process).
127 Device* client_device_ = nullptr;
128
129 TF_DISALLOW_COPY_AND_ASSIGN(DeviceSet);
130};
131
132} // namespace tensorflow
133
134#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
135