1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/colocation_graph.h"
17
18#include <memory>
19#include <set>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "absl/algorithm/container.h"
26#include "absl/container/flat_hash_set.h"
27#include "absl/strings/str_join.h"
28#include "absl/types/optional.h"
29#include "tensorflow/core/common_runtime/composite_device.h"
30#include "tensorflow/core/common_runtime/device.h"
31#include "tensorflow/core/common_runtime/device_set.h"
32#include "tensorflow/core/common_runtime/input_colocation_exemption_registry.h"
33#include "tensorflow/core/common_runtime/inspecting_placer.h"
34#include "tensorflow/core/common_runtime/partitioning_utils.h"
35#include "tensorflow/core/framework/attr_value.pb.h"
36#include "tensorflow/core/framework/attr_value_util.h"
37#include "tensorflow/core/framework/dataset.h"
38#include "tensorflow/core/framework/device_attributes.pb.h"
39#include "tensorflow/core/framework/full_type.pb.h"
40#include "tensorflow/core/framework/full_type_util.h"
41#include "tensorflow/core/framework/function.h"
42#include "tensorflow/core/framework/node_def_util.h"
43#include "tensorflow/core/framework/op_kernel.h"
44#include "tensorflow/core/framework/types.h"
45#include "tensorflow/core/framework/types.pb.h"
46#include "tensorflow/core/graph/algorithm.h"
47#include "tensorflow/core/graph/graph_node_util.h"
48#include "tensorflow/core/lib/core/errors.h"
49#include "tensorflow/core/lib/core/stringpiece.h"
50#include "tensorflow/core/lib/strings/str_util.h"
51#include "tensorflow/core/lib/strings/strcat.h"
52#include "tensorflow/core/util/device_name_utils.h"
53#include "tensorflow/core/util/dump_graph.h"
54#include "tensorflow/core/util/port.h"
55
56namespace tensorflow {
57
58namespace {
59
60// We hoist the conversion from C-style string literal to StringPiece here,
61// so that we can avoid the many repeated calls to strlen().
62const StringPiece kColocationAttrNameStringPiece(kColocationAttrName);
63const StringPiece kColocationGroupPrefixStringPiece(kColocationGroupPrefix);
64
65// Using absl::StrJoin with lambda does not work in tf-lite builds.
66std::vector<string> DevicesToString(const std::vector<Device*> devices) {
67 std::vector<string> v;
68 v.reserve(devices.size());
69 for (Device* d : devices) {
70 v.push_back(d->name());
71 }
72 return v;
73}
74
75// Using absl::StrJoin with lambda does not work in tf-lite builds.
76std::vector<string> DeviceTypeAndPriorityToString(
77 const PrioritizedDeviceTypeVector& devices) {
78 std::vector<string> v;
79 v.reserve(devices.size());
80 for (const std::pair<DeviceType, int32>& device_and_type : devices) {
81 v.push_back(DeviceTypeString(device_and_type.first));
82 }
83 return v;
84}
85
86bool IsRefOrResource(DataType data_type) {
87 return IsRefType(data_type) || data_type == DT_RESOURCE;
88}
89
90// While Placer can override requested device on ops processing
91// resources, i.e. node that take (and potentially return) a resource,
92// it must not override requested device on ops generating a resource,
93// e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
94// output nodes.
95bool IsRefOrResourceGeneratorNode(const Node& node) {
96 return node.num_inputs() == 0 && node.num_outputs() == 1 &&
97 IsRefOrResource(node.output_type(0));
98}
99
100bool IsExemptFromResourceInputColocation(const Node* node) {
101 // Note: Partitioned function calls, which place and partition their
102 // function bodies, are exempt from this check: they forward resource and
103 // ref inputs to operations that are appropriately placed, instead of
104 // dereferencing them.
105 const string& op_type = node->op_def().name();
106 auto exempt_ops = InputColocationExemptionRegistry::Global()->Get();
107 return exempt_ops.find(op_type) != exempt_ops.end();
108}
109
110bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
111 for (const auto& prioritized_device_type : device_types) {
112 if (prioritized_device_type.second != 0) return true;
113 }
114 return false;
115}
116
117bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
118 const PrioritizedDeviceTypeVector& b_types) {
119 if (a_types.size() != b_types.size()) {
120 return false;
121 }
122 for (int i = 0; i < a_types.size(); ++i) {
123 if (a_types[i].first != b_types[i].first) {
124 return false;
125 }
126 }
127 return true;
128}
129
130bool IsXlaDevice(absl::string_view device_type) {
131 if (device_type == "XLA_CPU_JIT" || device_type == "XLA_GPU_JIT" ||
132 device_type == "XLA_TPU_JIT") {
133 // Symbolic XLA device.
134 return true;
135 }
136
137 return (device_type == "XLA_CPU" || device_type == "XLA_GPU" ||
138 device_type == "TPU");
139}
140
141bool IsCompositeDevice(absl::string_view device_type) {
142 return device_type == kCompositeDeviceType;
143}
144
145// TODO(mdan): This is still too coarse.
146// Host-memory constraints are specific to kernel registrations, so in theory
147// they depend on the assigned device.
148// So we need a constraint model of the kind: <<node device>>: <<output_device>>
149bool HasHostMemoryOutType(const Node& node) {
150 if (!node.def().has_experimental_type()) {
151 return false;
152 }
153 const FullTypeDef& ft = node.def().experimental_type();
154 DCHECK(ft.type_id() == TFT_PRODUCT) << ft.DebugString();
155
156 for (const auto& arg : ft.args()) {
157 if (full_type::IsHostMemoryType(arg)) {
158 return true;
159 }
160 }
161
162 return false;
163}
164} // namespace
165
166Status Member::SetParentAndSupportedDevices(
167 const Node& node, const std::vector<DeviceType>& types,
168 const DeviceNameUtils::ParsedName* local_address_spec) {
169 int id = node.id();
170 if (id < 0) {
171 return errors::Internal("Placer should not be creating a Member for node: ",
172 node.DebugString());
173 }
174 parent_ = id;
175 return SupportedDeviceTypesForNode(
176 types, node.def(), &supported_device_types_, local_address_spec);
177}
178
179Status Member::SetAssignedDeviceName(const string& device_name) {
180 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
181 return errors::Internal(
182 "Setting assigned device name when there is a requested device set "
183 "is unsupported");
184 }
185 if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
186 return errors::Internal("Malformed assigned device '", device_name, "'");
187 }
188 // Set requested device to assigned_device to maintain the invariant that
189 // requested is a specialization of assigned.
190 requested_device_name_ = assigned_device_name_;
191 return OkStatus();
192}
193
194Status Member::SetResourceDeviceName(const Node& node) {
195 if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
196 return errors::Internal(
197 "Setting resource device name when there is a requested device set "
198 "is unsupported");
199 }
200
201 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
202 &resource_device_name_)) {
203 return errors::InvalidArgument("Malformed device specification '",
204 node.requested_device(),
205 "' in node: ", node.DebugString());
206 }
207
208 // Set requested device to resource device to maintain the invariant that
209 // requested is a specialization of resource.
210 requested_device_name_ = resource_device_name_;
211 return OkStatus();
212}
213
214Status Member::SetRequestedDeviceName(const Node& node) {
215 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
216 return errors::Internal(
217 "Setting requested device name when there is an assigned device set "
218 "is unsupported");
219 }
220 if (DeviceNameUtils::HasSomeDetails(resource_device_name_)) {
221 return errors::Internal(
222 "Setting requested device name when there is a resource device set "
223 "is unsupported");
224 }
225 if (!DeviceNameUtils::ParseFullName(node.requested_device(),
226 &requested_device_name_)) {
227 return errors::InvalidArgument("Malformed device specification '",
228 node.requested_device(),
229 "' in node: ", node.DebugString());
230 }
231 return OkStatus();
232}
233
234Status Member::FillPossibleDevices(PossibleDevices* possible_device) const {
235 if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
236 return errors::Internal(
237 "Cannot fill PossibleDevices from a member that has non-empty assigned "
238 "device. Did we start assigning devices to functions called by deep "
239 "ops? ",
240 DebugString());
241 }
242 possible_device->requested_device_name = requested_device_name_;
243 possible_device->resource_device_name = resource_device_name_;
244 possible_device->device_types = supported_device_types_;
245 return OkStatus();
246}
247
248bool Member::IsEdgeFromCompositeDeviceToPhysicalDevice(
249 const Member& src_root) const {
250 auto compatible_edge_from_composite_device_to_physical_device =
251 [](const DeviceNameUtils::ParsedName& src_device,
252 const DeviceNameUtils::ParsedName& dst_device) -> bool {
253 return src_device.has_type && dst_device.has_type &&
254 IsCompositeDevice(src_device.type) &&
255 !IsCompositeDevice(dst_device.type);
256 };
257 if (compatible_edge_from_composite_device_to_physical_device(
258 src_root.assigned_device_name_, assigned_device_name_) ||
259 compatible_edge_from_composite_device_to_physical_device(
260 src_root.resource_device_name_, resource_device_name_) ||
261 compatible_edge_from_composite_device_to_physical_device(
262 src_root.requested_device_name_, requested_device_name_)) {
263 return true;
264 }
265 return false;
266}
267
268Status Member::EnsureCompatibilityAcrossResourceEdge(
269 const Node& src, const Member& src_root,
270 const Node& dst, /*dst_root is this*/
271 bool log_device_placement) {
272 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
273 assigned_device_name_)) {
274 return errors::InvalidArgument(
275 "Cannot place the graph because a reference or resource edge "
276 "connects colocation groups with incompatible assigned devices: ",
277 DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
278 " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_),
279 ". The edge src node is name='", src.name(), "' (op='", src.def().op(),
280 "'), and the dst node is name='", dst.name(), "' (op='", dst.def().op(),
281 "').");
282 }
283
284 if (!DeviceNameUtils::AreCompatibleDevNames(src_root.resource_device_name_,
285 resource_device_name_)) {
286 return errors::InvalidArgument(
287 "Cannot place the graph because a reference or resource edge "
288 "connects colocation groups with incompatible resource devices: ",
289 DeviceNameUtils::ParsedNameToString(src_root.resource_device_name_),
290 " vs ", DeviceNameUtils::ParsedNameToString(resource_device_name_),
291 ". The edge src node is name='", src.name(), "' (op='", src.def().op(),
292 "'), and the dst node is name='", dst.name(), "' (op='", dst.def().op(),
293 "').");
294 }
295
296 if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
297 requested_device_name_)) {
298 return OkStatus();
299 }
300
301 // If we are here, assigned and resource devices are compatible but requested
302 // ones are not. We will be overriding the requested device for destination
303 // node, but need to preserve the invariant that it will be a specialization
304 // of the assigned and resource devices.
305 if (log_device_placement) {
306 LOG(INFO) << "Ignoring device specification "
307 << DeviceNameUtils::ParsedNameToString(requested_device_name_)
308 << " for node '" << dst.name()
309 << "' because the input edge from '" << src.name()
310 << "' is a reference connection and already has a device "
311 "field set to "
312 << DeviceNameUtils::ParsedNameToString(
313 src_root.requested_device_name_);
314 }
315 requested_device_name_ = src_root.requested_device_name_;
316 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
317 assigned_device_name_);
318 DeviceNameUtils::EnsureSpecification(&requested_device_name_,
319 resource_device_name_);
320 return OkStatus();
321}
322
323void Member::Merge(std::vector<Member>* tree, int x_root, int y_root,
324 Member** new_root, Member** old_root, bool dry_run) {
325 Member& x_root_member = (*tree)[x_root];
326 Member& y_root_member = (*tree)[y_root];
327
328 // Merge the sets by setting the parent pointer of the smaller tree's root
329 // node to point to the root of the larger tree. Together with path
330 // compression in ColocationGraph::FindRoot, this ensures that we do not
331 // experience pathological performance on graphs such as chains.
332 int new_root_id, old_root_id;
333 if (x_root_member.rank_ < y_root_member.rank_) {
334 // The tree rooted at x_root is shallower, so connect it to
335 // y_root. The rank of y_root is unchanged because its new
336 // child has strictly less rank.
337 if (!dry_run) {
338 x_root_member.parent_ = y_root;
339 }
340 new_root_id = y_root;
341 old_root_id = x_root;
342 } else if (x_root_member.rank_ > y_root_member.rank_) {
343 // The tree rooted at y_root is shallower, so connect it to
344 // x_root. The rank of x_root is unchanged because its new
345 // child has strictly less rank.
346 if (!dry_run) {
347 y_root_member.parent_ = x_root;
348 }
349 new_root_id = x_root;
350 old_root_id = y_root;
351 } else {
352 if (!dry_run) {
353 // Both trees have the same rank, so break the tie by choosing
354 // x_root as the new root.
355 y_root_member.parent_ = x_root;
356 // Increment the rank of the tree rooted at x_root, because it
357 // is now strictly deeper than before.
358 ++x_root_member.rank_;
359 }
360 new_root_id = x_root;
361 old_root_id = y_root;
362 }
363
364 *new_root = &(*tree)[new_root_id];
365 *old_root = &(*tree)[old_root_id];
366}
367
368// tree is non-const because we can change some `parent` pointers in some
369// members for more efficient future lookups. The vector itself is not
370// changed.
371int Member::FindAndUpdateRoot(std::vector<Member>* tree, int node_id) {
372 Member& member = (*tree)[node_id];
373 if (member.parent_ == node_id) {
374 // member.parent is the root of this disjoint tree. Do nothing.
375 } else {
376 member.parent_ = FindAndUpdateRoot(tree, member.parent_);
377 }
378 // Now it is guaranteed that member.parent is the root of this disjoint
379 // tree.
380 return member.parent_;
381}
382
383int Member::FindRoot(const std::vector<Member>& tree, int node_id) {
384 const Member& member = tree[node_id];
385 if (member.parent_ == node_id) {
386 return member.parent_;
387 }
388 return FindRoot(tree, member.parent_);
389}
390
391Status Member::MergeDeviceNames(const Member& other,
392 bool allow_soft_placement) {
393 // Assuming the "requested is a specialization of assigned and resource
394 // devices" invariant holds for this and `other`, it will hold after the
395 // merges below.
396 DeviceNameUtils::ParsedName assigned_device_name_copy = assigned_device_name_;
397 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
398 &assigned_device_name_copy, other.assigned_device_name_));
399
400 DeviceNameUtils::ParsedName resource_device_name_copy = resource_device_name_;
401 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
402 &resource_device_name_copy, other.resource_device_name_));
403
404 DeviceNameUtils::ParsedName requested_device_name_copy =
405 requested_device_name_;
406 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
407 &requested_device_name_copy, other.requested_device_name_,
408 allow_soft_placement));
409
410 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
411 assigned_device_name_copy);
412 DeviceNameUtils::EnsureSpecification(&requested_device_name_copy,
413 resource_device_name_copy);
414
415 // We checked for all errors, now change the devices.
416 assigned_device_name_ = std::move(assigned_device_name_copy);
417 resource_device_name_ = std::move(resource_device_name_copy);
418 requested_device_name_ = std::move(requested_device_name_copy);
419 return OkStatus();
420}
421
422// Updates this to contain the intersection of the device types in
423// this and "other".
424bool Member::MergeSupportedDevices(const Member& other) {
425 return MergeSupportedDevices(other.supported_device_types_);
426}
427
428bool Member::MergeSupportedDevices(
429 const PrioritizedDeviceTypeVector& other_devices) {
430 // Generate intersection with priorities.
431 // Each vector contains the same device types but with different priorities.
432 // The priorities are taken from the corresponding source vector.
433 PrioritizedDeviceTypeVector target_intersection;
434 PrioritizedDeviceTypeVector other_intersection;
435
436 for (const auto& prioritized_device_type : supported_device_types_) {
437 bool found = false;
438 for (const auto& other_prioritized_device_type : other_devices) {
439 if (prioritized_device_type.first ==
440 other_prioritized_device_type.first) {
441 found = true;
442 other_intersection.push_back(other_prioritized_device_type);
443 break;
444 }
445 }
446 if (found) {
447 target_intersection.push_back(prioritized_device_type);
448 }
449 }
450
451 DeviceSet::SortPrioritizedDeviceTypeVector(&target_intersection);
452 DeviceSet::SortPrioritizedDeviceTypeVector(&other_intersection);
453
454 PrioritizedDeviceTypeVector result;
455
456 bool is_target_prioritized = HasPriorities(target_intersection);
457 bool is_other_prioritized = HasPriorities(other_intersection);
458 if (!is_target_prioritized && !is_other_prioritized) {
459 // If neither are prioritized then we just return the original i.e. target
460 // prioritization.
461 result = target_intersection;
462 } else if (is_target_prioritized && !is_other_prioritized) {
463 // If only one is prioritized, then we respect priorities of that in the
464 // intersection.
465 result = target_intersection;
466 } else if (!is_target_prioritized && is_other_prioritized) {
467 result = other_intersection;
468 } else {
469 // If both have priorities and agree then we go with that. If the
470 // prioritization order is different, then we just fallback to the default
471 // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
472 // merged priorities to 0, so that downstream merges work correctly as well.
473 if (ArePrioritiesSame(target_intersection, other_intersection)) {
474 result = target_intersection;
475 } else {
476 for (const auto& prioritized_device : target_intersection) {
477 result.push_back(std::make_pair(prioritized_device.first, 0));
478 }
479 DeviceSet::SortPrioritizedDeviceTypeVector(&result);
480 }
481 }
482
483 if (result.empty()) {
484 return false;
485 }
486 supported_device_types_ = result;
487 return true;
488}
489
490Status Member::AssignDevice(const Node& node) {
491 if (node.assigned_device_name_index() == assigned_device_name_index_) {
492 return OkStatus();
493 }
494
495 DeviceNameUtils::ParsedName parsed;
496 DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
497 Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed);
498 if (!s.ok()) {
499 return errors::Internal(
500 "Constraining by assigned device should not cause an error. Original "
501 "root's assigned device name: ",
502 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
503 " node's assigned device name \"", node.assigned_device_name(),
504 ". Error: ", s.error_message());
505 }
506 s = DeviceNameUtils::MergeOverrideDevNames(&resource_device_name_, parsed);
507 if (!s.ok()) {
508 return errors::Internal(
509 "Constraining by assigned device should not cause an error. Original "
510 "root's resource device name: ",
511 DeviceNameUtils::ParsedNameToString(resource_device_name_),
512 " node's assigned device name \"", node.assigned_device_name(),
513 ". Error: ", s.error_message());
514 }
515 s = DeviceNameUtils::MergeOverrideDevNames(&requested_device_name_, parsed);
516 if (!s.ok()) {
517 return errors::Internal(
518 "Constraining by assigned device should not cause an error. Original "
519 "root's requested device name: \"",
520 DeviceNameUtils::ParsedNameToString(requested_device_name_),
521 "\", node's assigned device name \"", node.assigned_device_name(),
522 "\". Error: ", s.error_message());
523 }
524
525 assigned_device_name_index_ = node.assigned_device_name_index();
526 // Clear cached possible_devices, if any.
527 possible_devices_.clear();
528 return OkStatus();
529}
530
531void Member::MaybeExcludeXlaDevices() {
532 for (const auto& parsed_name :
533 {requested_device_name_, assigned_device_name_, resource_device_name_}) {
534 // Don't exculde XLA devices from supported devices if member is explicitly
535 // assigned to a CompositeDevice.
536 if (parsed_name.has_type && (IsXlaDevice(parsed_name.type) ||
537 IsCompositeDevice(parsed_name.type))) {
538 return;
539 }
540 }
541
542 PrioritizedDeviceTypeVector non_xla_types;
543 absl::c_copy_if(supported_device_types_, std::back_inserter(non_xla_types),
544 [&](const std::pair<DeviceType, int32>& entry) {
545 return !IsXlaDevice(entry.first.type_string());
546 });
547
548 // TODO(b/141216278) Remove all XLA device types from the supported device
549 // types if the node has no requested/assigned/resource XLA device.
550 if (!non_xla_types.empty() &&
551 non_xla_types.size() < supported_device_types_.size()) {
552 supported_device_types_ = std::move(non_xla_types);
553 }
554}
555
556Status Member::LimitToPossibleDevices(const PossibleDevices& devices,
557 bool allow_soft_placement) {
558 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
559 &requested_device_name_, devices.requested_device_name,
560 allow_soft_placement));
561 TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
562 &resource_device_name_, devices.resource_device_name));
563 MergeSupportedDevices(devices.device_types);
564 return OkStatus();
565}
566
567string Member::DebugString() const {
568 return absl::StrCat(
569 "Member(assigned_device_name_index_=", assigned_device_name_index_,
570 " requested_device_name_='",
571 DeviceNameUtils::ParsedNameToString(requested_device_name_),
572 "' assigned_device_name_='",
573 DeviceNameUtils::ParsedNameToString(assigned_device_name_),
574 "' resource_device_name_='",
575 DeviceNameUtils::ParsedNameToString(resource_device_name_),
576 "' supported_device_types_=[",
577 absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
578 ", "),
579 "] possible_devices_=[",
580 absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
581}
582
583DeviceNameUtils::ParsedName Member::GetSoftDeviceName() const {
584 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
585 if (!assigned_device_name_.has_type) {
586 soft_device_name.type.clear();
587 soft_device_name.has_type = false;
588 }
589 if (!assigned_device_name_.has_id) {
590 soft_device_name.has_id = false;
591 }
592 return soft_device_name;
593}
594
595DeviceNameUtils::ParsedName Member::GetPreferredSoftDeviceName() const {
596 DeviceNameUtils::ParsedName soft_device_name = requested_device_name_;
597 if (!assigned_device_name_.has_type && !resource_device_name_.has_type) {
598 soft_device_name.type.clear();
599 soft_device_name.has_type = false;
600 }
601 if (!assigned_device_name_.has_id && !resource_device_name_.has_id) {
602 soft_device_name.has_id = false;
603 }
604 return soft_device_name;
605}
606
607// Returns ParsedName whose address space (i.e. job, replica, task) identifies
608// the address space directly accessible by the local process. If the address
609// space is fully specified and it is exactly the same as the address space
610// of a device, then all kernels of that device should be registered in the
611// local process.
612static const DeviceNameUtils::ParsedName LocalAddressSpec(
613 const Device* client_device, const Device* default_local_device) {
614 if (client_device != nullptr) {
615 return DeviceNameUtils::AddressSpace(client_device->parsed_name());
616 }
617
618 if (default_local_device != nullptr) {
619 return DeviceNameUtils::AddressSpace(default_local_device->parsed_name());
620 }
621
622 // TODO(b/139617593) Return the name of the first local device in device_set_
623 // once we can trust the output of Device::IsLocal().
624 return DeviceNameUtils::ParsedName();
625}
626
627ColocationGraph::ColocationGraph(const Graph* graph, const FunctionStack& stack,
628 const FunctionLibraryDefinition* flib_def,
629 const DeviceSet* device_set,
630 const Device* default_local_device,
631 bool allow_soft_placement,
632 bool log_device_placement)
633 : graph_(*graph),
634 stack_(stack),
635 inspecting_placer_(stack, flib_def, device_set, default_local_device,
636 allow_soft_placement, log_device_placement),
637 inspection_required_checker_(graph, flib_def),
638 device_set_(*device_set),
639 device_types_(device_set->PrioritizedDeviceTypeList()),
640 local_address_spec_(
641 LocalAddressSpec(device_set->client_device(), default_local_device)),
642 default_local_device_(default_local_device),
643 allow_soft_placement_(allow_soft_placement),
644 log_device_placement_(log_device_placement) {
645 members_.resize(graph_.num_node_ids());
646}
647
648// Adds each node of the Graph to this ColocationGraph as a singleton.
649//
650// NOTE: The implementation assumes that the ids of nodes passed to
651// this method are dense and zero-based; the memory used will be linear in
652// the largest node ID.
653// NOTE: If this method returns an error, *this is left in an undefined
654// state.
655Status ColocationGraph::ColocateAllNodes() {
656 // This maps from a colocation group identifier to the 'root' of that
657 // colocation group. Note that the keys in this map are StringPiece; the
658 // actual strings are stored under the NodeDef. The lifetime of this map
659 // is limited to this ColocateAllNodes() method, and no part of the
660 // NodeDef trees are changed during the lifetime of this method, so using
661 // StringPiece as a key is safe.
662 //
663 // Also, as a further optimization, we remove the "loc:@" prefix from
664 // "class" attribute values, when they are used as keys in this table.
665 // This allows us to use StringPiece values that refer to substrings of
666 // 'string' values stored in NodeDef attribute lists, as well as StringPiece
667 // values that refer to 'string' values from NodeDef::name(), without
668 // performing any string allocations.
669 std::unordered_map<StringPiece, const Node*, StringPieceHasher>
670 colocation_group_root;
671
672 for (const Node* node : graph_.op_nodes()) {
673 // When adding the node, identify whether it is part of a colocation
674 // group.
675
676 // This code is effectively the equivalent of GetNodeAttr() for a string
677 // array, but it avoids all internal allocations (the allocation of the
678 // backing store of the std::vector<string> as well as the copies of the
679 // strings within it). Instead, we combine the query of the colocation
680 // attribute with the calls to ColocateNodeToGroup.
681 const AttrValue* attr_value =
682 node->attrs().Find(kColocationAttrNameStringPiece);
683 if (attr_value != nullptr) {
684 if (attr_value->has_list()) {
685 for (const string& class_spec : attr_value->list().s()) {
686 StringPiece spec(class_spec);
687 if (absl::ConsumePrefix(&spec, kColocationGroupPrefixStringPiece)) {
688 TF_RETURN_IF_ERROR(
689 ColocateNodeToGroup(&colocation_group_root, node, spec));
690 }
691 }
692 } else if (!attr_value->s().empty()) {
693 LOG(ERROR) << "The value for colocation attribute '_class' must be a "
694 "list of strings, not a single string: "
695 << node->DebugString();
696 }
697 }
698
699 // Each node belongs to a colocation group with the node's name.
700 TF_RETURN_IF_ERROR(
701 ColocateNodeToGroup(&colocation_group_root, node, node->name()));
702 }
703
704 return OkStatus();
705}
706
707Status ColocationGraph::ColocateResourceOrRefEdge(const Node* src,
708 const Node* dst) {
709 // Colocate `src` and `dst` to maintain the invariant that nodes
710 // connected by reference edges are colocated.
711 int src_root_id = FindAndUpdateRoot(src->id());
712 int dst_root_id = FindAndUpdateRoot(dst->id());
713 auto& src_root = members_[src_root_id];
714 auto& dst_root = members_[dst_root_id];
715
716 if (dst_root.IsEdgeFromCompositeDeviceToPhysicalDevice(src_root)) {
717 // If the src root is assigned to a composite device and the dst root is
718 // assigned to a physical device, don't colocate the dst root with the src
719 // root.
720 return OkStatus();
721 }
722 TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
723 *src, src_root, *dst, log_device_placement_));
724 Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
725 if (!status.ok()) {
726 return AttachDef(
727 errors::InvalidArgument(
728 "Nodes were connected by a reference or resource connection "
729 "(requiring them to be on the same device), but the two nodes "
730 "were assigned two different devices: ",
731 status.error_message()),
732 *dst);
733 }
734 return OkStatus();
735}
736
737Status ColocationGraph::ColocateResourceAndRefEdges(
738 std::unordered_set<Node*>* inspection_required) {
739 // If `node` has an input edge with reference type, add an edge from the
740 // source of that edge to `node`.
741 for (const Edge* edge : graph_.edges()) {
742 if (edge->IsControlEdge()) {
743 continue;
744 }
745 Node* src = edge->src();
746 Node* dst = edge->dst();
747 bool needs_inspection;
748 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
749 *src, &needs_inspection));
750 if (needs_inspection) {
751 inspection_required->insert(src);
752 continue;
753 }
754 TF_RETURN_IF_ERROR(inspection_required_checker_.IsPlacerInspectionRequired(
755 *dst, &needs_inspection));
756 if (needs_inspection) {
757 inspection_required->insert(dst);
758 continue;
759 }
760
761 DataType input_type = dst->input_type(edge->dst_input());
762
763 // Colocate two DatasetOp nodes connected by edge of dtype=DT_VARIANT.
764 // This is needed to get around the issue in b/135705778.
765 if (input_type == DT_VARIANT &&
766 data::DatasetOpKernel::IsDatasetOp(src->op_def()) &&
767 data::DatasetOpKernel::IsDatasetOp(dst->op_def())) {
768 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
769 continue;
770 }
771
772 // Even though we can look inside function calling ops, we make an exception
773 // here mostly for performance reasons. Looking inside function calling ops
774 // is extra overhead. It is only necessary when they return resources. When
775 // they don't, we don't look inside them and make this exception here.
776 // Looking inside, could potentially enable us to make better placement
777 // decisions. It might be worth doing at some point.
778 if ((input_type == DT_RESOURCE || IsRefType(input_type)) &&
779 !IsExemptFromResourceInputColocation(dst)) {
780 TF_RETURN_IF_ERROR(ColocateResourceOrRefEdge(src, dst));
781 }
782 }
783
784 return OkStatus();
785}
786
787namespace {
788// Returns tensor list element data type, if the node is one of the ops that
789// operate with TensorLists. Otherwise returns DT_INVALID.
790// TODO(b/199443424): Don't use op names, use FullType here.
791DataType GetElementDataType(const Node& node) {
792 static absl::flat_hash_set<std::string>* tensor_list_ops =
793 new absl::flat_hash_set<std::string>(
794 {"TensorListReserve", "TensorListFromTensor", "EmptyTensorList",
795 "TensorListSplit", "TensorListScatter", "TensorListScatterV2",
796 "TensorListScatterIntoExistingList", "TensorListPushBack",
797 "TensorListPushBackBatch", "TensorListPopBack", "TensorListStack",
798 "TensorListConcat", "TensorListConcatV2", "TensorListGetItem",
799 "TensorListSetItem", "TensorListGather", "TensorListConcatLists"});
800
801 if (tensor_list_ops->contains(node.type_string())) {
802 DataType element_type;
803 if (GetNodeAttr(node.attrs(), "element_dtype", &element_type).ok()) {
804 return element_type;
805 }
806 }
807
808 return DT_INVALID;
809}
810} // namespace
811
812Status ColocationGraph::AddHostOnlyDataTypesConstraints() {
813 auto is_variant = [](DataType dtype) -> bool { return dtype == DT_VARIANT; };
814
815 auto is_cpu_device = [](const std::pair<DeviceType, int32>& entry) -> bool {
816 return entry.first == DEVICE_CPU;
817 };
818
819 for (Node* node : graph_.nodes()) {
820 // Skip nodes that do not have DT_VARIANT inputs.
821 if (absl::c_none_of(node->input_types(), is_variant)) {
822 continue;
823 }
824
825 // Skip nodes that can't be placed on GPU anyway.
826 Member& root = members_[FindAndUpdateRoot(node->id())];
827 if (absl::c_all_of(root.supported_device_types(), is_cpu_device)) {
828 continue;
829 }
830
831 absl::optional<bool> constrain_to_host;
832
833 // This is a list of special nodes that we know to have no HostMemory
834 // inputs, so if they receive a host-only data type, they must necessarily
835 // be constrained to the host.
836 // This is brittle. In general, this should be handled by accounting for
837 // HostMemory as a constraint when the node's device is known, not ahead of
838 // time.
839 // A less ideal, but still better alternative is to look for ops which
840 // have no HostMemory kernels for the corresponding input. Unfortunately,
841 // determining that is challenging because we lack a map from input names
842 // to node input indices.
843 // TODO(mdan): Fix this.
844 if (node->IsRetval() || node->IsIdentity() || node->IsControlFlow() ||
845 node->IsFunctionCall()) {
846 for (const auto& edge : node->in_edges()) {
847 if (HasHostMemoryOutType(*edge->src())) {
848 // Skip nodes in colocation groups that already have a device
849 // assignment
850 if (root.has_assigned_device_name()) {
851 VLOG(4) << "Special node has host-only data type input "
852 << "but is in a colocation group that already has a device "
853 << "assignment, so NOT adding constraint:\n"
854 << node->def().DebugString() << "\nedge:\n"
855 << edge->DebugString();
856 break;
857 } else {
858 VLOG(4) << "Special node has host-only data type input, "
859 << "adding constraint:\n"
860 << node->def().DebugString() << "\nedge:\n"
861 << edge->DebugString();
862 constrain_to_host = true;
863 break;
864 }
865 }
866 }
867 }
868
869 if (!constrain_to_host.has_value()) {
870 // Legacy slow path. This covers legacy data types and ops which have not
871 // been upgraded to FullType.
872 auto edge_filter = [&](const Edge& edge) -> bool {
873 // We already found the underlying data type.
874 if (constrain_to_host.has_value()) return false;
875
876 // Otherwise follow only DT_VARIANT data edges.
877 auto edge_dtype = [&]() -> DataType {
878 return edge.src()->output_type(edge.src_output());
879 };
880 return !edge.IsControlEdge() && edge_dtype() == DT_VARIANT;
881 };
882
883 auto enter = [&](Node* n) -> void {
884 DataType element_type = GetElementDataType(*n);
885 // To handle nested lists continue traversal after finding a TensorList
886 // operation that uses DT_VARIANT for element type.
887 if (element_type == DT_INVALID || element_type == DT_VARIANT) {
888 return;
889 }
890 constrain_to_host = DataTypeAlwaysOnHost(element_type);
891 };
892
893 ReverseDFSFrom(graph_, {node}, enter, /*leave=*/nullptr,
894 /*stable_comparator=*/nullptr, edge_filter);
895 }
896
897 if (constrain_to_host.has_value() && *constrain_to_host) {
898 VLOG(2) << "Constraining node " << node->name()
899 << " to CPU: it has an input with host-only "
900 "underlying data type.";
901
902 // Restrict possible device types to CPU only.
903 PossibleDevices possible_devices;
904 absl::c_copy_if(root.supported_device_types(),
905 std::back_inserter(possible_devices.device_types),
906 is_cpu_device);
907
908 TF_RETURN_IF_ERROR(root.LimitToPossibleDevices(
909 possible_devices, /*allow_soft_placement=*/false));
910 }
911 }
912
913 return OkStatus();
914}
915
916Status ColocationGraph::AddInspectionConstraints(
917 const std::unordered_set<Node*>& inspection_required) {
918 for (Node* node : inspection_required) {
919 IOColocationGroups groups;
920 TF_RETURN_IF_ERROR(
921 inspecting_placer_.ComputeIOColocationGroups(*node, &groups));
922 VLOG(2) << "Computed IOColocationGroups for node " << node->name()
923 << ":\n\t" << groups.DebugString();
924 TF_RETURN_IF_ERROR(ApplyIOColocationGroups(groups, *node));
925 }
926 return OkStatus();
927}
928
929Status ColocationGraph::Initialize() {
930 TF_RETURN_IF_ERROR(InitializeMembers());
931
932 std::unordered_set<Node*> inspection_required;
933 TF_RETURN_IF_ERROR(ColocateResourceAndRefEdges(&inspection_required));
934 TF_RETURN_IF_ERROR(AddInspectionConstraints(inspection_required));
935 TF_RETURN_IF_ERROR(ColocateAllNodes());
936 TF_RETURN_IF_ERROR(AddHostOnlyDataTypesConstraints());
937
938 for (Node* node : graph_.op_nodes()) {
939 int root_id = FindAndUpdateRoot(node->id());
940 members_[root_id].MaybeExcludeXlaDevices();
941 }
942
943 return OkStatus();
944}
945
946// pair containing a node and whether this node has a resource input
947// from the node requiring placer inspection.
948using NodeAndBool = std::pair<const Node*, bool>;
949
950namespace {
951
952// Returns a vector of node names from `nodes`.
953std::vector<string> NodeAndBoolToString(const std::vector<NodeAndBool>& nodes) {
954 std::vector<string> v;
955 v.reserve(nodes.size());
956 for (const NodeAndBool& node_and_bool : nodes) {
957 v.push_back(node_and_bool.first->name());
958 }
959 return v;
960}
961
962// Given a node requiring placer inspection and its IOColocationGroups,
963// computes `group_nodes`.
964// group_nodes[i] contains the nodes that are members of colocation
965// group i. These nodes are inputs or outputs of `node`.
966// group_nodes[i][j] is a pair containing a node and whether this node
967// has a resource input from `node`.
968// Note:
969// The same node can be added multiple times to the same group.
970// The same node can be added to multiple groups.
971Status GetGroupNodes(const IOColocationGroups& groups, const Node& node,
972 std::vector<std::vector<NodeAndBool>>* group_nodes) {
973 group_nodes->reserve(groups.group_devices.size());
974 for (int arg_idx = 0; arg_idx < groups.input_groups.size(); ++arg_idx) {
975 const Node* src;
976 TF_RETURN_IF_ERROR(node.input_node(arg_idx, &src));
977 int group_id = groups.input_groups[arg_idx];
978 (*group_nodes)[group_id].emplace_back(src, false);
979 }
980
981 for (const Edge* edge : node.out_edges()) {
982 if (edge->IsControlEdge()) {
983 continue;
984 }
985
986 int group_id = groups.output_groups[edge->src_output()];
987 (*group_nodes)[group_id].emplace_back(
988 edge->dst(), edge->dst()->input_type(edge->dst_input()) == DT_RESOURCE);
989 }
990
991 if (VLOG_IS_ON(2)) {
992 VLOG(2) << "Colocated inputs/outputs of node: " << node.DebugString();
993 for (const std::vector<NodeAndBool>& nodes : *group_nodes) {
994 VLOG(2) << "\t[" << absl::StrJoin(NodeAndBoolToString(nodes), "\t\n")
995 << "]";
996 }
997 }
998 return OkStatus();
999}
1000
1001// Returns whether the device_type in `device_attributes` is supported.
1002bool IsSupportedDeviceType(const DeviceAttributes& device_attributes,
1003 const DeviceType& supported_type) {
1004 if (DeviceType(device_attributes.device_type()) == supported_type) {
1005 return true;
1006 }
1007 return IsCompositeDevice(device_attributes.device_type());
1008}
1009
1010} // namespace
1011
1012Status ColocationGraph::ApplyIOColocationGroups(
1013 const IOColocationGroups& groups, const Node& node) {
1014 if (groups.input_groups.size() != node.num_inputs()) {
1015 return errors::Internal(
1016 "Cannot apply input/output device constraints to node ",
1017 node.DebugString(), " because input_groups.size() (",
1018 groups.input_groups.size(),
1019 ") is different from number of inputs into the op node (",
1020 node.num_inputs(), ")");
1021 }
1022 if (groups.output_groups.size() != node.num_outputs()) {
1023 return errors::Internal(
1024 "Cannot apply input/output device constraints to node ",
1025 node.DebugString(), " because output_groups.size() (",
1026 groups.output_groups.size(),
1027 ") is different from number of outputs into the op node (",
1028 node.num_outputs(), ")");
1029 }
1030
1031 // group_nodes[i] contains the nodes that are members of colocation
1032 // group i. These nodes are inputs or outputs of `node`.
1033 // group_nodes[i][j] is a pair containing the node and whether this node
1034 // has a resource input from `node`.
1035 // The same node can be added multiple times to the same group.
1036 // The same node can be added to multiple groups.
1037 // NOTE: group ids are guarantees to be [0, 1, ..., num_groups].
1038 std::vector<std::vector<NodeAndBool>> group_nodes(
1039 groups.group_devices.size());
1040 TF_RETURN_IF_ERROR(GetGroupNodes(groups, node, &group_nodes));
1041
1042 // Colocate nodes in each group
1043 for (const std::vector<NodeAndBool>& nodes : group_nodes) {
1044 for (int i = 1; i < nodes.size(); ++i) {
1045 VLOG(2) << "Colocating \"" << nodes[0].first->name() << "\" and \""
1046 << nodes[i].first->name() << "\"";
1047 if (nodes[i].second) {
1048 TF_RETURN_IF_ERROR(
1049 ColocateResourceOrRefEdge(nodes[0].first, nodes[i].first));
1050 } else {
1051 TF_RETURN_IF_ERROR(ColocateNodes(*nodes[0].first, *nodes[i].first));
1052 }
1053 }
1054 }
1055
1056 // Limit devices in each group
1057 for (int group_id = 0; group_id < groups.group_devices.size(); ++group_id) {
1058 // Nothing to do for empty groups. Groups can be empty if some output
1059 // of an op is not used.
1060 if (group_nodes[group_id].empty()) {
1061 continue;
1062 }
1063 const Node* group_node = group_nodes[group_id][0].first;
1064 const PossibleDevices& possible_devices = groups.group_devices[group_id];
1065 TF_RETURN_IF_ERROR(LimitToPossibleDevices(*group_node, possible_devices));
1066 }
1067
1068 return OkStatus();
1069}
1070
1071Status ColocationGraph::ColocateNodeToGroup(
1072 std::unordered_map<StringPiece, const Node*, StringPieceHasher>*
1073 colocation_group_root,
1074 const Node* node, StringPiece colocation_group) {
1075 const Node*& root_node = (*colocation_group_root)[colocation_group];
1076 if (root_node == nullptr) {
1077 // This is the first node of the colocation group, so
1078 // designate this node as the 'root' of that colocation group.
1079 root_node = node;
1080 } else {
1081 // Try to colocate the node with the root. If there is an
1082 // error, return it.
1083 Status s = ColocateNodes(*node, *root_node);
1084 if (!s.ok()) {
1085 if (!allow_soft_placement_) {
1086 return AttachDef(s, *node);
1087 }
1088 if (log_device_placement_) {
1089 LOG(INFO) << "Ignoring request to colocate node '" << node->name()
1090 << "' with nodes in colocation group '" << colocation_group
1091 << "' because soft placement is on and an attempt at doing "
1092 "so resulted in the following error: "
1093 << AttachDef(s, *node).ToString();
1094 }
1095 }
1096 }
1097 return OkStatus();
1098}
1099
1100// Merge the (possibly disjoint) sets containing nodes "x" and
1101// "y". Returns OK if the all nodes in the union of these sets can
1102// be placed on the same device type.
1103//
1104// NOTE: If this method returns an error, *this is left in an undefined
1105// state.
1106Status ColocationGraph::ColocateNodes(const Node& x, const Node& y) {
1107 int x_root = FindAndUpdateRoot(x.id());
1108 int y_root = FindAndUpdateRoot(y.id());
1109 return ColocateNodes(x, x_root, y, y_root);
1110}
1111
1112// This overload of ColocateNodes() allows a caller to provide the root node
1113// ids for the two nodes. For large graphs, this noticeably reduces the
1114// graph load time.
1115Status ColocationGraph::ColocateNodes(const Node& x, int x_root, const Node& y,
1116 int y_root) {
1117 if (x_root == y_root) {
1118 return OkStatus();
1119 }
1120
1121 Member* new_root_member;
1122 Member* old_root_member;
1123 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1124 /*dry_run=*/true);
1125
1126 // Merge the partial device specifications, and ensure that they are
1127 // compatible. NULL options_ is treated as allowing soft placement.
1128 // If there is an error, nothing is modified.
1129 // TODO(mrry): Consider enriching the error message by pointing
1130 // out which nodes have the explicit partial device
1131 // specifications that caused this conflict.
1132 Status s = new_root_member->MergeDeviceNames(*old_root_member,
1133 allow_soft_placement_);
1134 if (!s.ok()) {
1135 return errors::InvalidArgument(
1136 "Cannot colocate nodes ",
1137 errors::FormatColocationNodeForError(x.name()), " and ",
1138 errors::FormatColocationNodeForError(y.name()), ": ",
1139 s.error_message());
1140 }
1141
1142 // Ensure that the common root has at least one supported device
1143 // type, by computing the intersection of
1144 // new_root_member.supported_device_types and
1145 // old_root_member.supported_device_types.
1146 if (!new_root_member->MergeSupportedDevices(*old_root_member)) {
1147 return errors::InvalidArgument(
1148 "Cannot colocate nodes ",
1149 errors::FormatColocationNodeForError(x.name()), " and ",
1150 errors::FormatColocationNodeForError(y.name()),
1151 " because no device type supports both of those nodes and the "
1152 "other nodes colocated with them.",
1153 DebugInfo(x_root), DebugInfo(y_root));
1154 }
1155
1156 // All error checks are done, merge the colocation graphs.
1157 Member::Merge(&members_, x_root, y_root, &new_root_member, &old_root_member,
1158 /*dry_run=*/false);
1159 return OkStatus();
1160}
1161
1162Status ColocationGraph::LimitToAssignedDevice(const Node& node) {
1163 if (node.assigned_device_name_index() < 0) {
1164 return errors::Internal(
1165 "Expected an assigned node as argument to LimitToAssignedDevice but "
1166 "got: ",
1167 node.DebugString());
1168 }
1169 int root = FindAndUpdateRoot(node.id());
1170 Member& root_member = members_[root];
1171 return root_member.AssignDevice(node);
1172}
1173
1174void ColocationGraph::GetSoftDeviceCandidates(
1175 const Node& node, const Member& root_member, int root_id,
1176 std::vector<Device*>* possible_devices) {
1177 // Try to find supported devices that don't violate resource devices.
1178 // The soft_device_name is the same as the requested device name
1179 // without specifying the device type or ID (if assigned and requested
1180 // devices does not specify them).
1181 DeviceNameUtils::ParsedName soft_device_name =
1182 root_member.GetPreferredSoftDeviceName();
1183 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1184 if (!possible_devices->empty()) {
1185 *possible_devices = FilterSupportedDevices(
1186 *possible_devices, root_member.supported_device_types(),
1187 default_local_device_);
1188 }
1189
1190 if (!possible_devices->empty()) {
1191 return;
1192 }
1193
1194 // TODO(iga): Disallow changing resource devices when this ColocationGraph
1195 // is for :
1196 // - a function called by an op requiring deep inspection, or
1197 // - a graph containing ops requiring inspection.
1198 // It is fairly tricky to make changing resource devices in presence of
1199 // ops requiring inspection work correctly. One thing it would require is to
1200 // communicate these "resource movement" decisions across Placer instances.
1201
1202 // Failed to find supported devices that don't violate resource devices.
1203 // Try finding some devices that violated resource devices.
1204 // If we succeed, we will log a warning below.
1205 soft_device_name = root_member.GetSoftDeviceName();
1206 device_set_.FindMatchingDevices(soft_device_name, possible_devices);
1207 if (!possible_devices->empty()) {
1208 *possible_devices = FilterSupportedDevices(
1209 *possible_devices, root_member.supported_device_types(),
1210 default_local_device_);
1211 }
1212
1213 if (!possible_devices->empty()) {
1214 LOG(WARNING)
1215 << "Failed to place the graph without changing the devices of some "
1216 "resources. Some of the operations (that had to be colocated with "
1217 "resource generating operations) are not supported on the "
1218 "resources' devices. Current candidate devices are [\n "
1219 << absl::StrJoin(DevicesToString(*possible_devices), "\n ")
1220 << "].\nSee below for details of this colocation group:"
1221 << DebugInfo(root_id);
1222 }
1223}
1224
1225Status ColocationGraph::LimitToPossibleDevices(const Node& node,
1226 const PossibleDevices& devices) {
1227 int root = FindAndUpdateRoot(node.id());
1228 Member& root_member = members_[root];
1229 return root_member.LimitToPossibleDevices(devices, allow_soft_placement_);
1230}
1231
1232Status ColocationGraph::GetDevicesForNode(
1233 Node* node, const std::vector<Device*>** possible_devices) {
1234 *possible_devices = nullptr;
1235 const int node_root = FindAndUpdateRoot(node->id());
1236 if (!members_[node_root].possible_devices().empty()) {
1237 *possible_devices = &members_[node_root].possible_devices();
1238 return OkStatus();
1239 }
1240
1241 Member& root_member = members_[node_root];
1242
1243 // We have not yet computed the possible devices for the
1244 // colocated node set containing 'node', so we do so now using the
1245 // constraints on the root node.
1246
1247 // "devices" will contain the set of feasible placements for the
1248 // colocated node set containing 'node'.
1249 // NOTE: Basing possible device computation on requested device name
1250 // is guaranteed to respect the assigned and resource device names because
1251 // requested device is always a specialization of both.
1252 std::vector<Device*> devices;
1253 if (DeviceNameUtils::HasSomeDetails(root_member.requested_device_name())) {
1254 // The root node has a (possibly partial) device
1255 // specification, so enumerate the physical devices that
1256 // conform to it.
1257 device_set_.FindMatchingDevices(root_member.requested_device_name(),
1258 &devices);
1259
1260 if (!devices.empty()) {
1261 // Filter devices into those that are compatible with the root
1262 // node (and its children).
1263 devices = FilterSupportedDevices(
1264 devices, root_member.supported_device_types(), default_local_device_);
1265 }
1266
1267 // Perform soft placement if allow_soft_placement_ is set.
1268 if (devices.empty() && allow_soft_placement_) {
1269 GetSoftDeviceCandidates(*node, root_member, node_root, &devices);
1270 }
1271
1272 if (devices.empty()) {
1273 // Return an error when a physical device that matches an explicit
1274 // device specification is not found. This ensures that we don't
1275 // assign a node to GPU when the user wanted to force it on CPU.
1276 string debug_info = DebugInfo(node_root);
1277
1278 DeviceNameUtils::ParsedName specified_device_name;
1279 if (DeviceNameUtils::ParseFullName(node->requested_device(),
1280 &specified_device_name) &&
1281 specified_device_name == root_member.requested_device_name()) {
1282 // The specified device and merged set device match, and
1283 // will appear in the GraphDef (for debugging), so just
1284 // print the specified device.
1285 std::vector<Device*> devices_matching_nodedef;
1286 device_set_.FindMatchingDevices(specified_device_name,
1287 &devices_matching_nodedef);
1288 if (devices_matching_nodedef.empty()) {
1289 // Sometimes it is almost impossible to understand the problem
1290 // without a list of available devices.
1291 std::vector<string> device_names;
1292 for (const Device* device : device_set_.devices()) {
1293 device_names.push_back(device->name());
1294 }
1295 std::sort(device_names.begin(), device_names.end());
1296
1297 string gpu_msg = "";
1298 if (!IsGoogleCudaEnabled() &&
1299 absl::AsciiStrToLower(specified_device_name.type) == "gpu") {
1300 gpu_msg =
1301 " The requested device appears to be a GPU, but CUDA is not "
1302 "enabled.";
1303 }
1304
1305 return errors::InvalidArgument(
1306 errors::FormatNodeNameForError(node->name()),
1307 " was explicitly assigned to ", node->requested_device(),
1308 " but available devices are [ ",
1309 absl::StrJoin(device_names, ", "), " ]. Make sure ",
1310 "the device specification refers to a valid device.", gpu_msg);
1311 } else if (specified_device_name.has_type) {
1312 return errors::InvalidArgument(
1313 "Could not satisfy explicit device specification '",
1314 node->requested_device(), "' because no supported kernel for ",
1315 specified_device_name.type, " devices is available.", debug_info,
1316 "\nOp: ", node->type_string(),
1317 "\nNode attrs: ", node->attrs().DebugString(),
1318 "\nRegistered kernels:\n",
1319 KernelsRegisteredForOp(node->type_string()));
1320 } else {
1321 return errors::InvalidArgument(
1322 "Could not satisfy explicit device specification '",
1323 node->requested_device(), debug_info);
1324 }
1325 } else {
1326 // The specified device may be a valid device but the
1327 // merged set device is different, so print both.
1328 // TODO(b/129057603): There are many possibilities at this point.
1329 // Provide good error messages.
1330 return errors::InvalidArgument(
1331 "Could not satisfy explicit device specification '",
1332 node->requested_device(), "' because the node ",
1333 errors::FormatColocationNodeForError(node->name()),
1334 " was colocated with a group of nodes that ",
1335 "required incompatible device '",
1336 DeviceNameUtils::ParsedNameToString(
1337 root_member.requested_device_name()),
1338 "'. All available devices [",
1339 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "]. ",
1340 debug_info);
1341 }
1342 }
1343 } else {
1344 // The device is completely unspecified, so enumerate the devices that
1345 // support all of the nodes in the set.
1346 if (device_set_.devices().empty()) {
1347 return errors::Internal("No devices are registered");
1348 }
1349 devices = FilterSupportedDevices(device_set_.devices(),
1350 root_member.supported_device_types(),
1351 default_local_device_);
1352
1353 if (devices.empty()) {
1354 return errors::InvalidArgument(
1355 "Node had no OpKernel registered to support this operation: ",
1356 "Operation was ", node->type_string(), " and inputs were [",
1357 DataTypeVectorString(node->input_types()), "].\n",
1358 DebugInfo(node_root));
1359 }
1360 }
1361
1362 // Cache the result of the possible devices for this node group.
1363 root_member.set_possible_devices(std::move(devices));
1364 *possible_devices = &root_member.possible_devices();
1365 return OkStatus();
1366}
1367
1368Status ColocationGraph::InitializeMembers() {
1369 for (Node* node : graph_.op_nodes()) {
1370 Status status = InitializeMember(*node, &members_[node->id()]);
1371 if (!status.ok()) {
1372 return AttachDef(status, *node);
1373 }
1374 }
1375 return OkStatus();
1376}
1377
1378string ColocationGraph::DebugString() const {
1379 std::unordered_set<int> roots;
1380 std::vector<string> root_strings;
1381 for (const Node* node : graph_.nodes()) {
1382 if (!node->IsOp()) {
1383 continue;
1384 }
1385 int node_root = FindRoot(node->id());
1386 if (roots.count(node_root) == 0) {
1387 root_strings.push_back(DebugInfo(node_root));
1388 roots.insert(node_root);
1389 }
1390 }
1391 return absl::StrJoin(root_strings, "\n");
1392}
1393
1394// Returns debugging info for the node referred to by 'node_root'.
1395string ColocationGraph::DebugInfo(const int node_root) const {
1396 string text(
1397 "\nColocation Debug Info:\n"
1398 "Colocation group had the following types and supported devices: ");
1399
1400 // If this node is part of a colocation group, then we want to
1401 // collect the mapping of ops to supported devices, so that
1402 // the user can see why an unsatisfiable placement occurred.
1403
1404 std::unordered_map<string, string> type_to_devices;
1405 std::vector<const Node*> colocation_nodes;
1406 int num_nodes_found = 0;
1407
1408 for (const Node* node : graph_.nodes()) {
1409 if (!node->IsOp()) {
1410 continue;
1411 }
1412 int id = node->id();
1413 if (FindRoot(id) != node_root) {
1414 continue;
1415 }
1416 ++num_nodes_found;
1417 colocation_nodes.push_back(node);
1418
1419 PrioritizedDeviceTypeVector supported_types;
1420 SupportedDeviceTypesForNode(device_types_, node->def(), &supported_types,
1421 &local_address_spec_)
1422 .IgnoreError();
1423 string devices_registered;
1424 for (const auto& device_type : supported_types) {
1425 strings::StrAppend(&devices_registered,
1426 DeviceTypeString(device_type.first), " ");
1427 }
1428
1429 const string& op_type = node->type_string();
1430 type_to_devices[op_type] = std::move(devices_registered);
1431 }
1432 strings::StrAppend(&text, "\nRoot ", members_[node_root].DebugString());
1433
1434 for (const auto& td : type_to_devices) {
1435 strings::StrAppend(&text, "\n", td.first, ": ", td.second);
1436 }
1437 strings::StrAppend(&text,
1438 "\n\nColocation members, user-requested devices, and "
1439 "framework assigned devices, if any:");
1440 for (const Node* node : colocation_nodes) {
1441 strings::StrAppend(&text, "\n ", node->name(), " (", node->type_string(),
1442 ") ", node->requested_device());
1443 if (node->has_assigned_device_name()) {
1444 strings::StrAppend(
1445 &text, " framework assigned device=", node->assigned_device_name());
1446 }
1447 }
1448 strings::StrAppend(&text, "\n");
1449
1450 if (num_nodes_found <= 0) {
1451 text.clear();
1452 }
1453 return text;
1454}
1455
1456Status ColocationGraph::InitializeMemberWithAssignedDevice(
1457 const string& assigned_device_name, const string& node_type,
1458 Member* member) {
1459 // This node has already been assigned to a device, so we
1460 // respect this placement, after sanity-checking it.
1461 // NOTE: Since any assignment must have been performed by
1462 // the TensorFlow runtime, we consider errors in this branch to
1463 // be INTERNAL.
1464 TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
1465
1466 // Since assigned device must be a full specification, do extra checks.
1467 const Device* assigned_device =
1468 device_set_.FindDeviceByName(assigned_device_name);
1469 if (assigned_device == nullptr) {
1470 // TODO(b/129295848, b/122851476): Remove the bit about cross-host function
1471 // calls when they are supported.
1472 return errors::Internal(
1473 "Assigned device '", assigned_device_name,
1474 "' does not match any device. This error can happen when one attempts "
1475 "to run a tf.function with resource inputs residing on remote devices. "
1476 "This use case is currently not supported. Here are the devices "
1477 "available on this machine: [",
1478 absl::StrJoin(DevicesToString(device_set_.devices()), ", "), "].",
1479 "If you are seeing this error when running using a tf.Session, set "
1480 "share_cluster_devices_in_session to true in the tf.ConfigProto.");
1481 }
1482
1483 for (const auto& d : member->supported_device_types()) {
1484 if (IsSupportedDeviceType(assigned_device->attributes(), d.first)) {
1485 return OkStatus();
1486 }
1487 }
1488
1489 return errors::Internal("Assigned device '", assigned_device_name,
1490 "' does not have registered OpKernel support "
1491 "for ",
1492 node_type);
1493}
1494
1495Status ColocationGraph::InitializeMember(const Node& node, Member* member) {
1496 TF_RETURN_IF_ERROR(member->SetParentAndSupportedDevices(
1497 node, device_types_, &local_address_spec_));
1498
1499 if (node.has_assigned_device_name()) {
1500 TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
1501 node.assigned_device_name(), node.type_string(), member));
1502 } else {
1503 // This node has not yet been assigned to a device, so we
1504 // calculate any constraints due to the set of registered
1505 // kernels and any (partial) user-provided device specification
1506 // in the NodeDef.
1507
1508 // If no kernels are registered for this op type, fail with an error.
1509 if (member->supported_device_types().empty()) {
1510 std::set<string> registered_device_types;
1511 for (Device* d : device_set_.devices()) {
1512 registered_device_types.insert(d->device_type());
1513 }
1514 return errors::InvalidArgument(
1515 "No OpKernel was registered to support Op '", node.type_string(),
1516 "' used by ", errors::FormatNodeNameForError(node.name()),
1517 " with these attrs: [", node.attrs().DebugString(),
1518 "]\n"
1519 "Registered devices: [",
1520 absl::StrJoin(registered_device_types, ", "), "]\n",
1521 "Registered kernels:\n", KernelsRegisteredForOp(node.type_string()));
1522 }
1523
1524 // If the NodeDef contains a device, then we interpret it as a
1525 // (partial) device specification.
1526 if (!node.requested_device().empty()) {
1527 if (IsRefOrResourceGeneratorNode(node)) {
1528 // Treat requested device on resource generating nodes as assigned
1529 // device so that we don't override it.
1530 TF_RETURN_IF_ERROR(member->SetResourceDeviceName(node));
1531 } else {
1532 // The user has specified a device in the NodeDef, try to find a
1533 // valid device matching their specification in the set of
1534 // devices.
1535 // NOTE: The full name may specify a device that is not in
1536 // n.supported_device_types(), but we check that in AssignDevice().
1537 TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
1538 }
1539 }
1540 }
1541 return OkStatus();
1542}
1543
1544// Returns a list of devices having type in supported_device_types. The
1545// returned list is sorted by preferred type (higher numeric type is preferred).
1546/*static*/ std::vector<Device*> ColocationGraph::FilterSupportedDevices(
1547 const std::vector<Device*>& devices,
1548 const PrioritizedDeviceTypeVector& supported_device_types,
1549 const Device* default_local_device) {
1550 Device* filtered_default_device = nullptr;
1551 PrioritizedDeviceVector prioritized_filtered_devices;
1552 for (const auto& supported_device_type : supported_device_types) {
1553 for (Device* device : devices) {
1554 if (IsSupportedDeviceType(device->attributes(),
1555 supported_device_type.first)) {
1556 if (default_local_device &&
1557 (device == default_local_device ||
1558 // TODO(nareshmodi, fishx): At times the device pointer in the
1559 // device set is different to the one passed in as the default
1560 // device. Figure out why this might be.
1561 device->name() == default_local_device->name())) {
1562 filtered_default_device = device;
1563 } else {
1564 prioritized_filtered_devices.emplace_back(
1565 device, supported_device_type.second);
1566 }
1567 }
1568 }
1569 }
1570 DeviceSet::SortPrioritizedDeviceVector(&prioritized_filtered_devices);
1571
1572 std::vector<Device*> filtered_devices;
1573 if (filtered_default_device != nullptr) {
1574 filtered_devices.emplace_back(filtered_default_device);
1575 }
1576 for (const auto& prioritized_filtered_device : prioritized_filtered_devices) {
1577 filtered_devices.push_back(prioritized_filtered_device.first);
1578 }
1579 return filtered_devices;
1580}
1581
1582} // namespace tensorflow
1583