1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ |
18 | |
19 | #include <unordered_map> |
20 | #include <vector> |
21 | |
22 | #include "absl/strings/str_join.h" |
23 | #include "tensorflow/core/common_runtime/device.h" |
24 | #include "tensorflow/core/common_runtime/inspecting_placer.h" |
25 | #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" |
26 | #include "tensorflow/core/framework/function.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/lib/core/stringpiece.h" |
29 | #include "tensorflow/core/util/device_name_utils.h" |
30 | #include "tensorflow/core/util/port.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // Represents a node in the disjoint node forest and the |
35 | // accumulated constraints on the device used by that node. |
36 | class Member { |
37 | public: |
38 | Member() = default; |
39 | |
40 | Status SetParentAndSupportedDevices( |
41 | const Node& node, const std::vector<DeviceType>& types, |
42 | const DeviceNameUtils::ParsedName* local_address_spec); |
43 | |
44 | const DeviceNameUtils::ParsedName& requested_device_name() const { |
45 | return requested_device_name_; |
46 | } |
47 | |
48 | Status SetAssignedDeviceName(const string& device_name); |
49 | Status SetResourceDeviceName(const Node& node); |
50 | Status SetRequestedDeviceName(const Node& node); |
51 | |
52 | Status FillPossibleDevices(PossibleDevices* possible_device) const; |
53 | |
54 | // Returns whether `src_root` is assigned to a CompositeDevice and `this` is |
55 | // assigned to a physical device. |
56 | bool IsEdgeFromCompositeDeviceToPhysicalDevice(const Member& src_root) const; |
57 | |
58 | Status EnsureCompatibilityAcrossResourceEdge( |
59 | const Node& src, const Member& src_root, |
60 | const Node& dst, /*dst_root is this*/ |
61 | bool log_device_placement); |
62 | |
63 | const PrioritizedDeviceTypeVector& supported_device_types() const { |
64 | return supported_device_types_; |
65 | } |
66 | |
67 | // If `dry_run` is true, just sets `new_root` and `old_root` and does not |
68 | // actually modify anything in the `tree`. |
69 | static void Merge(std::vector<Member>* tree, int x_root, int y_root, |
70 | Member** new_root, Member** old_root, bool dry_run); |
71 | |
72 | // Returns the root node of the disjoint tree to which the node with the |
73 | // given id is connected. |
74 | // FindRoot should be called only for debugging or after the members have |
75 | // been updated with direct root pointers because it does not update |
76 | // root pointers and can traverse many links. It exists to have |
77 | // a const version of FindAndUpdateRoot |
78 | static int FindRoot(const std::vector<Member>& tree, int node_id); |
79 | static int FindAndUpdateRoot(std::vector<Member>* tree, int node_id); |
80 | |
81 | Status MergeDeviceNames(const Member& other, bool allow_soft_placement); |
82 | |
83 | // Updates this to contain the intersection of the device types in |
84 | // this and "other". If the intersection is empty, returns false and does |
85 | // not update this. Else returns true and updates this. |
86 | bool MergeSupportedDevices(const Member& other); |
87 | |
88 | Status AssignDevice(const Node& node); |
89 | |
90 | // If user does not explicitly request XLA device and non-XLA device is |
91 | // supported for this node, use only the non-XLA device. See b/140896502. |
92 | void MaybeExcludeXlaDevices(); |
93 | |
94 | // Limit the possible devices of this (should be a root) to the device |
95 | // specifications in `devices`. |
96 | Status LimitToPossibleDevices(const PossibleDevices& devices, |
97 | bool allow_soft_placement); |
98 | |
99 | void set_possible_devices(std::vector<Device*>&& devices) { |
100 | possible_devices_ = devices; |
101 | } |
102 | const std::vector<Device*>& possible_devices() { return possible_devices_; } |
103 | |
104 | // Returns a (parsed) device name that is based on requested_device_name() |
105 | // but with potentially cleared device type and ID fields. A field is cleared |
106 | // if the assigned_device_name does not specify it. If it does, the field |
107 | // is not cleared because soft placement cannot violate assigned device names. |
108 | DeviceNameUtils::ParsedName GetSoftDeviceName() const; |
109 | |
110 | // Same as GetSoftDeviceName but device type and device ID fields are not |
111 | // cleared if resource device has them set. |
112 | DeviceNameUtils::ParsedName GetPreferredSoftDeviceName() const; |
113 | |
114 | string DebugString() const; |
115 | |
116 | bool has_assigned_device_name() const { return assigned_device_name_.has_id; } |
117 | |
118 | private: |
119 | // Updates this to contain the intersection of the device types in |
120 | // this and `other_devices`. |
121 | bool MergeSupportedDevices(const PrioritizedDeviceTypeVector& other_devices); |
122 | |
123 | // The id of the node that is the parent of this one, or its own |
124 | // id if it is a root. parent <= 0 indicates that this member is invalid. |
125 | int parent_ = -1; |
126 | |
127 | // A proxy for the depth of the tree that is used to prefer |
128 | // connecting smaller trees to larger trees when merging disjoint |
129 | // sets. |
130 | int rank_ = 0; |
131 | |
132 | // Once colocation groups have been formed, the Placer starts actually |
133 | // choosing devices. All nodes in a group must be assigned to the same |
134 | // device. Once we assigned the first device to some node in this group, |
135 | // we set assigned_device_name_index to this device name's index in the |
136 | // graph. |
137 | // The `*_device_name_` fields will contain the parsed name of this device |
138 | // and `possible_devices`, if computed, will contain just this device. |
139 | // `assigned_device_name_index` is an optimization to avoid parsing and |
140 | // comparing device names. The value of -1 signals that a single device |
141 | // has not been chosen yet. |
142 | int assigned_device_name_index_ = -1; |
143 | |
144 | // The merged form of the device requested for this node, with those of all of |
145 | // its children. requested_device_name_ is always kept a specialization (i.e. |
146 | // DeviceNameUtils::IsSpecification) of assigned_device_name_. When no device |
147 | // is requested, this field is set to assigned_device_name_. As a |
148 | // specialization of assigned_device_name_, requested_device_name_ represents |
149 | // the most specific form of all assigned and requested devices of this node |
150 | // and its children, if this node is a root. requested_device_name_ is used |
151 | // to finally select devices for nodes. We can override requested devices due |
152 | // to resource colocation constraints but not assigned devices (unless soft |
153 | // placement is on). |
154 | // INVARIANT: requested_device_name_ is always kept a |
155 | // DeviceNameUtils::IsSpecification of assigned_device_name_ and |
156 | // resource_device_name_. This makes requested_device_name_ the "accumulation |
157 | // of all wishes" about the device. |
158 | DeviceNameUtils::ParsedName requested_device_name_; |
159 | |
160 | // The merged form of the device assigned for this node, with |
161 | // those of all of its children. |
162 | // This field is used to raise errors due to unsatisfiable constraints. |
163 | // Can be a partial specification. |
164 | DeviceNameUtils::ParsedName assigned_device_name_; |
165 | |
166 | // The merged form of the requested resource device assigned for this node, |
167 | // with those of all of its children. |
168 | // This field is used to raise errors due to unsatisfiable constraints. |
169 | // Can be a partial specification. |
170 | // resource_device_name_ is initialized with user-requested device on nodes |
171 | // producing resources, e.g. VarHandleOp. |
172 | // For historical reasons, with soft placement enabled, Placer can "move" |
173 | // resources (place resource producing ops on a device different from what |
174 | // the user explicitly requested) when the colocation group of a resource |
175 | // producing op contains ops that are not supported on the user-requested |
176 | // resource device. A classic example of this is a sparse optimizer (only |
177 | // supported on CPU) used on a GPU variable. In this case, the whole group |
178 | // will be assigned to some device supported by all ops in the colocation |
179 | // group. This is a surprising and unfortunate behavior because: |
180 | // 1. Since soft_placement is on by default, users don't know that their |
181 | // variables are created on a different device than what they requested. |
182 | // Among other things, this can lead to surprising poor performance. |
183 | // 2. Eager runtime cannot "move" resources. The same code can "work" when |
184 | // wrapped in tf.function but will fail when run eagerly. |
185 | // 3. Extra complexity here to preserve these resource moving capabilities. |
186 | DeviceNameUtils::ParsedName resource_device_name_; |
187 | |
188 | // The intersection of all device types supported by this node, |
189 | // and those of all of its children, in priority order |
190 | // of the preferred device. |
191 | // It is possible that supported_device_types_ has an empty intersection with |
192 | // requested/assigned/resource devices. We could have detected such cases |
193 | // as soon as they happen and raise an error. Instead, for historical reasons, |
194 | // we leave such error detection to the final device picking stage. |
195 | PrioritizedDeviceTypeVector supported_device_types_; |
196 | |
197 | // If this node is a root, stores a list of Devices to which this node |
198 | // and all of its children can be assigned. |
199 | // `possible_devices` is empty if they have not yet been computed. |
200 | std::vector<Device*> possible_devices_; |
201 | }; |
202 | |
203 | // This class maintains the connected components of a colocation |
204 | // constraint graph, and uses this information to assign a satisfying |
205 | // device placement to the nodes of the graph. |
206 | // |
207 | // This implementation uses the Union-Find algorithm to efficiently maintain the |
208 | // connected components and incrementally adds edges via |
209 | // ColocationGraph::ColocateNodes() invocations. |
210 | // |
211 | // ColocationGraph does not assign any devices to graph nodes. The |
212 | // `log_device_placement` argument is used to log messages when requested |
213 | // device is ignored. |
214 | class ColocationGraph { |
215 | public: |
216 | // graph, flib_def, and device_set must not be null and must outlive |
217 | // this ColocationGraph. default_local_device can be null. If not, must |
218 | // outlive this. |
219 | ColocationGraph(const Graph* graph, const FunctionStack& stack, |
220 | const FunctionLibraryDefinition* flib_def, |
221 | const DeviceSet* device_set, |
222 | const Device* default_local_device, bool allow_soft_placement, |
223 | bool log_device_placement); |
224 | |
225 | Status Initialize(); |
226 | |
227 | const std::vector<Member>& members() const { return members_; } |
228 | |
229 | // Limit the group containing `node` to the device specifications in |
230 | // `devices`. |
231 | Status LimitToPossibleDevices(const Node& node, |
232 | const PossibleDevices& devices); |
233 | |
234 | // Limits the possible devices of `node`'s colocation group to the device |
235 | // to which `node` is assigned. This makes sure that all nodes in this |
236 | // colocation group will be assigned to the same device. Without this |
237 | // explicit restriction, heuristics can choose a different possible device |
238 | // for other nodes in the group. |
239 | Status LimitToAssignedDevice(const Node& node); |
240 | |
241 | // Returns the root node of the disjoint tree to which the node with the |
242 | // given id is connected. |
243 | // Updates the internal pointers so that future calls will returns faster. |
244 | int FindAndUpdateRoot(int node_id) { |
245 | return Member::FindAndUpdateRoot(&members_, node_id); |
246 | } |
247 | |
248 | // For the given node, subject to the constraints previously given |
249 | // to this ColocationGraph, set its assigned_device_name. Returns OK |
250 | // if a satisfying device can be found, otherwise an error. |
251 | // |
252 | // Note: This method returns a pointer to a field within members_. |
253 | // The caller must not use the returned pointer after there is any possibility |
254 | // that the members_[i].possible_devices field has been modified. |
255 | Status GetDevicesForNode(Node* node, |
256 | const std::vector<Device*>** possible_devices); |
257 | |
258 | // Returns debugging info for the node referred to by 'node_root'. |
259 | string DebugInfo(const int node_root) const; |
260 | |
261 | string DebugString() const; |
262 | |
263 | // Returns a list of devices having type in supported_device_types. The |
264 | // returned list is sorted by preferred type (higher numeric type is |
265 | // preferred). |
266 | static std::vector<Device*> FilterSupportedDevices( |
267 | const std::vector<Device*>& devices, |
268 | const PrioritizedDeviceTypeVector& supported_device_types, |
269 | const Device* default_local_device); |
270 | |
271 | private: |
272 | // Adds each node of the Graph to this ColocationGraph as a singleton. |
273 | // |
274 | // NOTE: The implementation assumes that the ids of nodes passed to |
275 | // this method are dense and zero-based; the memory used will be linear in |
276 | // the largest node ID. |
277 | // NOTE: If this method returns an error, *this is left in an undefined |
278 | // state. |
279 | Status ColocateAllNodes(); |
280 | |
281 | Status ColocateResourceOrRefEdge(const Node* src, const Node* dst); |
282 | |
283 | // Adds colocation constraints to data types known not to support copying. |
284 | Status ColocateUncopiableTypeEdges( |
285 | std::unordered_set<Node*>* inspection_required); |
286 | |
287 | // Updates this ColocationGraph by making sure that all nodes |
288 | // touching resource and/or ref tensors are colocated. |
289 | // As it iterates over the edges, fills the `inspection_required` set with |
290 | // the nodes that |
291 | // PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired |
292 | // deems as requiring deep inspection by placer. This is an optimization. |
293 | // TODO(mdan): Deprecate in favor of ColocateUncopiableTypeEdges. |
294 | Status ColocateResourceAndRefEdges( |
295 | std::unordered_set<Node*>* inspection_required); |
296 | |
297 | // Updates this ColocationGraph by making sure that all nodes having inputs of |
298 | // a DT_VARIANT data type with a host-only underlying types (e.g. strings) can |
299 | // be placed only on CPU device. We do that by reverse-DFS traversal from all |
300 | // nodes that take variant inputs to the node that produces that variant. |
301 | // TODO(ezhulenev): This function does not yet support "deep op" inspection, |
302 | // that we have for DT_RESOURCE edges. |
303 | Status AddHostOnlyDataTypesConstraints(); |
304 | |
305 | Status AddInspectionConstraints( |
306 | const std::unordered_set<Node*>& inspection_required); |
307 | |
308 | // Applies colocation groups for `node`'s inputs and outputs to this |
309 | // ColocationGraph. |
310 | // `groups` are the colocation groups to which `nodes`'s inputs and outputs |
311 | // belong. |
312 | // `node` is a node requiring deep inspection (e.g. a node calling |
313 | // a function) |
314 | // |
315 | // For example, consider a `node` taking two inputs and producing one output |
316 | // a b |
317 | // | | |
318 | // v v |
319 | // node |
320 | // | |
321 | // v |
322 | // c |
323 | // |
324 | // `groups` can tell us that `a` and `c` must be colocated and their device |
325 | // must be a GPU. `b` might be in a group by itself without any device |
326 | // restrictions. |
327 | // |
328 | // ApplyIOColocationGroups will have an effect of calling |
329 | // ColocateNodes(a, c) and LimitToPossibleDevices(`a`, "GPU"). The colocation |
330 | // group of the `node` itself is not directly impacted. |
331 | // |
332 | Status ApplyIOColocationGroups(const IOColocationGroups& groups, |
333 | const Node& node); |
334 | |
335 | Status ColocateNodeToGroup( |
336 | std::unordered_map<StringPiece, const Node*, StringPieceHasher>* |
337 | colocation_group_root, |
338 | const Node* node, StringPiece colocation_group); |
339 | |
340 | // Merge the (possibly disjoint) sets containing nodes "x" and |
341 | // "y". Returns OK if the all nodes in the union of these sets can |
342 | // be placed on the same device type. |
343 | // |
344 | // If this method returns an error, *this is unchanged. |
345 | Status ColocateNodes(const Node& x, const Node& y); |
346 | |
347 | // This overload of ColocateNodes() allows a caller to provide the root node |
348 | // ids for the two nodes. For large graphs, this noticeably reduces the |
349 | // graph load time. |
350 | // If this method returns an error, *this is unchanged. |
351 | Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root); |
352 | |
353 | void GetSoftDeviceCandidates(const Node& node, const Member& root_member, |
354 | int root_id, |
355 | std::vector<Device*>* possible_devices); |
356 | |
357 | Status InitializeMembers(); |
358 | |
359 | Status InitializeMemberWithAssignedDevice(const string& assigned_device_name, |
360 | const string& node_type, |
361 | Member* member); |
362 | |
363 | Status InitializeMember(const Node& node, Member* member); |
364 | |
365 | // Returns the root node of the disjoint tree to which the node with the |
366 | // given id is connected. |
367 | // FindRoot should be called only for debugging or after the members have |
368 | // been updated with direct root pointers because it does not update |
369 | // root pointers and can traverse many links. It exists to have |
370 | // a const version of FindAndUpdateRoot |
371 | int FindRoot(int node_id) const { |
372 | return Member::FindRoot(members_, node_id); |
373 | } |
374 | |
375 | const Graph& graph_; |
376 | const FunctionStack stack_; |
377 | std::vector<Member> members_; |
378 | InspectingPlacer inspecting_placer_; |
379 | PlacerInspectionRequiredOpChecker inspection_required_checker_; |
380 | const DeviceSet& device_set_; |
381 | const std::vector<DeviceType> device_types_; |
382 | const DeviceNameUtils::ParsedName local_address_spec_; |
383 | const Device* default_local_device_; |
384 | const bool allow_soft_placement_; |
385 | const bool log_device_placement_; |
386 | |
387 | TF_DISALLOW_COPY_AND_ASSIGN(ColocationGraph); |
388 | }; |
389 | |
390 | } // namespace tensorflow |
391 | |
392 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_ |
393 | |