1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/colocation_graph.h" |
17 | |
18 | #include <memory> |
19 | #include <set> |
20 | #include <unordered_map> |
21 | #include <unordered_set> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "absl/algorithm/container.h" |
26 | #include "absl/container/flat_hash_set.h" |
27 | #include "absl/strings/str_join.h" |
28 | #include "absl/types/optional.h" |
29 | #include "tensorflow/core/common_runtime/composite_device.h" |
30 | #include "tensorflow/core/common_runtime/device.h" |
31 | #include "tensorflow/core/common_runtime/device_set.h" |
32 | #include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h" |
33 | #include "tensorflow/core/common_runtime/inspecting_placer.h" |
34 | #include "tensorflow/core/common_runtime/partitioning_utils.h" |
35 | #include "tensorflow/core/framework/attr_value.pb.h" |
36 | #include "tensorflow/core/framework/attr_value_util.h" |
37 | #include "tensorflow/core/framework/dataset.h" |
38 | #include "tensorflow/core/framework/device_attributes.pb.h" |
39 | #include "tensorflow/core/framework/full_type.pb.h" |
40 | #include "tensorflow/core/framework/full_type_util.h" |
41 | #include "tensorflow/core/framework/function.h" |
42 | #include "tensorflow/core/framework/node_def_util.h" |
43 | #include "tensorflow/core/framework/op_kernel.h" |
44 | #include "tensorflow/core/framework/types.h" |
45 | #include "tensorflow/core/framework/types.pb.h" |
46 | #include "tensorflow/core/graph/algorithm.h" |
47 | #include "tensorflow/core/graph/graph_node_util.h" |
48 | #include "tensorflow/core/lib/core/errors.h" |
49 | #include "tensorflow/core/lib/core/stringpiece.h" |
50 | #include "tensorflow/core/lib/strings/str_util.h" |
51 | #include "tensorflow/core/lib/strings/strcat.h" |
52 | #include "tensorflow/core/util/device_name_utils.h" |
53 | #include "tensorflow/core/util/dump_graph.h" |
54 | #include "tensorflow/core/util/port.h" |
55 | |
56 | namespace tensorflow { |
57 | |
58 | namespace { |
59 | |
60 | // We hoist the conversion from C-style string literal to StringPiece here, |
61 | // so that we can avoid the many repeated calls to strlen(). |
62 | const StringPiece kColocationAttrNameStringPiece(kColocationAttrName); |
63 | const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix); |
64 | |
65 | // Using absl::StrJoin with lambda does not work in tf-lite builds. |
66 | std::vector<string> DevicesToString(const std::vector<Device*> devices) { |
67 | std::vector<string> v; |
68 | v.reserve(devices.size()); |
69 | for (Device* d : devices) { |
70 | v.push_back(d->name()); |
71 | } |
72 | return v; |
73 | } |
74 | |
75 | // Using absl::StrJoin with lambda does not work in tf-lite builds. |
76 | std::vector<string> DeviceTypeAndPriorityToString( |
77 | const PrioritizedDeviceTypeVector& devices) { |
78 | std::vector<string> v; |
79 | v.reserve(devices.size()); |
80 | for (const std::pair<DeviceType, int32>& device_and_type : devices) { |
81 | v.push_back(DeviceTypeString(device_and_type.first)); |
82 | } |
83 | return v; |
84 | } |
85 | |
86 | bool IsRefOrResource(DataType data_type) { |
87 | return IsRefType(data_type) || data_type == DT_RESOURCE; |
88 | } |
89 | |
90 | // While Placer can override requested device on ops processing |
91 | // resources, i.e. node that take (and potentially return) a resource, |
92 | // it must not override requested device on ops generating a resource, |
93 | // e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref |
94 | // output nodes. |
95 | bool IsRefOrResourceGeneratorNode(const Node& node) { |
96 | return node.num_inputs() == 0 && node.num_outputs() == 1 && |
97 | IsRefOrResource(node.output_type(0)); |
98 | } |
99 | |
100 | bool IsExemptFromResourceInputColocation(const Node* node) { |
101 | // Note: Partitioned function calls, which place and partition their |
102 | // function bodies, are exempt from this check: they forward resource and |
103 | // ref inputs to operations that are appropriately placed, instead of |
104 | // dereferencing them. |
105 | const string& op_type = node->op_def().name(); |
106 | auto exempt_ops = InputColocationExemptionRegistry::Global()->Get(); |
107 | return exempt_ops.find(op_type) != exempt_ops.end(); |
108 | } |
109 | |
110 | bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) { |
111 | for (const auto& prioritized_device_type : device_types) { |
112 | if (prioritized_device_type.second != 0) return true; |
113 | } |
114 | return false; |
115 | } |
116 | |
117 | bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types, |
118 | const PrioritizedDeviceTypeVector& b_types) { |
119 | if (a_types.size() != b_types.size()) { |
120 | return false; |
121 | } |
122 | for (int i = 0; i < a_types.size(); ++i) { |
123 | if (a_types[i].first != b_types[i].first) { |
124 | return false; |
125 | } |
126 | } |
127 | return true; |
128 | } |
129 | |
130 | bool IsXlaDevice(absl::string_view device_type) { |
131 | if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" || |
132 | device_type == "XLA_TPU_JIT" ) { |
133 | // Symbolic XLA device. |
134 | return true; |
135 | } |
136 | |
137 | return (device_type == "XLA_CPU" || device_type == "XLA_GPU" || |
138 | device_type == "TPU" ); |
139 | } |
140 | |
141 | bool IsCompositeDevice(absl::string_view device_type) { |
142 | return device_type == kCompositeDeviceType; |
143 | } |
144 | |
145 | // TODO(mdan): This is still too coarse. |
146 | // Host-memory constraints are specific to kernel registrations, so in theory |
147 | // they depend on the assigned device. |
148 | // So we need a constraint model of the kind: <<node device>>: <<output_device>> |
149 | bool HasHostMemoryOutType(const Node& node) { |
150 | if (!node.def().has_experimental_type()) { |
151 | return false; |
152 | } |
153 | const FullTypeDef& ft = node.def().experimental_type(); |
154 | DCHECK(ft.type_id() == TFT_PRODUCT) << ft.DebugString(); |
155 | |
156 | for (const auto& arg : ft.args()) { |
157 | if (full_type::IsHostMemoryType(arg)) { |
158 | return true; |
159 | } |
160 | } |
161 | |
162 | return false; |
163 | } |
164 | } // namespace |
165 | |
166 | Status Member::SetParentAndSupportedDevices( |
167 | const Node& node, const std::vector<DeviceType>& types, |
168 | const DeviceNameUtils::ParsedName* local_address_spec) { |
169 | int id = node.id(); |
170 | if (id < 0) { |
171 | return errors::Internal("Placer should not be creating a Member for node: " , |
172 | node.DebugString()); |
173 | } |
174 | parent_ = id; |
175 | return SupportedDeviceTypesForNode( |
176 | types, node.def(), &supported_device_types_, local_address_spec); |
177 | } |
178 | |
179 | Status Member::SetAssignedDeviceName(const string& device_name) { |
180 | if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) { |
181 | return errors::Internal( |
182 | "Setting assigned device name when there is a requested device set " |
183 | "is unsupported" ); |
184 | } |
185 | if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) { |
186 | return errors::Internal("Malformed assigned device '" , device_name, "'" ); |
187 | } |
188 | // Set requested device to assigned_device to maintain the invariant that |
189 | // requested is a specialization of assigned. |
190 | requested_device_name_ = assigned_device_name_; |
191 | return OkStatus(); |
192 | } |
193 | |
194 | Status Member::SetResourceDeviceName(const Node& node) { |
195 | if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) { |
196 | return errors::Internal( |
197 | "Setting resource device name when there is a requested device set " |
198 | "is unsupported" ); |
199 | } |
200 | |
201 | if (!DeviceNameUtils::ParseFullName(node.requested_device(), |
202 | &resource_device_name_)) { |
203 | return errors::InvalidArgument("Malformed device specification '" , |
204 | node.requested_device(), |
205 | "' in node: " , node.DebugString()); |
206 | } |
207 | |
208 | // Set requested device to resource device to maintain the invariant that |
209 | // requested is a specialization of resource. |
210 | requested_device_name_ = resource_device_name_; |
211 | return OkStatus(); |
212 | } |
213 | |
214 | Status Member::SetRequestedDeviceName(const Node& node) { |
215 | if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) { |
216 | return errors::Internal( |
217 | "Setting requested device name when there is an assigned device set " |
218 | "is unsupported" ); |
219 | } |
220 | if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) { |
221 | return errors::Internal( |
222 | "Setting requested device name when there is a resource device set " |
223 | "is unsupported" ); |
224 | } |
225 | if (!DeviceNameUtils::ParseFullName(node.requested_device(), |
226 | &requested_device_name_)) { |
227 | return errors::InvalidArgument("Malformed device specification '" , |
228 | node.requested_device(), |
229 | "' in node: " , node.DebugString()); |
230 | } |
231 | return OkStatus(); |
232 | } |
233 | |
234 | Status Member::FillPossibleDevices(PossibleDevices* possible_device) const { |
235 | if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) { |
236 | return errors::Internal( |
237 | "Cannot fill PossibleDevices from a member that has non-empty assigned " |
238 | "device. Did we start assigning devices to functions called by deep " |
239 | "ops? " , |
240 | DebugString()); |
241 | } |
242 | possible_device->requested_device_name = requested_device_name_; |
243 | possible_device->resource_device_name = resource_device_name_; |
244 | possible_device->device_types = supported_device_types_; |
245 | return OkStatus(); |
246 | } |
247 | |
248 | bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice( |
249 | const Member& src_root) const { |
250 | auto compatible_edge_from_composite_device_to_physical_device = |
251 | [](const DeviceNameUtils::ParsedName& src_device, |
252 | const DeviceNameUtils::ParsedName& dst_device) -> bool { |
253 | return src_device.has_type && dst_device.has_type && |
254 | IsCompositeDevice(src_device.type) && |
255 | !IsCompositeDevice(dst_device.type); |
256 | }; |
257 | if (compatible_edge_from_composite_device_to_physical_device( |
258 | src_root.assigned_device_name_, assigned_device_name_) || |
259 | compatible_edge_from_composite_device_to_physical_device( |
260 | src_root.resource_device_name_, resource_device_name_) || |
261 | compatible_edge_from_composite_device_to_physical_device( |
262 | src_root.requested_device_name_, requested_device_name_)) { |
263 | return true; |
264 | } |
265 | return false; |
266 | } |
267 | |
268 | Status Member::EnsureCompatibilityAcrossResourceEdge( |
269 | const Node& src, const Member& src_root, |
270 | const Node& dst, /*dst_root is this*/ |
271 | bool log_device_placement) { |
272 | if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_, |
273 | assigned_device_name_)) { |
274 | return errors::InvalidArgument( |
275 | "Cannot place the graph because a reference or resource edge " |
276 | "connects colocation groups with incompatible assigned devices: " , |
277 | DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_), |
278 | " vs " , DeviceNameUtils::ParsedNameToString(assigned_device_name_), |
279 | ". The edge src node is name='" , src.name(), "' (op='" , src.def().op(), |
280 | "'), and the dst node is name='" , dst.name(), "' (op='" , dst.def().op(), |
281 | "')." ); |
282 | } |
283 | |
284 | if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_, |
285 | resource_device_name_)) { |
286 | return errors::InvalidArgument( |
287 | "Cannot place the graph because a reference or resource edge " |
288 | "connects colocation groups with incompatible resource devices: " , |
289 | DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_), |
290 | " vs " , DeviceNameUtils::ParsedNameToString(resource_device_name_), |
291 | ". The edge src node is name='" , src.name(), "' (op='" , src.def().op(), |
292 | "'), and the dst node is name='" , dst.name(), "' (op='" , dst.def().op(), |
293 | "')." ); |
294 | } |
295 | |
296 | if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_, |
297 | requested_device_name_)) { |
298 | return OkStatus(); |
299 | } |
300 | |
301 | // If we are here, assigned and resource devices are compatible but requested |
302 | // ones are not. We will be overriding the requested device for destination |
303 | // node, but need to preserve the invariant that it will be a specialization |
304 | // of the assigned and resource devices. |
305 | if (log_device_placement) { |
306 | LOG(INFO) << "Ignoring device specification " |
307 | << DeviceNameUtils::ParsedNameToString(requested_device_name_) |
308 | << " for node '" << dst.name() |
309 | << "' because the input edge from '" << src.name() |
310 | << "' is a reference connection and already has a device " |
311 | "field set to " |
312 | << DeviceNameUtils::ParsedNameToString( |
313 | src_root.requested_device_name_); |
314 | } |
315 | requested_device_name_ = src_root.requested_device_name_; |
316 | DeviceNameUtils::EnsureSpecification(&requested_device_name_, |
317 | assigned_device_name_); |
318 | DeviceNameUtils::EnsureSpecification(&requested_device_name_, |
319 | resource_device_name_); |
320 | return OkStatus(); |
321 | } |
322 | |
323 | void Member::Merge(std::vector<Member>* tree, int x_root, int y_root, |
324 | Member** new_root, Member** old_root, bool dry_run) { |
325 | Member& x_root_member = (*tree)[x_root]; |
326 | Member& y_root_member = (*tree)[y_root]; |
327 | |
328 | // Merge the sets by setting the parent pointer of the smaller tree's root |
329 | // node to point to the root of the larger tree. Together with path |
330 | // compression in ColocationGraph::FindRoot, this ensures that we do not |
331 | // experience pathological performance on graphs such as chains. |
332 | int new_root_id, old_root_id; |
333 | if (x_root_member.rank_ < y_root_member.rank_) { |
334 | // The tree rooted at x_root is shallower, so connect it to |
335 | // y_root. The rank of y_root is unchanged because its new |
336 | // child has strictly less rank. |
337 | if (!dry_run) { |
338 | x_root_member.parent_ = y_root; |
339 | } |
340 | new_root_id = y_root; |
341 | old_root_id = x_root; |
342 | } else if (x_root_member.rank_ > y_root_member.rank_) { |
343 | // The tree rooted at y_root is shallower, so connect it to |
344 | // x_root. The rank of x_root is unchanged because its new |
345 | // child has strictly less rank. |
346 | if (!dry_run) { |
347 | y_root_member.parent_ = x_root; |
348 | } |
349 | new_root_id = x_root; |
350 | old_root_id = y_root; |
351 | } else { |
352 | if (!dry_run) { |
353 | // Both trees have the same rank, so break the tie by choosing |
354 | // x_root as the new root. |
355 | y_root_member.parent_ = x_root; |
356 | // Increment the rank of the tree rooted at x_root, because it |
357 | // is now strictly deeper than before. |
358 | ++x_root_member.rank_; |
359 | } |
360 | new_root_id = x_root; |
361 | old_root_id = y_root; |
362 | } |
363 | |
364 | *new_root = &(*tree)[new_root_id]; |
365 | *old_root = &(*tree)[old_root_id]; |
366 | } |
367 | |
368 | // tree is non-const because we can change some `parent` pointers in some |
369 | // members for more efficient future lookups. The vector itself is not |
370 | // changed. |
371 | int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) { |
372 | Member& member = (*tree)[node_id]; |
373 | if (member.parent_ == node_id) { |
374 | // member.parent is the root of this disjoint tree. Do nothing. |
375 | } else { |
376 | member.parent_ = FindAndUpdateRoot(tree, member.parent_); |
377 | } |
378 | // Now it is guaranteed that member.parent is the root of this disjoint |
379 | // tree. |
380 | return member.parent_; |
381 | } |
382 | |
383 | int Member::FindRoot(const std::vector<Member>& tree, int node_id) { |
384 | const Member& member = tree[node_id]; |
385 | if (member.parent_ == node_id) { |
386 | return member.parent_; |
387 | } |
388 | return FindRoot(tree, member.parent_); |
389 | } |
390 | |
391 | Status Member::MergeDeviceNames(const Member& other, |
392 | bool allow_soft_placement) { |
393 | // Assuming the "requested is a specialization of assigned and resource |
394 | // devices" invariant holds for this and `other`, it will hold after the |
395 | // merges below. |
396 | DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_; |
397 | TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( |
398 | &assigned_device_name_copy, other.assigned_device_name_)); |
399 | |
400 | DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_; |
401 | TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( |
402 | &resource_device_name_copy, other.resource_device_name_)); |
403 | |
404 | DeviceNameUtils::ParsedName requested_device_name_copy = |
405 | requested_device_name_; |
406 | TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( |
407 | &requested_device_name_copy, other.requested_device_name_, |
408 | allow_soft_placement)); |
409 | |
410 | DeviceNameUtils::EnsureSpecification(&requested_device_name_copy, |
411 | assigned_device_name_copy); |
412 | DeviceNameUtils::EnsureSpecification(&requested_device_name_copy, |
413 | resource_device_name_copy); |
414 | |
415 | // We checked for all errors, now change the devices. |
416 | assigned_device_name_ = std::move(assigned_device_name_copy); |
417 | resource_device_name_ = std::move(resource_device_name_copy); |
418 | requested_device_name_ = std::move(requested_device_name_copy); |
419 | return OkStatus(); |
420 | } |
421 | |
422 | // Updates this to contain the intersection of the device types in |
423 | // this and "other". |
424 | bool Member::MergeSupportedDevices(const Member& other) { |
425 | return MergeSupportedDevices(other.supported_device_types_); |
426 | } |
427 | |
428 | bool Member::MergeSupportedDevices( |
429 | const PrioritizedDeviceTypeVector& other_devices) { |
430 | // Generate intersection with priorities. |
431 | // Each vector contains the same device types but with different priorities. |
432 | // The priorities are taken from the corresponding source vector. |
433 | PrioritizedDeviceTypeVector target_intersection; |
434 | PrioritizedDeviceTypeVector other_intersection; |
435 | |
436 | for (const auto& prioritized_device_type : supported_device_types_) { |
437 | bool found = false; |
438 | for (const auto& other_prioritized_device_type : other_devices) { |
439 | if (prioritized_device_type.first == |
440 | other_prioritized_device_type.first) { |
441 | found = true; |
442 | other_intersection.push_back(other_prioritized_device_type); |
443 | break; |
444 | } |
445 | } |
446 | if (found) { |
447 | target_intersection.push_back(prioritized_device_type); |
448 | } |
449 | } |
450 | |
451 | DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection); |
452 | DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection); |
453 | |
454 | PrioritizedDeviceTypeVector result; |
455 | |
456 | bool is_target_prioritized = HasPriorities(target_intersection); |
457 | bool is_other_prioritized = HasPriorities(other_intersection); |
458 | if (!is_target_prioritized && !is_other_prioritized) { |
459 | // If neither are prioritized then we just return the original i.e. target |
460 | // prioritization. |
461 | result = target_intersection; |
462 | } else if (is_target_prioritized && !is_other_prioritized) { |
463 | // If only one is prioritized, then we respect priorities of that in the |
464 | // intersection. |
465 | result = target_intersection; |
466 | } else if (!is_target_prioritized && is_other_prioritized) { |
467 | result = other_intersection; |
468 | } else { |
469 | // If both have priorities and agree then we go with that. If the |
470 | // prioritization order is different, then we just fallback to the default |
471 | // i.e. what the DeviceTypeOrder suggests. In that case, we also set the |
472 | // merged priorities to 0, so that downstream merges work correctly as well. |
473 | if (ArePrioritiesSame(target_intersection, other_intersection)) { |
474 | result = target_intersection; |
475 | } else { |
476 | for (const auto& prioritized_device : target_intersection) { |
477 | result.push_back(std::make_pair(prioritized_device.first, 0)); |
478 | } |
479 | DeviceSet::SortPrioritizedDeviceTypeVector(&result); |
480 | } |
481 | } |
482 | |
483 | if (result.empty()) { |
484 | return false; |
485 | } |
486 | supported_device_types_ = result; |
487 | return true; |
488 | } |
489 | |
490 | Status Member::AssignDevice(const Node& node) { |
491 | if (node.assigned_device_name_index() == assigned_device_name_index_) { |
492 | return OkStatus(); |
493 | } |
494 | |
495 | DeviceNameUtils::ParsedName parsed; |
496 | DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed); |
497 | Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed); |
498 | if (!s.ok()) { |
499 | return errors::Internal( |
500 | "Constraining by assigned device should not cause an error. Original " |
501 | "root's assigned device name: " , |
502 | DeviceNameUtils::ParsedNameToString(assigned_device_name_), |
503 | " node's assigned device name \"" , node.assigned_device_name(), |
504 | ". Error: " , s.error_message()); |
505 | } |
506 | s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed); |
507 | if (!s.ok()) { |
508 | return errors::Internal( |
509 | "Constraining by assigned device should not cause an error. Original " |
510 | "root's resource device name: " , |
511 | DeviceNameUtils::ParsedNameToString(resource_device_name_), |
512 | " node's assigned device name \"" , node.assigned_device_name(), |
513 | ". Error: " , s.error_message()); |
514 | } |
515 | s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed); |
516 | if (!s.ok()) { |
517 | return errors::Internal( |
518 | "Constraining by assigned device should not cause an error. Original " |
519 | "root's requested device name: \"" , |
520 | DeviceNameUtils::ParsedNameToString(requested_device_name_), |
521 | "\", node's assigned device name \"" , node.assigned_device_name(), |
522 | "\". Error: " , s.error_message()); |
523 | } |
524 | |
525 | assigned_device_name_index_ = node.assigned_device_name_index(); |
526 | // Clear cached possible_devices, if any. |
527 | possible_devices_.clear(); |
528 | return OkStatus(); |
529 | } |
530 | |
531 | void Member::MaybeExcludeXlaDevices() { |
532 | for (const auto& parsed_name : |
533 | {requested_device_name_, assigned_device_name_, resource_device_name_}) { |
534 | // Don't exculde XLA devices from supported devices if member is explicitly |
535 | // assigned to a CompositeDevice. |
536 | if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) || |
537 | IsCompositeDevice(parsed_name.type))) { |
538 | return; |
539 | } |
540 | } |
541 | |
542 | PrioritizedDeviceTypeVector non_xla_types; |
543 | absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types), |
544 | [&](const std::pair<DeviceType, int32>& entry) { |
545 | return !IsXlaDevice(entry.first.type_string()); |
546 | }); |
547 | |
548 | // TODO(b/141216278) Remove all XLA device types from the supported device |
549 | // types if the node has no requested/assigned/resource XLA device. |
550 | if (!non_xla_types.empty() && |
551 | non_xla_types.size() < supported_device_types_.size()) { |
552 | supported_device_types_ = std::move(non_xla_types); |
553 | } |
554 | } |
555 | |
556 | Status Member::LimitToPossibleDevices(const PossibleDevices& devices, |
557 | bool allow_soft_placement) { |
558 | TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( |
559 | &requested_device_name_, devices.requested_device_name, |
560 | allow_soft_placement)); |
561 | TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames( |
562 | &resource_device_name_, devices.resource_device_name)); |
563 | MergeSupportedDevices(devices.device_types); |
564 | return OkStatus(); |
565 | } |
566 | |
567 | string Member::DebugString() const { |
568 | return absl::StrCat( |
569 | "Member(assigned_device_name_index_=" , assigned_device_name_index_, |
570 | " requested_device_name_='" , |
571 | DeviceNameUtils::ParsedNameToString(requested_device_name_), |
572 | "' assigned_device_name_='" , |
573 | DeviceNameUtils::ParsedNameToString(assigned_device_name_), |
574 | "' resource_device_name_='" , |
575 | DeviceNameUtils::ParsedNameToString(resource_device_name_), |
576 | "' supported_device_types_=[" , |
577 | absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_), |
578 | ", " ), |
579 | "] possible_devices_=[" , |
580 | absl::StrJoin(DevicesToString(possible_devices_), ", " ), "]" ); |
581 | } |
582 | |
583 | DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const { |
584 | DeviceNameUtils::ParsedName soft_device_name = requested_device_name_; |
585 | if (!assigned_device_name_.has_type) { |
586 | soft_device_name.type.clear(); |
587 | soft_device_name.has_type = false; |
588 | } |
589 | if (!assigned_device_name_.has_id) { |
590 | soft_device_name.has_id = false; |
591 | } |
592 | return soft_device_name; |
593 | } |
594 | |
595 | DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const { |
596 | DeviceNameUtils::ParsedName soft_device_name = requested_device_name_; |
597 | if (!assigned_device_name_.has_type && !resource_device_name_.has_type) { |
598 | soft_device_name.type.clear(); |
599 | soft_device_name.has_type = false; |
600 | } |
601 | if (!assigned_device_name_.has_id && !resource_device_name_.has_id) { |
602 | soft_device_name.has_id = false; |
603 | } |
604 | return soft_device_name; |
605 | } |
606 | |
607 | // Returns ParsedName whose address space (i.e. job, replica, task) identifies |
608 | // the address space directly accessible by the local process. If the address |
609 | // space is fully specified and it is exactly the same as the address space |
610 | // of a device, then all kernels of that device should be registered in the |
611 | // local process. |
612 | static const DeviceNameUtils::ParsedName LocalAddressSpec( |
613 | const Device* client_device, const Device* default_local_device) { |
614 | if (client_device != nullptr) { |
615 | return DeviceNameUtils::AddressSpace(client_device->parsed_name()); |
616 | } |
617 | |
618 | if (default_local_device != nullptr) { |
619 | return DeviceNameUtils::AddressSpace(default_local_device->parsed_name()); |
620 | } |
621 | |
622 | // TODO(b/139617593) Return the name of the first local device in device_set_ |
623 | // once we can trust the output of Device::IsLocal(). |
624 | return DeviceNameUtils::ParsedName(); |
625 | } |
626 | |
627 | ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack, |
628 | const FunctionLibraryDefinition* flib_def, |
629 | const DeviceSet* device_set, |
630 | const Device* default_local_device, |
631 | bool allow_soft_placement, |
632 | bool log_device_placement) |
633 | : graph_(*graph), |
634 | stack_(stack), |
635 | inspecting_placer_(stack, flib_def, device_set, default_local_device, |
636 | allow_soft_placement, log_device_placement), |
637 | inspection_required_checker_(graph, flib_def), |
638 | device_set_(*device_set), |
639 | device_types_(device_set->PrioritizedDeviceTypeList()), |
640 | local_address_spec_( |
641 | LocalAddressSpec(device_set->client_device(), default_local_device)), |
642 | default_local_device_(default_local_device), |
643 | allow_soft_placement_(allow_soft_placement), |
644 | log_device_placement_(log_device_placement) { |
645 | members_.resize(graph_.num_node_ids()); |
646 | } |
647 | |
648 | // Adds each node of the Graph to this ColocationGraph as a singleton. |
649 | // |
650 | // NOTE: The implementation assumes that the ids of nodes passed to |
651 | // this method are dense and zero-based; the memory used will be linear in |
652 | // the largest node ID. |
653 | // NOTE: If this method returns an error, *this is left in an undefined |
654 | // state. |
655 | Status ColocationGraph::ColocateAllNodes() { |
656 | // This maps from a colocation group identifier to the 'root' of that |
657 | // colocation group. Note that the keys in this map are StringPiece; the |
658 | // actual strings are stored under the NodeDef. The lifetime of this map |
659 | // is limited to this ColocateAllNodes() method, and no part of the |
660 | // NodeDef trees are changed during the lifetime of this method, so using |
661 | // StringPiece as a key is safe. |
662 | // |
663 | // Also, as a further optimization, we remove the "loc:@" prefix from |
664 | // "class" attribute values, when they are used as keys in this table. |
665 | // This allows us to use StringPiece values that refer to substrings of |
666 | // 'string' values stored in NodeDef attribute lists, as well as StringPiece |
667 | // values that refer to 'string' values from NodeDef::name(), without |
668 | // performing any string allocations. |
669 | std::unordered_map<StringPiece, const Node*, StringPieceHasher> |
670 | colocation_group_root; |
671 | |
672 | for (const Node* node : graph_.op_nodes()) { |
673 | // When adding the node, identify whether it is part of a colocation |
674 | // group. |
675 | |
676 | // This code is effectively the equivalent of GetNodeAttr() for a string |
677 | // array, but it avoids all internal allocations (the allocation of the |
678 | // backing store of the std::vector<string> as well as the copies of the |
679 | // strings within it). Instead, we combine the query of the colocation |
680 | // attribute with the calls to ColocateNodeToGroup. |
681 | const AttrValue* attr_value = |
682 | node->attrs().Find(kColocationAttrNameStringPiece); |
683 | if (attr_value != nullptr) { |
684 | if (attr_value->has_list()) { |
685 | for (const string& class_spec : attr_value->list().s()) { |
686 | StringPiece spec(class_spec); |
687 | if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) { |
688 | TF_RETURN_IF_ERROR( |
689 | ColocateNodeToGroup(&colocation_group_root, node, spec)); |
690 | } |
691 | } |
692 | } else if (!attr_value->s().empty()) { |
693 | LOG(ERROR) << "The value for colocation attribute '_class' must be a " |
694 | "list of strings, not a single string: " |
695 | << node->DebugString(); |
696 | } |
697 | } |
698 | |
699 | // Each node belongs to a colocation group with the node's name. |
700 | TF_RETURN_IF_ERROR( |
701 | ColocateNodeToGroup(&colocation_group_root, node, node->name())); |
702 | } |
703 | |
704 | return OkStatus(); |
705 | } |
706 | |
707 | Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src, |
708 | const Node* dst) { |
709 | // Colocate `src` and `dst` to maintain the invariant that nodes |
710 | // connected by reference edges are colocated. |
711 | int src_root_id = FindAndUpdateRoot(src->id()); |
712 | int dst_root_id = FindAndUpdateRoot(dst->id()); |
713 | auto& src_root = members_[src_root_id]; |
714 | auto& dst_root = members_[dst_root_id]; |
715 | |
716 | if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) { |
717 | // If the src root is assigned to a composite device and the dst root is |
718 | // assigned to a physical device, don't colocate the dst root with the src |
719 | // root. |
720 | return OkStatus(); |
721 | } |
722 | TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge( |
723 | *src, src_root, *dst, log_device_placement_)); |
724 | Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id); |
725 | if (!status.ok()) { |
726 | return AttachDef( |
727 | errors::InvalidArgument( |
728 | "Nodes were connected by a reference or resource connection " |
729 | "(requiring them to be on the same device), but the two nodes " |
730 | "were assigned two different devices: " , |
731 | status.error_message()), |
732 | *dst); |
733 | } |
734 | return OkStatus(); |
735 | } |
736 | |
737 | Status ColocationGraph::ColocateResourceAndRefEdges( |
738 | std::unordered_set<Node*>* inspection_required) { |
739 | // If `node` has an input edge with reference type, add an edge from the |
740 | // source of that edge to `node`. |
741 | for (const Edge* edge : graph_.edges()) { |
742 | if (edge->IsControlEdge()) { |
743 | continue; |
744 | } |
745 | Node* src = edge->src(); |
746 | Node* dst = edge->dst(); |
747 | bool needs_inspection; |
748 | TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired( |
749 | *src, &needs_inspection)); |
750 | if (needs_inspection) { |
751 | inspection_required->insert(src); |
752 | continue; |
753 | } |
754 | TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired( |
755 | *dst, &needs_inspection)); |
756 | if (needs_inspection) { |
757 | inspection_required->insert(dst); |
758 | continue; |
759 | } |
760 | |
761 | DataType input_type = dst->input_type(edge->dst_input()); |
762 | |
763 | // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT. |
764 | // This is needed to get around the issue in b/135705778. |
765 | if (input_type == DT_VARIANT && |
766 | data::DatasetOpKernel::IsDatasetOp(src->op_def()) && |
767 | data::DatasetOpKernel::IsDatasetOp(dst->op_def())) { |
768 | TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst)); |
769 | continue; |
770 | } |
771 | |
772 | // Even though we can look inside function calling ops, we make an exception |
773 | // here mostly for performance reasons. Looking inside function calling ops |
774 | // is extra overhead. It is only necessary when they return resources. When |
775 | // they don't, we don't look inside them and make this exception here. |
776 | // Looking inside, could potentially enable us to make better placement |
777 | // decisions. It might be worth doing at some point. |
778 | if ((input_type == DT_RESOURCE || IsRefType(input_type)) && |
779 | !IsExemptFromResourceInputColocation(dst)) { |
780 | TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst)); |
781 | } |
782 | } |
783 | |
784 | return OkStatus(); |
785 | } |
786 | |
787 | namespace { |
788 | // Returns tensor list element data type, if the node is one of the ops that |
789 | // operate with TensorLists. Otherwise returns DT_INVALID. |
790 | // TODO(b/199443424): Don't use op names, use FullType here. |
791 | DataType GetElementDataType(const Node& node) { |
792 | static absl::flat_hash_set<std::string>* tensor_list_ops = |
793 | new absl::flat_hash_set<std::string>( |
794 | {"TensorListReserve" , "TensorListFromTensor" , "EmptyTensorList" , |
795 | "TensorListSplit" , "TensorListScatter" , "TensorListScatterV2" , |
796 | "TensorListScatterIntoExistingList" , "TensorListPushBack" , |
797 | "TensorListPushBackBatch" , "TensorListPopBack" , "TensorListStack" , |
798 | "TensorListConcat" , "TensorListConcatV2" , "TensorListGetItem" , |
799 | "TensorListSetItem" , "TensorListGather" , "TensorListConcatLists" }); |
800 | |
801 | if (tensor_list_ops->contains(node.type_string())) { |
802 | DataType element_type; |
803 | if (GetNodeAttr(node.attrs(), "element_dtype" , &element_type).ok()) { |
804 | return element_type; |
805 | } |
806 | } |
807 | |
808 | return DT_INVALID; |
809 | } |
810 | } // namespace |
811 | |
812 | Status ColocationGraph::AddHostOnlyDataTypesConstraints() { |
813 | auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; }; |
814 | |
815 | auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool { |
816 | return entry.first == DEVICE_CPU; |
817 | }; |
818 | |
819 | for (Node* node : graph_.nodes()) { |
820 | // Skip nodes that do not have DT_VARIANT inputs. |
821 | if (absl::c_none_of(node->input_types(), is_variant)) { |
822 | continue; |
823 | } |
824 | |
825 | // Skip nodes that can't be placed on GPU anyway. |
826 | Member& root = members_[FindAndUpdateRoot(node->id())]; |
827 | if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) { |
828 | continue; |
829 | } |
830 | |
831 | absl::optional<bool> constrain_to_host; |
832 | |
833 | // This is a list of special nodes that we know to have no HostMemory |
834 | // inputs, so if they receive a host-only data type, they must necessarily |
835 | // be constrained to the host. |
836 | // This is brittle. In general, this should be handled by accounting for |
837 | // HostMemory as a constraint when the node's device is known, not ahead of |
838 | // time. |
839 | // A less ideal, but still better alternative is to look for ops which |
840 | // have no HostMemory kernels for the corresponding input. Unfortunately, |
841 | // determining that is challenging because we lack a map from input names |
842 | // to node input indices. |
843 | // TODO(mdan): Fix this. |
844 | if (node->IsRetval() || node->IsIdentity() || node->IsControlFlow() || |
845 | node->IsFunctionCall()) { |
846 | for (const auto& edge : node->in_edges()) { |
847 | if (HasHostMemoryOutType(*edge->src())) { |
848 | // Skip nodes in colocation groups that already have a device |
849 | // assignment |
850 | if (root.has_assigned_device_name()) { |
851 | VLOG(4) << "Special node has host-only data type input " |
852 | << "but is in a colocation group that already has a device " |
853 | << "assignment, so NOT adding constraint:\n" |
854 | << node->def().DebugString() << "\nedge:\n" |
855 | << edge->DebugString(); |
856 | break; |
857 | } else { |
858 | VLOG(4) << "Special node has host-only data type input, " |
859 | << "adding constraint:\n" |
860 | << node->def().DebugString() << "\nedge:\n" |
861 | << edge->DebugString(); |
862 | constrain_to_host = true; |
863 | break; |
864 | } |
865 | } |
866 | } |
867 | } |
868 | |
869 | if (!constrain_to_host.has_value()) { |
870 | // Legacy slow path. This covers legacy data types and ops which have not |
871 | // been upgraded to FullType. |
872 | auto edge_filter = [&](const Edge& edge) -> bool { |
873 | // We already found the underlying data type. |
874 | if (constrain_to_host.has_value()) return false; |
875 | |
876 | // Otherwise follow only DT_VARIANT data edges. |
877 | auto edge_dtype = [&]() -> DataType { |
878 | return edge.src()->output_type(edge.src_output()); |
879 | }; |
880 | return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT; |
881 | }; |
882 | |
883 | auto enter = [&](Node* n) -> void { |
884 | DataType element_type = GetElementDataType(*n); |
885 | // To handle nested lists continue traversal after finding a TensorList |
886 | // operation that uses DT_VARIANT for element type. |
887 | if (element_type == DT_INVALID || element_type == DT_VARIANT) { |
888 | return; |
889 | } |
890 | constrain_to_host = DataTypeAlwaysOnHost(element_type); |
891 | }; |
892 | |
893 | ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr, |
894 | /*stable_comparator=*/nullptr, edge_filter); |
895 | } |
896 | |
897 | if (constrain_to_host.has_value() && *constrain_to_host) { |
898 | VLOG(2) << "Constraining node " << node->name() |
899 | << " to CPU: it has an input with host-only " |
900 | "underlying data type." ; |
901 | |
902 | // Restrict possible device types to CPU only. |
903 | PossibleDevices possible_devices; |
904 | absl::c_copy_if(root.supported_device_types(), |
905 | std::back_inserter(possible_devices.device_types), |
906 | is_cpu_device); |
907 | |
908 | TF_RETURN_IF_ERROR(root.LimitToPossibleDevices( |
909 | possible_devices, /*allow_soft_placement=*/false)); |
910 | } |
911 | } |
912 | |
913 | return OkStatus(); |
914 | } |
915 | |
916 | Status ColocationGraph::AddInspectionConstraints( |
917 | const std::unordered_set<Node*>& inspection_required) { |
918 | for (Node* node : inspection_required) { |
919 | IOColocationGroups groups; |
920 | TF_RETURN_IF_ERROR( |
921 | inspecting_placer_.ComputeIOColocationGroups(*node, &groups)); |
922 | VLOG(2) << "Computed IOColocationGroups for node " << node->name() |
923 | << ":\n\t" << groups.DebugString(); |
924 | TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node)); |
925 | } |
926 | return OkStatus(); |
927 | } |
928 | |
929 | Status ColocationGraph::Initialize() { |
930 | TF_RETURN_IF_ERROR(InitializeMembers()); |
931 | |
932 | std::unordered_set<Node*> inspection_required; |
933 | TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required)); |
934 | TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required)); |
935 | TF_RETURN_IF_ERROR(ColocateAllNodes()); |
936 | TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints()); |
937 | |
938 | for (Node* node : graph_.op_nodes()) { |
939 | int root_id = FindAndUpdateRoot(node->id()); |
940 | members_[root_id].MaybeExcludeXlaDevices(); |
941 | } |
942 | |
943 | return OkStatus(); |
944 | } |
945 | |
946 | // pair containing a node and whether this node has a resource input |
947 | // from the node requiring placer inspection. |
948 | using NodeAndBool = std::pair<const Node*, bool>; |
949 | |
950 | namespace { |
951 | |
952 | // Returns a vector of node names from `nodes`. |
953 | std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) { |
954 | std::vector<string> v; |
955 | v.reserve(nodes.size()); |
956 | for (const NodeAndBool& node_and_bool : nodes) { |
957 | v.push_back(node_and_bool.first->name()); |
958 | } |
959 | return v; |
960 | } |
961 | |
962 | // Given a node requiring placer inspection and its IOColocationGroups, |
963 | // computes `group_nodes`. |
964 | // group_nodes[i] contains the nodes that are members of colocation |
965 | // group i. These nodes are inputs or outputs of `node`. |
966 | // group_nodes[i][j] is a pair containing a node and whether this node |
967 | // has a resource input from `node`. |
968 | // Note: |
969 | // The same node can be added multiple times to the same group. |
970 | // The same node can be added to multiple groups. |
971 | Status GetGroupNodes(const IOColocationGroups& groups, const Node& node, |
972 | std::vector<std::vector<NodeAndBool>>* group_nodes) { |
973 | group_nodes->reserve(groups.group_devices.size()); |
974 | for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) { |
975 | const Node* src; |
976 | TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src)); |
977 | int group_id = groups.input_groups[arg_idx]; |
978 | (*group_nodes)[group_id].emplace_back(src, false); |
979 | } |
980 | |
981 | for (const Edge* edge : node.out_edges()) { |
982 | if (edge->IsControlEdge()) { |
983 | continue; |
984 | } |
985 | |
986 | int group_id = groups.output_groups[edge->src_output()]; |
987 | (*group_nodes)[group_id].emplace_back( |
988 | edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE); |
989 | } |
990 | |
991 | if (VLOG_IS_ON(2)) { |
992 | VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString(); |
993 | for (const std::vector<NodeAndBool>& nodes : *group_nodes) { |
994 | VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n" ) |
995 | << "]" ; |
996 | } |
997 | } |
998 | return OkStatus(); |
999 | } |
1000 | |
1001 | // Returns whether the device_type in `device_attributes` is supported. |
1002 | bool IsSupportedDeviceType(const DeviceAttributes& device_attributes, |
1003 | const DeviceType& supported_type) { |
1004 | if (DeviceType(device_attributes.device_type()) == supported_type) { |
1005 | return true; |
1006 | } |
1007 | return IsCompositeDevice(device_attributes.device_type()); |
1008 | } |
1009 | |
1010 | } // namespace |
1011 | |
1012 | Status ColocationGraph::ApplyIOColocationGroups( |
1013 | const IOColocationGroups& groups, const Node& node) { |
1014 | if (groups.input_groups.size() != node.num_inputs()) { |
1015 | return errors::Internal( |
1016 | "Cannot apply input/output device constraints to node " , |
1017 | node.DebugString(), " because input_groups.size() (" , |
1018 | groups.input_groups.size(), |
1019 | ") is different from number of inputs into the op node (" , |
1020 | node.num_inputs(), ")" ); |
1021 | } |
1022 | if (groups.output_groups.size() != node.num_outputs()) { |
1023 | return errors::Internal( |
1024 | "Cannot apply input/output device constraints to node " , |
1025 | node.DebugString(), " because output_groups.size() (" , |
1026 | groups.output_groups.size(), |
1027 | ") is different from number of outputs into the op node (" , |
1028 | node.num_outputs(), ")" ); |
1029 | } |
1030 | |
1031 | // group_nodes[i] contains the nodes that are members of colocation |
1032 | // group i. These nodes are inputs or outputs of `node`. |
1033 | // group_nodes[i][j] is a pair containing the node and whether this node |
1034 | // has a resource input from `node`. |
1035 | // The same node can be added multiple times to the same group. |
1036 | // The same node can be added to multiple groups. |
1037 | // NOTE: group ids are guarantees to be [0, 1, ..., num_groups]. |
1038 | std::vector<std::vector<NodeAndBool>> group_nodes( |
1039 | groups.group_devices.size()); |
1040 | TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes)); |
1041 | |
1042 | // Colocate nodes in each group |
1043 | for (const std::vector<NodeAndBool>& nodes : group_nodes) { |
1044 | for (int i = 1; i < nodes.size(); ++i) { |
1045 | VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \"" |
1046 | << nodes[i].first->name() << "\"" ; |
1047 | if (nodes[i].second) { |
1048 | TF_RETURN_IF_ERROR( |
1049 | ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first)); |
1050 | } else { |
1051 | TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first)); |
1052 | } |
1053 | } |
1054 | } |
1055 | |
1056 | // Limit devices in each group |
1057 | for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) { |
1058 | // Nothing to do for empty groups. Groups can be empty if some output |
1059 | // of an op is not used. |
1060 | if (group_nodes[group_id].empty()) { |
1061 | continue; |
1062 | } |
1063 | const Node* group_node = group_nodes[group_id][0].first; |
1064 | const PossibleDevices& possible_devices = groups.group_devices[group_id]; |
1065 | TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices)); |
1066 | } |
1067 | |
1068 | return OkStatus(); |
1069 | } |
1070 | |
1071 | Status ColocationGraph::ColocateNodeToGroup( |
1072 | std::unordered_map<StringPiece, const Node*, StringPieceHasher>* |
1073 | colocation_group_root, |
1074 | const Node* node, StringPiece colocation_group) { |
1075 | const Node*& root_node = (*colocation_group_root)[colocation_group]; |
1076 | if (root_node == nullptr) { |
1077 | // This is the first node of the colocation group, so |
1078 | // designate this node as the 'root' of that colocation group. |
1079 | root_node = node; |
1080 | } else { |
1081 | // Try to colocate the node with the root. If there is an |
1082 | // error, return it. |
1083 | Status s = ColocateNodes(*node, *root_node); |
1084 | if (!s.ok()) { |
1085 | if (!allow_soft_placement_) { |
1086 | return AttachDef(s, *node); |
1087 | } |
1088 | if (log_device_placement_) { |
1089 | LOG(INFO) << "Ignoring request to colocate node '" << node->name() |
1090 | << "' with nodes in colocation group '" << colocation_group |
1091 | << "' because soft placement is on and an attempt at doing " |
1092 | "so resulted in the following error: " |
1093 | << AttachDef(s, *node).ToString(); |
1094 | } |
1095 | } |
1096 | } |
1097 | return OkStatus(); |
1098 | } |
1099 | |
1100 | // Merge the (possibly disjoint) sets containing nodes "x" and |
1101 | // "y". Returns OK if the all nodes in the union of these sets can |
1102 | // be placed on the same device type. |
1103 | // |
1104 | // NOTE: If this method returns an error, *this is left in an undefined |
1105 | // state. |
1106 | Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) { |
1107 | int x_root = FindAndUpdateRoot(x.id()); |
1108 | int y_root = FindAndUpdateRoot(y.id()); |
1109 | return ColocateNodes(x, x_root, y, y_root); |
1110 | } |
1111 | |
1112 | // This overload of ColocateNodes() allows a caller to provide the root node |
1113 | // ids for the two nodes. For large graphs, this noticeably reduces the |
1114 | // graph load time. |
1115 | Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y, |
1116 | int y_root) { |
1117 | if (x_root == y_root) { |
1118 | return OkStatus(); |
1119 | } |
1120 | |
1121 | Member* new_root_member; |
1122 | Member* old_root_member; |
1123 | Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member, |
1124 | /*dry_run=*/true); |
1125 | |
1126 | // Merge the partial device specifications, and ensure that they are |
1127 | // compatible. NULL options_ is treated as allowing soft placement. |
1128 | // If there is an error, nothing is modified. |
1129 | // TODO(mrry): Consider enriching the error message by pointing |
1130 | // out which nodes have the explicit partial device |
1131 | // specifications that caused this conflict. |
1132 | Status s = new_root_member->MergeDeviceNames(*old_root_member, |
1133 | allow_soft_placement_); |
1134 | if (!s.ok()) { |
1135 | return errors::InvalidArgument( |
1136 | "Cannot colocate nodes " , |
1137 | errors::FormatColocationNodeForError(x.name()), " and " , |
1138 | errors::FormatColocationNodeForError(y.name()), ": " , |
1139 | s.error_message()); |
1140 | } |
1141 | |
1142 | // Ensure that the common root has at least one supported device |
1143 | // type, by computing the intersection of |
1144 | // new_root_member.supported_device_types and |
1145 | // old_root_member.supported_device_types. |
1146 | if (!new_root_member->MergeSupportedDevices(*old_root_member)) { |
1147 | return errors::InvalidArgument( |
1148 | "Cannot colocate nodes " , |
1149 | errors::FormatColocationNodeForError(x.name()), " and " , |
1150 | errors::FormatColocationNodeForError(y.name()), |
1151 | " because no device type supports both of those nodes and the " |
1152 | "other nodes colocated with them." , |
1153 | DebugInfo(x_root), DebugInfo(y_root)); |
1154 | } |
1155 | |
1156 | // All error checks are done, merge the colocation graphs. |
1157 | Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member, |
1158 | /*dry_run=*/false); |
1159 | return OkStatus(); |
1160 | } |
1161 | |
1162 | Status ColocationGraph::LimitToAssignedDevice(const Node& node) { |
1163 | if (node.assigned_device_name_index() < 0) { |
1164 | return errors::Internal( |
1165 | "Expected an assigned node as argument to LimitToAssignedDevice but " |
1166 | "got: " , |
1167 | node.DebugString()); |
1168 | } |
1169 | int root = FindAndUpdateRoot(node.id()); |
1170 | Member& root_member = members_[root]; |
1171 | return root_member.AssignDevice(node); |
1172 | } |
1173 | |
1174 | void ColocationGraph::GetSoftDeviceCandidates( |
1175 | const Node& node, const Member& root_member, int root_id, |
1176 | std::vector<Device*>* possible_devices) { |
1177 | // Try to find supported devices that don't violate resource devices. |
1178 | // The soft_device_name is the same as the requested device name |
1179 | // without specifying the device type or ID (if assigned and requested |
1180 | // devices does not specify them). |
1181 | DeviceNameUtils::ParsedName soft_device_name = |
1182 | root_member.GetPreferredSoftDeviceName(); |
1183 | device_set_.FindMatchingDevices(soft_device_name, possible_devices); |
1184 | if (!possible_devices->empty()) { |
1185 | *possible_devices = FilterSupportedDevices( |
1186 | *possible_devices, root_member.supported_device_types(), |
1187 | default_local_device_); |
1188 | } |
1189 | |
1190 | if (!possible_devices->empty()) { |
1191 | return; |
1192 | } |
1193 | |
1194 | // TODO(iga): Disallow changing resource devices when this ColocationGraph |
1195 | // is for : |
1196 | // - a function called by an op requiring deep inspection, or |
1197 | // - a graph containing ops requiring inspection. |
1198 | // It is fairly tricky to make changing resource devices in presence of |
1199 | // ops requiring inspection work correctly. One thing it would require is to |
1200 | // communicate these "resource movement" decisions across Placer instances. |
1201 | |
1202 | // Failed to find supported devices that don't violate resource devices. |
1203 | // Try finding some devices that violated resource devices. |
1204 | // If we succeed, we will log a warning below. |
1205 | soft_device_name = root_member.GetSoftDeviceName(); |
1206 | device_set_.FindMatchingDevices(soft_device_name, possible_devices); |
1207 | if (!possible_devices->empty()) { |
1208 | *possible_devices = FilterSupportedDevices( |
1209 | *possible_devices, root_member.supported_device_types(), |
1210 | default_local_device_); |
1211 | } |
1212 | |
1213 | if (!possible_devices->empty()) { |
1214 | LOG(WARNING) |
1215 | << "Failed to place the graph without changing the devices of some " |
1216 | "resources. Some of the operations (that had to be colocated with " |
1217 | "resource generating operations) are not supported on the " |
1218 | "resources' devices. Current candidate devices are [\n " |
1219 | << absl::StrJoin(DevicesToString(*possible_devices), "\n " ) |
1220 | << "].\nSee below for details of this colocation group:" |
1221 | << DebugInfo(root_id); |
1222 | } |
1223 | } |
1224 | |
1225 | Status ColocationGraph::LimitToPossibleDevices(const Node& node, |
1226 | const PossibleDevices& devices) { |
1227 | int root = FindAndUpdateRoot(node.id()); |
1228 | Member& root_member = members_[root]; |
1229 | return root_member.LimitToPossibleDevices(devices, allow_soft_placement_); |
1230 | } |
1231 | |
1232 | Status ColocationGraph::GetDevicesForNode( |
1233 | Node* node, const std::vector<Device*>** possible_devices) { |
1234 | *possible_devices = nullptr; |
1235 | const int node_root = FindAndUpdateRoot(node->id()); |
1236 | if (!members_[node_root].possible_devices().empty()) { |
1237 | *possible_devices = &members_[node_root].possible_devices(); |
1238 | return OkStatus(); |
1239 | } |
1240 | |
1241 | Member& root_member = members_[node_root]; |
1242 | |
1243 | // We have not yet computed the possible devices for the |
1244 | // colocated node set containing 'node', so we do so now using the |
1245 | // constraints on the root node. |
1246 | |
1247 | // "devices" will contain the set of feasible placements for the |
1248 | // colocated node set containing 'node'. |
1249 | // NOTE: Basing possible device computation on requested device name |
1250 | // is guaranteed to respect the assigned and resource device names because |
1251 | // requested device is always a specialization of both. |
1252 | std::vector<Device*> devices; |
1253 | if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) { |
1254 | // The root node has a (possibly partial) device |
1255 | // specification, so enumerate the physical devices that |
1256 | // conform to it. |
1257 | device_set_.FindMatchingDevices(root_member.requested_device_name(), |
1258 | &devices); |
1259 | |
1260 | if (!devices.empty()) { |
1261 | // Filter devices into those that are compatible with the root |
1262 | // node (and its children). |
1263 | devices = FilterSupportedDevices( |
1264 | devices, root_member.supported_device_types(), default_local_device_); |
1265 | } |
1266 | |
1267 | // Perform soft placement if allow_soft_placement_ is set. |
1268 | if (devices.empty() && allow_soft_placement_) { |
1269 | GetSoftDeviceCandidates(*node, root_member, node_root, &devices); |
1270 | } |
1271 | |
1272 | if (devices.empty()) { |
1273 | // Return an error when a physical device that matches an explicit |
1274 | // device specification is not found. This ensures that we don't |
1275 | // assign a node to GPU when the user wanted to force it on CPU. |
1276 | string debug_info = DebugInfo(node_root); |
1277 | |
1278 | DeviceNameUtils::ParsedName specified_device_name; |
1279 | if (DeviceNameUtils::ParseFullName(node->requested_device(), |
1280 | &specified_device_name) && |
1281 | specified_device_name == root_member.requested_device_name()) { |
1282 | // The specified device and merged set device match, and |
1283 | // will appear in the GraphDef (for debugging), so just |
1284 | // print the specified device. |
1285 | std::vector<Device*> devices_matching_nodedef; |
1286 | device_set_.FindMatchingDevices(specified_device_name, |
1287 | &devices_matching_nodedef); |
1288 | if (devices_matching_nodedef.empty()) { |
1289 | // Sometimes it is almost impossible to understand the problem |
1290 | // without a list of available devices. |
1291 | std::vector<string> device_names; |
1292 | for (const Device* device : device_set_.devices()) { |
1293 | device_names.push_back(device->name()); |
1294 | } |
1295 | std::sort(device_names.begin(), device_names.end()); |
1296 | |
1297 | string gpu_msg = "" ; |
1298 | if (!IsGoogleCudaEnabled() && |
1299 | absl::AsciiStrToLower(specified_device_name.type) == "gpu" ) { |
1300 | gpu_msg = |
1301 | " The requested device appears to be a GPU, but CUDA is not " |
1302 | "enabled." ; |
1303 | } |
1304 | |
1305 | return errors::InvalidArgument( |
1306 | errors::FormatNodeNameForError(node->name()), |
1307 | " was explicitly assigned to " , node->requested_device(), |
1308 | " but available devices are [ " , |
1309 | absl::StrJoin(device_names, ", " ), " ]. Make sure " , |
1310 | "the device specification refers to a valid device." , gpu_msg); |
1311 | } else if (specified_device_name.has_type) { |
1312 | return errors::InvalidArgument( |
1313 | "Could not satisfy explicit device specification '" , |
1314 | node->requested_device(), "' because no supported kernel for " , |
1315 | specified_device_name.type, " devices is available." , debug_info, |
1316 | "\nOp: " , node->type_string(), |
1317 | "\nNode attrs: " , node->attrs().DebugString(), |
1318 | "\nRegistered kernels:\n" , |
1319 | KernelsRegisteredForOp(node->type_string())); |
1320 | } else { |
1321 | return errors::InvalidArgument( |
1322 | "Could not satisfy explicit device specification '" , |
1323 | node->requested_device(), debug_info); |
1324 | } |
1325 | } else { |
1326 | // The specified device may be a valid device but the |
1327 | // merged set device is different, so print both. |
1328 | // TODO(b/129057603): There are many possibilities at this point. |
1329 | // Provide good error messages. |
1330 | return errors::InvalidArgument( |
1331 | "Could not satisfy explicit device specification '" , |
1332 | node->requested_device(), "' because the node " , |
1333 | errors::FormatColocationNodeForError(node->name()), |
1334 | " was colocated with a group of nodes that " , |
1335 | "required incompatible device '" , |
1336 | DeviceNameUtils::ParsedNameToString( |
1337 | root_member.requested_device_name()), |
1338 | "'. All available devices [" , |
1339 | absl::StrJoin(DevicesToString(device_set_.devices()), ", " ), "]. " , |
1340 | debug_info); |
1341 | } |
1342 | } |
1343 | } else { |
1344 | // The device is completely unspecified, so enumerate the devices that |
1345 | // support all of the nodes in the set. |
1346 | if (device_set_.devices().empty()) { |
1347 | return errors::Internal("No devices are registered" ); |
1348 | } |
1349 | devices = FilterSupportedDevices(device_set_.devices(), |
1350 | root_member.supported_device_types(), |
1351 | default_local_device_); |
1352 | |
1353 | if (devices.empty()) { |
1354 | return errors::InvalidArgument( |
1355 | "Node had no OpKernel registered to support this operation: " , |
1356 | "Operation was " , node->type_string(), " and inputs were [" , |
1357 | DataTypeVectorString(node->input_types()), "].\n" , |
1358 | DebugInfo(node_root)); |
1359 | } |
1360 | } |
1361 | |
1362 | // Cache the result of the possible devices for this node group. |
1363 | root_member.set_possible_devices(std::move(devices)); |
1364 | *possible_devices = &root_member.possible_devices(); |
1365 | return OkStatus(); |
1366 | } |
1367 | |
1368 | Status ColocationGraph::InitializeMembers() { |
1369 | for (Node* node : graph_.op_nodes()) { |
1370 | Status status = InitializeMember(*node, &members_[node->id()]); |
1371 | if (!status.ok()) { |
1372 | return AttachDef(status, *node); |
1373 | } |
1374 | } |
1375 | return OkStatus(); |
1376 | } |
1377 | |
1378 | string ColocationGraph::DebugString() const { |
1379 | std::unordered_set<int> roots; |
1380 | std::vector<string> root_strings; |
1381 | for (const Node* node : graph_.nodes()) { |
1382 | if (!node->IsOp()) { |
1383 | continue; |
1384 | } |
1385 | int node_root = FindRoot(node->id()); |
1386 | if (roots.count(node_root) == 0) { |
1387 | root_strings.push_back(DebugInfo(node_root)); |
1388 | roots.insert(node_root); |
1389 | } |
1390 | } |
1391 | return absl::StrJoin(root_strings, "\n" ); |
1392 | } |
1393 | |
1394 | // Returns debugging info for the node referred to by 'node_root'. |
1395 | string ColocationGraph::DebugInfo(const int node_root) const { |
1396 | string text( |
1397 | "\nColocation Debug Info:\n" |
1398 | "Colocation group had the following types and supported devices: " ); |
1399 | |
1400 | // If this node is part of a colocation group, then we want to |
1401 | // collect the mapping of ops to supported devices, so that |
1402 | // the user can see why an unsatisfiable placement occurred. |
1403 | |
1404 | std::unordered_map<string, string> type_to_devices; |
1405 | std::vector<const Node*> colocation_nodes; |
1406 | int num_nodes_found = 0; |
1407 | |
1408 | for (const Node* node : graph_.nodes()) { |
1409 | if (!node->IsOp()) { |
1410 | continue; |
1411 | } |
1412 | int id = node->id(); |
1413 | if (FindRoot(id) != node_root) { |
1414 | continue; |
1415 | } |
1416 | ++num_nodes_found; |
1417 | colocation_nodes.push_back(node); |
1418 | |
1419 | PrioritizedDeviceTypeVector supported_types; |
1420 | SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types, |
1421 | &local_address_spec_) |
1422 | .IgnoreError(); |
1423 | string devices_registered; |
1424 | for (const auto& device_type : supported_types) { |
1425 | strings::StrAppend(&devices_registered, |
1426 | DeviceTypeString(device_type.first), " " ); |
1427 | } |
1428 | |
1429 | const string& op_type = node->type_string(); |
1430 | type_to_devices[op_type] = std::move(devices_registered); |
1431 | } |
1432 | strings::StrAppend(&text, "\nRoot " , members_[node_root].DebugString()); |
1433 | |
1434 | for (const auto& td : type_to_devices) { |
1435 | strings::StrAppend(&text, "\n" , td.first, ": " , td.second); |
1436 | } |
1437 | strings::StrAppend(&text, |
1438 | "\n\nColocation members, user-requested devices, and " |
1439 | "framework assigned devices, if any:" ); |
1440 | for (const Node* node : colocation_nodes) { |
1441 | strings::StrAppend(&text, "\n " , node->name(), " (" , node->type_string(), |
1442 | ") " , node->requested_device()); |
1443 | if (node->has_assigned_device_name()) { |
1444 | strings::StrAppend( |
1445 | &text, " framework assigned device=" , node->assigned_device_name()); |
1446 | } |
1447 | } |
1448 | strings::StrAppend(&text, "\n" ); |
1449 | |
1450 | if (num_nodes_found <= 0) { |
1451 | text.clear(); |
1452 | } |
1453 | return text; |
1454 | } |
1455 | |
1456 | Status ColocationGraph::InitializeMemberWithAssignedDevice( |
1457 | const string& assigned_device_name, const string& node_type, |
1458 | Member* member) { |
1459 | // This node has already been assigned to a device, so we |
1460 | // respect this placement, after sanity-checking it. |
1461 | // NOTE: Since any assignment must have been performed by |
1462 | // the TensorFlow runtime, we consider errors in this branch to |
1463 | // be INTERNAL. |
1464 | TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name)); |
1465 | |
1466 | // Since assigned device must be a full specification, do extra checks. |
1467 | const Device* assigned_device = |
1468 | device_set_.FindDeviceByName(assigned_device_name); |
1469 | if (assigned_device == nullptr) { |
1470 | // TODO(b/129295848, b/122851476): Remove the bit about cross-host function |
1471 | // calls when they are supported. |
1472 | return errors::Internal( |
1473 | "Assigned device '" , assigned_device_name, |
1474 | "' does not match any device. This error can happen when one attempts " |
1475 | "to run a tf.function with resource inputs residing on remote devices. " |
1476 | "This use case is currently not supported. Here are the devices " |
1477 | "available on this machine: [" , |
1478 | absl::StrJoin(DevicesToString(device_set_.devices()), ", " ), "]." , |
1479 | "If you are seeing this error when running using a tf.Session, set " |
1480 | "share_cluster_devices_in_session to true in the tf.ConfigProto." ); |
1481 | } |
1482 | |
1483 | for (const auto& d : member->supported_device_types()) { |
1484 | if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) { |
1485 | return OkStatus(); |
1486 | } |
1487 | } |
1488 | |
1489 | return errors::Internal("Assigned device '" , assigned_device_name, |
1490 | "' does not have registered OpKernel support " |
1491 | "for " , |
1492 | node_type); |
1493 | } |
1494 | |
1495 | Status ColocationGraph::InitializeMember(const Node& node, Member* member) { |
1496 | TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices( |
1497 | node, device_types_, &local_address_spec_)); |
1498 | |
1499 | if (node.has_assigned_device_name()) { |
1500 | TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice( |
1501 | node.assigned_device_name(), node.type_string(), member)); |
1502 | } else { |
1503 | // This node has not yet been assigned to a device, so we |
1504 | // calculate any constraints due to the set of registered |
1505 | // kernels and any (partial) user-provided device specification |
1506 | // in the NodeDef. |
1507 | |
1508 | // If no kernels are registered for this op type, fail with an error. |
1509 | if (member->supported_device_types().empty()) { |
1510 | std::set<string> registered_device_types; |
1511 | for (Device* d : device_set_.devices()) { |
1512 | registered_device_types.insert(d->device_type()); |
1513 | } |
1514 | return errors::InvalidArgument( |
1515 | "No OpKernel was registered to support Op '" , node.type_string(), |
1516 | "' used by " , errors::FormatNodeNameForError(node.name()), |
1517 | " with these attrs: [" , node.attrs().DebugString(), |
1518 | "]\n" |
1519 | "Registered devices: [" , |
1520 | absl::StrJoin(registered_device_types, ", " ), "]\n" , |
1521 | "Registered kernels:\n" , KernelsRegisteredForOp(node.type_string())); |
1522 | } |
1523 | |
1524 | // If the NodeDef contains a device, then we interpret it as a |
1525 | // (partial) device specification. |
1526 | if (!node.requested_device().empty()) { |
1527 | if (IsRefOrResourceGeneratorNode(node)) { |
1528 | // Treat requested device on resource generating nodes as assigned |
1529 | // device so that we don't override it. |
1530 | TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node)); |
1531 | } else { |
1532 | // The user has specified a device in the NodeDef, try to find a |
1533 | // valid device matching their specification in the set of |
1534 | // devices. |
1535 | // NOTE: The full name may specify a device that is not in |
1536 | // n.supported_device_types(), but we check that in AssignDevice(). |
1537 | TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node)); |
1538 | } |
1539 | } |
1540 | } |
1541 | return OkStatus(); |
1542 | } |
1543 | |
1544 | // Returns a list of devices having type in supported_device_types. The |
1545 | // returned list is sorted by preferred type (higher numeric type is preferred). |
1546 | /*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices( |
1547 | const std::vector<Device*>& devices, |
1548 | const PrioritizedDeviceTypeVector& supported_device_types, |
1549 | const Device* default_local_device) { |
1550 | Device* filtered_default_device = nullptr; |
1551 | PrioritizedDeviceVector prioritized_filtered_devices; |
1552 | for (const auto& supported_device_type : supported_device_types) { |
1553 | for (Device* device : devices) { |
1554 | if (IsSupportedDeviceType(device->attributes(), |
1555 | supported_device_type.first)) { |
1556 | if (default_local_device && |
1557 | (device == default_local_device || |
1558 | // TODO(nareshmodi, fishx): At times the device pointer in the |
1559 | // device set is different to the one passed in as the default |
1560 | // device. Figure out why this might be. |
1561 | device->name() == default_local_device->name())) { |
1562 | filtered_default_device = device; |
1563 | } else { |
1564 | prioritized_filtered_devices.emplace_back( |
1565 | device, supported_device_type.second); |
1566 | } |
1567 | } |
1568 | } |
1569 | } |
1570 | DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices); |
1571 | |
1572 | std::vector<Device*> filtered_devices; |
1573 | if (filtered_default_device != nullptr) { |
1574 | filtered_devices.emplace_back(filtered_default_device); |
1575 | } |
1576 | for (const auto& prioritized_filtered_device : prioritized_filtered_devices) { |
1577 | filtered_devices.push_back(prioritized_filtered_device.first); |
1578 | } |
1579 | return filtered_devices; |
1580 | } |
1581 | |
1582 | } // namespace tensorflow |
1583 | |