1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_
18
19#include <unordered_map>
20#include <vector>
21
22#include "absl/strings/str_join.h"
23#include "tensorflow/core/common_runtime/device.h"
24#include "tensorflow/core/common_runtime/inspecting_placer.h"
25#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
26#include "tensorflow/core/framework/function.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/lib/core/stringpiece.h"
29#include "tensorflow/core/util/device_name_utils.h"
30#include "tensorflow/core/util/port.h"
31
32namespace tensorflow {
33
34// Represents a node in the disjoint node forest and the
35// accumulated constraints on the device used by that node.
36class Member {
37 public:
38 Member() = default;
39
40 Status SetParentAndSupportedDevices(
41 const Node& node, const std::vector<DeviceType>& types,
42 const DeviceNameUtils::ParsedName* local_address_spec);
43
44 const DeviceNameUtils::ParsedName& requested_device_name() const {
45 return requested_device_name_;
46 }
47
48 Status SetAssignedDeviceName(const string& device_name);
49 Status SetResourceDeviceName(const Node& node);
50 Status SetRequestedDeviceName(const Node& node);
51
52 Status FillPossibleDevices(PossibleDevices* possible_device) const;
53
54 // Returns whether `src_root` is assigned to a CompositeDevice and `this` is
55 // assigned to a physical device.
56 bool IsEdgeFromCompositeDeviceToPhysicalDevice(const Member& src_root) const;
57
58 Status EnsureCompatibilityAcrossResourceEdge(
59 const Node& src, const Member& src_root,
60 const Node& dst, /*dst_root is this*/
61 bool log_device_placement);
62
63 const PrioritizedDeviceTypeVector& supported_device_types() const {
64 return supported_device_types_;
65 }
66
67 // If `dry_run` is true, just sets `new_root` and `old_root` and does not
68 // actually modify anything in the `tree`.
69 static void Merge(std::vector<Member>* tree, int x_root, int y_root,
70 Member** new_root, Member** old_root, bool dry_run);
71
72 // Returns the root node of the disjoint tree to which the node with the
73 // given id is connected.
74 // FindRoot should be called only for debugging or after the members have
75 // been updated with direct root pointers because it does not update
76 // root pointers and can traverse many links. It exists to have
77 // a const version of FindAndUpdateRoot
78 static int FindRoot(const std::vector<Member>& tree, int node_id);
79 static int FindAndUpdateRoot(std::vector<Member>* tree, int node_id);
80
81 Status MergeDeviceNames(const Member& other, bool allow_soft_placement);
82
83 // Updates this to contain the intersection of the device types in
84 // this and "other". If the intersection is empty, returns false and does
85 // not update this. Else returns true and updates this.
86 bool MergeSupportedDevices(const Member& other);
87
88 Status AssignDevice(const Node& node);
89
90 // If user does not explicitly request XLA device and non-XLA device is
91 // supported for this node, use only the non-XLA device. See b/140896502.
92 void MaybeExcludeXlaDevices();
93
94 // Limit the possible devices of this (should be a root) to the device
95 // specifications in `devices`.
96 Status LimitToPossibleDevices(const PossibleDevices& devices,
97 bool allow_soft_placement);
98
99 void set_possible_devices(std::vector<Device*>&& devices) {
100 possible_devices_ = devices;
101 }
102 const std::vector<Device*>& possible_devices() { return possible_devices_; }
103
104 // Returns a (parsed) device name that is based on requested_device_name()
105 // but with potentially cleared device type and ID fields. A field is cleared
106 // if the assigned_device_name does not specify it. If it does, the field
107 // is not cleared because soft placement cannot violate assigned device names.
108 DeviceNameUtils::ParsedName GetSoftDeviceName() const;
109
110 // Same as GetSoftDeviceName but device type and device ID fields are not
111 // cleared if resource device has them set.
112 DeviceNameUtils::ParsedName GetPreferredSoftDeviceName() const;
113
114 string DebugString() const;
115
116 bool has_assigned_device_name() const { return assigned_device_name_.has_id; }
117
118 private:
119 // Updates this to contain the intersection of the device types in
120 // this and `other_devices`.
121 bool MergeSupportedDevices(const PrioritizedDeviceTypeVector& other_devices);
122
123 // The id of the node that is the parent of this one, or its own
124 // id if it is a root. parent <= 0 indicates that this member is invalid.
125 int parent_ = -1;
126
127 // A proxy for the depth of the tree that is used to prefer
128 // connecting smaller trees to larger trees when merging disjoint
129 // sets.
130 int rank_ = 0;
131
132 // Once colocation groups have been formed, the Placer starts actually
133 // choosing devices. All nodes in a group must be assigned to the same
134 // device. Once we assigned the first device to some node in this group,
135 // we set assigned_device_name_index to this device name's index in the
136 // graph.
137 // The `*_device_name_` fields will contain the parsed name of this device
138 // and `possible_devices`, if computed, will contain just this device.
139 // `assigned_device_name_index` is an optimization to avoid parsing and
140 // comparing device names. The value of -1 signals that a single device
141 // has not been chosen yet.
142 int assigned_device_name_index_ = -1;
143
144 // The merged form of the device requested for this node, with those of all of
145 // its children. requested_device_name_ is always kept a specialization (i.e.
146 // DeviceNameUtils::IsSpecification) of assigned_device_name_. When no device
147 // is requested, this field is set to assigned_device_name_. As a
148 // specialization of assigned_device_name_, requested_device_name_ represents
149 // the most specific form of all assigned and requested devices of this node
150 // and its children, if this node is a root. requested_device_name_ is used
151 // to finally select devices for nodes. We can override requested devices due
152 // to resource colocation constraints but not assigned devices (unless soft
153 // placement is on).
154 // INVARIANT: requested_device_name_ is always kept a
155 // DeviceNameUtils::IsSpecification of assigned_device_name_ and
156 // resource_device_name_. This makes requested_device_name_ the "accumulation
157 // of all wishes" about the device.
158 DeviceNameUtils::ParsedName requested_device_name_;
159
160 // The merged form of the device assigned for this node, with
161 // those of all of its children.
162 // This field is used to raise errors due to unsatisfiable constraints.
163 // Can be a partial specification.
164 DeviceNameUtils::ParsedName assigned_device_name_;
165
166 // The merged form of the requested resource device assigned for this node,
167 // with those of all of its children.
168 // This field is used to raise errors due to unsatisfiable constraints.
169 // Can be a partial specification.
170 // resource_device_name_ is initialized with user-requested device on nodes
171 // producing resources, e.g. VarHandleOp.
172 // For historical reasons, with soft placement enabled, Placer can "move"
173 // resources (place resource producing ops on a device different from what
174 // the user explicitly requested) when the colocation group of a resource
175 // producing op contains ops that are not supported on the user-requested
176 // resource device. A classic example of this is a sparse optimizer (only
177 // supported on CPU) used on a GPU variable. In this case, the whole group
178 // will be assigned to some device supported by all ops in the colocation
179 // group. This is a surprising and unfortunate behavior because:
180 // 1. Since soft_placement is on by default, users don't know that their
181 // variables are created on a different device than what they requested.
182 // Among other things, this can lead to surprising poor performance.
183 // 2. Eager runtime cannot "move" resources. The same code can "work" when
184 // wrapped in tf.function but will fail when run eagerly.
185 // 3. Extra complexity here to preserve these resource moving capabilities.
186 DeviceNameUtils::ParsedName resource_device_name_;
187
188 // The intersection of all device types supported by this node,
189 // and those of all of its children, in priority order
190 // of the preferred device.
191 // It is possible that supported_device_types_ has an empty intersection with
192 // requested/assigned/resource devices. We could have detected such cases
193 // as soon as they happen and raise an error. Instead, for historical reasons,
194 // we leave such error detection to the final device picking stage.
195 PrioritizedDeviceTypeVector supported_device_types_;
196
197 // If this node is a root, stores a list of Devices to which this node
198 // and all of its children can be assigned.
199 // `possible_devices` is empty if they have not yet been computed.
200 std::vector<Device*> possible_devices_;
201};
202
203// This class maintains the connected components of a colocation
204// constraint graph, and uses this information to assign a satisfying
205// device placement to the nodes of the graph.
206//
207// This implementation uses the Union-Find algorithm to efficiently maintain the
208// connected components and incrementally adds edges via
209// ColocationGraph::ColocateNodes() invocations.
210//
211// ColocationGraph does not assign any devices to graph nodes. The
212// `log_device_placement` argument is used to log messages when requested
213// device is ignored.
214class ColocationGraph {
215 public:
216 // graph, flib_def, and device_set must not be null and must outlive
217 // this ColocationGraph. default_local_device can be null. If not, must
218 // outlive this.
219 ColocationGraph(const Graph* graph, const FunctionStack& stack,
220 const FunctionLibraryDefinition* flib_def,
221 const DeviceSet* device_set,
222 const Device* default_local_device, bool allow_soft_placement,
223 bool log_device_placement);
224
225 Status Initialize();
226
227 const std::vector<Member>& members() const { return members_; }
228
229 // Limit the group containing `node` to the device specifications in
230 // `devices`.
231 Status LimitToPossibleDevices(const Node& node,
232 const PossibleDevices& devices);
233
234 // Limits the possible devices of `node`'s colocation group to the device
235 // to which `node` is assigned. This makes sure that all nodes in this
236 // colocation group will be assigned to the same device. Without this
237 // explicit restriction, heuristics can choose a different possible device
238 // for other nodes in the group.
239 Status LimitToAssignedDevice(const Node& node);
240
241 // Returns the root node of the disjoint tree to which the node with the
242 // given id is connected.
243 // Updates the internal pointers so that future calls will returns faster.
244 int FindAndUpdateRoot(int node_id) {
245 return Member::FindAndUpdateRoot(&members_, node_id);
246 }
247
248 // For the given node, subject to the constraints previously given
249 // to this ColocationGraph, set its assigned_device_name. Returns OK
250 // if a satisfying device can be found, otherwise an error.
251 //
252 // Note: This method returns a pointer to a field within members_.
253 // The caller must not use the returned pointer after there is any possibility
254 // that the members_[i].possible_devices field has been modified.
255 Status GetDevicesForNode(Node* node,
256 const std::vector<Device*>** possible_devices);
257
258 // Returns debugging info for the node referred to by 'node_root'.
259 string DebugInfo(const int node_root) const;
260
261 string DebugString() const;
262
263 // Returns a list of devices having type in supported_device_types. The
264 // returned list is sorted by preferred type (higher numeric type is
265 // preferred).
266 static std::vector<Device*> FilterSupportedDevices(
267 const std::vector<Device*>& devices,
268 const PrioritizedDeviceTypeVector& supported_device_types,
269 const Device* default_local_device);
270
271 private:
272 // Adds each node of the Graph to this ColocationGraph as a singleton.
273 //
274 // NOTE: The implementation assumes that the ids of nodes passed to
275 // this method are dense and zero-based; the memory used will be linear in
276 // the largest node ID.
277 // NOTE: If this method returns an error, *this is left in an undefined
278 // state.
279 Status ColocateAllNodes();
280
281 Status ColocateResourceOrRefEdge(const Node* src, const Node* dst);
282
283 // Adds colocation constraints to data types known not to support copying.
284 Status ColocateUncopiableTypeEdges(
285 std::unordered_set<Node*>* inspection_required);
286
287 // Updates this ColocationGraph by making sure that all nodes
288 // touching resource and/or ref tensors are colocated.
289 // As it iterates over the edges, fills the `inspection_required` set with
290 // the nodes that
291 // PlacerInspectionRequiredOpChecker::IsPlacerInspectionRequired
292 // deems as requiring deep inspection by placer. This is an optimization.
293 // TODO(mdan): Deprecate in favor of ColocateUncopiableTypeEdges.
294 Status ColocateResourceAndRefEdges(
295 std::unordered_set<Node*>* inspection_required);
296
297 // Updates this ColocationGraph by making sure that all nodes having inputs of
298 // a DT_VARIANT data type with a host-only underlying types (e.g. strings) can
299 // be placed only on CPU device. We do that by reverse-DFS traversal from all
300 // nodes that take variant inputs to the node that produces that variant.
301 // TODO(ezhulenev): This function does not yet support "deep op" inspection,
302 // that we have for DT_RESOURCE edges.
303 Status AddHostOnlyDataTypesConstraints();
304
305 Status AddInspectionConstraints(
306 const std::unordered_set<Node*>& inspection_required);
307
308 // Applies colocation groups for `node`'s inputs and outputs to this
309 // ColocationGraph.
310 // `groups` are the colocation groups to which `nodes`'s inputs and outputs
311 // belong.
312 // `node` is a node requiring deep inspection (e.g. a node calling
313 // a function)
314 //
315 // For example, consider a `node` taking two inputs and producing one output
316 // a b
317 // | |
318 // v v
319 // node
320 // |
321 // v
322 // c
323 //
324 // `groups` can tell us that `a` and `c` must be colocated and their device
325 // must be a GPU. `b` might be in a group by itself without any device
326 // restrictions.
327 //
328 // ApplyIOColocationGroups will have an effect of calling
329 // ColocateNodes(a, c) and LimitToPossibleDevices(`a`, "GPU"). The colocation
330 // group of the `node` itself is not directly impacted.
331 //
332 Status ApplyIOColocationGroups(const IOColocationGroups& groups,
333 const Node& node);
334
335 Status ColocateNodeToGroup(
336 std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
337 colocation_group_root,
338 const Node* node, StringPiece colocation_group);
339
340 // Merge the (possibly disjoint) sets containing nodes "x" and
341 // "y". Returns OK if the all nodes in the union of these sets can
342 // be placed on the same device type.
343 //
344 // If this method returns an error, *this is unchanged.
345 Status ColocateNodes(const Node& x, const Node& y);
346
347 // This overload of ColocateNodes() allows a caller to provide the root node
348 // ids for the two nodes. For large graphs, this noticeably reduces the
349 // graph load time.
350 // If this method returns an error, *this is unchanged.
351 Status ColocateNodes(const Node& x, int x_root, const Node& y, int y_root);
352
353 void GetSoftDeviceCandidates(const Node& node, const Member& root_member,
354 int root_id,
355 std::vector<Device*>* possible_devices);
356
357 Status InitializeMembers();
358
359 Status InitializeMemberWithAssignedDevice(const string& assigned_device_name,
360 const string& node_type,
361 Member* member);
362
363 Status InitializeMember(const Node& node, Member* member);
364
365 // Returns the root node of the disjoint tree to which the node with the
366 // given id is connected.
367 // FindRoot should be called only for debugging or after the members have
368 // been updated with direct root pointers because it does not update
369 // root pointers and can traverse many links. It exists to have
370 // a const version of FindAndUpdateRoot
371 int FindRoot(int node_id) const {
372 return Member::FindRoot(members_, node_id);
373 }
374
375 const Graph& graph_;
376 const FunctionStack stack_;
377 std::vector<Member> members_;
378 InspectingPlacer inspecting_placer_;
379 PlacerInspectionRequiredOpChecker inspection_required_checker_;
380 const DeviceSet& device_set_;
381 const std::vector<DeviceType> device_types_;
382 const DeviceNameUtils::ParsedName local_address_spec_;
383 const Device* default_local_device_;
384 const bool allow_soft_placement_;
385 const bool log_device_placement_;
386
387 TF_DISALLOW_COPY_AND_ASSIGN(ColocationGraph);
388};
389
390} // namespace tensorflow
391
392#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLOCATION_GRAPH_H_
393