1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/device_propagation.h"
17
18#include <string>
19#include <utility>
20
21#include "absl/container/flat_hash_set.h"
22#include "tensorflow/core/framework/types.h"
23#include "tensorflow/core/graph/algorithm.h"
24#include "tensorflow/core/graph/graph.h"
25
26namespace tensorflow {
27
28namespace {
29
30const std::string& AssignedOrRequestedDevice(const Node& node) {
31 if (!node.assigned_device_name().empty()) {
32 return node.assigned_device_name();
33 }
34 return node.requested_device();
35}
36
37void UpdateDeviceFromInputs(
38 const device_propagation::NodeFilter& node_filter,
39 const device_propagation::DeviceFilter& device_filter, Node* node) {
40 if (!AssignedOrRequestedDevice(*node).empty() || !node_filter(*node)) {
41 return;
42 }
43 string proposed_device = "";
44 Node* proposed_src = nullptr;
45 // Scan the input edges, propagate device assignment from its inputs to this
46 // node iff all input nodes has the same device assignment and the device is
47 // propagatable (checked by `device_filter`). Some kinds of edges are
48 // ignored.
49 for (const Edge* e : node->in_edges()) {
50 // Ignore control edge.
51 if (e->IsControlEdge()) {
52 continue;
53 }
54 Node* src = e->src();
55 const string& src_device = AssignedOrRequestedDevice(*src);
56
57 // Ignore LoopCond -> Switch and Enter -> Merge. In other words, the device
58 // placement of a Switch op is determined by all its non-LoopCond inputs and
59 // that of a Merge op is determined by all its non-Enter inputs.
60 if ((node->IsSwitch() && src->IsLoopCond()) ||
61 (node->IsMerge() && src->IsEnter())) {
62 continue;
63 }
64
65 // If a source device is not propagatable, stop.
66 if (!device_filter(src_device)) return;
67
68 if (proposed_src == nullptr) {
69 proposed_device = src_device;
70 proposed_src = src;
71 } else if (proposed_device != src_device) {
72 // The device assignments of some input nodes are not the same. Stop.
73 return;
74 }
75 }
76 if (proposed_src) {
77 node->set_assigned_device_name(proposed_src->assigned_device_name());
78 node->set_requested_device(proposed_src->requested_device());
79 }
80}
81
82} // namespace
83
84void PropagateDevices(const device_propagation::NodeFilter& node_filter,
85 const device_propagation::DeviceFilter& device_filter,
86 Graph* graph) {
87 ReverseDFS(*graph, {}, [&node_filter, &device_filter](Node* node) {
88 UpdateDeviceFromInputs(node_filter, device_filter, node);
89 });
90}
91
92} // namespace tensorflow
93