1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/optimize_cross_host_control_deps.h"
17
18#include <vector>
19
20#include "tensorflow/core/framework/node_def.pb.h"
21#include "tensorflow/core/framework/node_def_builder.h"
22#include "tensorflow/core/platform/errors.h"
23#include "tensorflow/core/platform/strcat.h"
24
25namespace tensorflow {
26
27namespace {
28
29Status BuildNoopNode(const Node& source, StringPiece name, const string& device,
30 Graph* graph, Node** node) {
31 NodeDefBuilder builder(name, "NoOp", NodeDebugInfo(source));
32 if (!device.empty()) {
33 builder.Device(device);
34 }
35 NodeDef def;
36 TF_RETURN_IF_ERROR(builder.Finalize(&def));
37
38 TF_ASSIGN_OR_RETURN(*node, graph->AddNode(def));
39 if (!device.empty()) {
40 (*node)->set_assigned_device_name(device);
41 }
42 return OkStatus();
43}
44
45const string& RequestedOrAssignedDevice(const Node* n) {
46 if (!n->assigned_device_name().empty()) {
47 return n->assigned_device_name();
48 }
49 return n->requested_device();
50}
51
52} // namespace
53
54Status OptimizeCrossHostControlOutputEdges(Graph* graph,
55 int cross_host_edges_threshold) {
56 string src_host_device;
57 string dst_host_device;
58 for (Node* n : graph->op_nodes()) {
59 if (n->out_edges().size() < cross_host_edges_threshold) {
60 continue;
61 }
62 absl::flat_hash_map<string, std::vector<const Edge*>>
63 cross_host_control_edges;
64 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
65 RequestedOrAssignedDevice(n), &src_host_device));
66 for (const Edge* edge : n->out_edges()) {
67 if (!edge->IsControlEdge() || edge->dst()->IsSink()) {
68 continue;
69 }
70
71 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
72 RequestedOrAssignedDevice(edge->dst()), &dst_host_device));
73 if (DeviceNameUtils::IsSameAddressSpace(src_host_device,
74 dst_host_device)) {
75 continue;
76 }
77 auto iter = cross_host_control_edges.find(dst_host_device);
78 if (iter == cross_host_control_edges.end()) {
79 cross_host_control_edges[dst_host_device] = {edge};
80 } else {
81 iter->second.push_back(edge);
82 }
83 }
84 for (const auto& pair : cross_host_control_edges) {
85 if (pair.second.size() < cross_host_edges_threshold) {
86 continue;
87 }
88 VLOG(1) << "Optmize cross host output control edge, src node: "
89 << n->name() << " src device: " << src_host_device
90 << " dst host device: " << pair.first
91 << " edges size: " << pair.second.size();
92 Node* control_after;
93 TF_RETURN_IF_ERROR(BuildNoopNode(
94 *n, graph->NewName(strings::StrCat(n->name(), "/", "control_after")),
95 /*device=*/pair.first, graph, &control_after));
96 graph->AddControlEdge(n, control_after);
97 for (const Edge* edge : pair.second) {
98 graph->AddControlEdge(control_after, edge->dst());
99 graph->RemoveEdge(edge);
100 }
101 }
102 }
103 return OkStatus();
104}
105
106Status OptimizeCrossHostControlInputEdges(Graph* graph,
107 int cross_host_edges_threshold) {
108 absl::flat_hash_map<Node*, std::vector<const Edge*>> node_control_input_edges;
109 for (Node* n : graph->op_nodes()) {
110 for (const Edge* edge : n->out_edges()) {
111 if (!edge->IsControlEdge() || edge->dst()->IsSink()) {
112 continue;
113 }
114 Node* dst = edge->dst();
115 auto iter = node_control_input_edges.find(dst);
116 if (iter == node_control_input_edges.end()) {
117 node_control_input_edges[dst] = {edge};
118 } else {
119 node_control_input_edges[dst].push_back(edge);
120 }
121 }
122 }
123
124 string src_host_device;
125 string dst_host_device;
126 for (auto& pair : node_control_input_edges) {
127 Node* dst = pair.first;
128 const std::vector<const Edge*>& input_edges = pair.second;
129
130 if (input_edges.size() < cross_host_edges_threshold) {
131 continue;
132 }
133
134 absl::flat_hash_map<string, std::vector<const Edge*>>
135 cross_host_control_edges;
136 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
137 RequestedOrAssignedDevice(dst), &dst_host_device));
138 for (const Edge* edge : input_edges) {
139 TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
140 RequestedOrAssignedDevice(edge->src()), &src_host_device));
141 if (DeviceNameUtils::IsSameAddressSpace(src_host_device,
142 dst_host_device)) {
143 continue;
144 }
145 auto iter = cross_host_control_edges.find(src_host_device);
146 if (iter == cross_host_control_edges.end()) {
147 cross_host_control_edges[src_host_device] = {edge};
148 } else {
149 iter->second.push_back(edge);
150 }
151 }
152 for (const auto& pair : cross_host_control_edges) {
153 if (pair.second.size() < cross_host_edges_threshold) {
154 continue;
155 }
156 VLOG(0) << "Optmize cross host input control edge, dst node: "
157 << dst->name() << " dst device: " << dst_host_device
158 << " src host device: " << pair.first
159 << " edges size: " << pair.second.size();
160 Node* control_before;
161 TF_RETURN_IF_ERROR(BuildNoopNode(
162 *dst,
163 graph->NewName(strings::StrCat(dst->name(), "/", "control_before")),
164 /*device=*/pair.first, graph, &control_before));
165 graph->AddControlEdge(control_before, dst);
166 for (const Edge* edge : pair.second) {
167 graph->AddControlEdge(edge->src(), control_before);
168 graph->RemoveEdge(edge);
169 }
170 }
171 }
172 return OkStatus();
173}
174
175} // namespace tensorflow
176