1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/optimize_cross_host_control_deps.h" |
17 | |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/framework/node_def.pb.h" |
21 | #include "tensorflow/core/framework/node_def_builder.h" |
22 | #include "tensorflow/core/platform/errors.h" |
23 | #include "tensorflow/core/platform/strcat.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | namespace { |
28 | |
29 | Status BuildNoopNode(const Node& source, StringPiece name, const string& device, |
30 | Graph* graph, Node** node) { |
31 | NodeDefBuilder builder(name, "NoOp" , NodeDebugInfo(source)); |
32 | if (!device.empty()) { |
33 | builder.Device(device); |
34 | } |
35 | NodeDef def; |
36 | TF_RETURN_IF_ERROR(builder.Finalize(&def)); |
37 | |
38 | TF_ASSIGN_OR_RETURN(*node, graph->AddNode(def)); |
39 | if (!device.empty()) { |
40 | (*node)->set_assigned_device_name(device); |
41 | } |
42 | return OkStatus(); |
43 | } |
44 | |
45 | const string& RequestedOrAssignedDevice(const Node* n) { |
46 | if (!n->assigned_device_name().empty()) { |
47 | return n->assigned_device_name(); |
48 | } |
49 | return n->requested_device(); |
50 | } |
51 | |
52 | } // namespace |
53 | |
54 | Status OptimizeCrossHostControlOutputEdges(Graph* graph, |
55 | int cross_host_edges_threshold) { |
56 | string src_host_device; |
57 | string dst_host_device; |
58 | for (Node* n : graph->op_nodes()) { |
59 | if (n->out_edges().size() < cross_host_edges_threshold) { |
60 | continue; |
61 | } |
62 | absl::flat_hash_map<string, std::vector<const Edge*>> |
63 | cross_host_control_edges; |
64 | TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( |
65 | RequestedOrAssignedDevice(n), &src_host_device)); |
66 | for (const Edge* edge : n->out_edges()) { |
67 | if (!edge->IsControlEdge() || edge->dst()->IsSink()) { |
68 | continue; |
69 | } |
70 | |
71 | TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( |
72 | RequestedOrAssignedDevice(edge->dst()), &dst_host_device)); |
73 | if (DeviceNameUtils::IsSameAddressSpace(src_host_device, |
74 | dst_host_device)) { |
75 | continue; |
76 | } |
77 | auto iter = cross_host_control_edges.find(dst_host_device); |
78 | if (iter == cross_host_control_edges.end()) { |
79 | cross_host_control_edges[dst_host_device] = {edge}; |
80 | } else { |
81 | iter->second.push_back(edge); |
82 | } |
83 | } |
84 | for (const auto& pair : cross_host_control_edges) { |
85 | if (pair.second.size() < cross_host_edges_threshold) { |
86 | continue; |
87 | } |
88 | VLOG(1) << "Optmize cross host output control edge, src node: " |
89 | << n->name() << " src device: " << src_host_device |
90 | << " dst host device: " << pair.first |
91 | << " edges size: " << pair.second.size(); |
92 | Node* control_after; |
93 | TF_RETURN_IF_ERROR(BuildNoopNode( |
94 | *n, graph->NewName(strings::StrCat(n->name(), "/" , "control_after" )), |
95 | /*device=*/pair.first, graph, &control_after)); |
96 | graph->AddControlEdge(n, control_after); |
97 | for (const Edge* edge : pair.second) { |
98 | graph->AddControlEdge(control_after, edge->dst()); |
99 | graph->RemoveEdge(edge); |
100 | } |
101 | } |
102 | } |
103 | return OkStatus(); |
104 | } |
105 | |
106 | Status OptimizeCrossHostControlInputEdges(Graph* graph, |
107 | int cross_host_edges_threshold) { |
108 | absl::flat_hash_map<Node*, std::vector<const Edge*>> node_control_input_edges; |
109 | for (Node* n : graph->op_nodes()) { |
110 | for (const Edge* edge : n->out_edges()) { |
111 | if (!edge->IsControlEdge() || edge->dst()->IsSink()) { |
112 | continue; |
113 | } |
114 | Node* dst = edge->dst(); |
115 | auto iter = node_control_input_edges.find(dst); |
116 | if (iter == node_control_input_edges.end()) { |
117 | node_control_input_edges[dst] = {edge}; |
118 | } else { |
119 | node_control_input_edges[dst].push_back(edge); |
120 | } |
121 | } |
122 | } |
123 | |
124 | string src_host_device; |
125 | string dst_host_device; |
126 | for (auto& pair : node_control_input_edges) { |
127 | Node* dst = pair.first; |
128 | const std::vector<const Edge*>& input_edges = pair.second; |
129 | |
130 | if (input_edges.size() < cross_host_edges_threshold) { |
131 | continue; |
132 | } |
133 | |
134 | absl::flat_hash_map<string, std::vector<const Edge*>> |
135 | cross_host_control_edges; |
136 | TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( |
137 | RequestedOrAssignedDevice(dst), &dst_host_device)); |
138 | for (const Edge* edge : input_edges) { |
139 | TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName( |
140 | RequestedOrAssignedDevice(edge->src()), &src_host_device)); |
141 | if (DeviceNameUtils::IsSameAddressSpace(src_host_device, |
142 | dst_host_device)) { |
143 | continue; |
144 | } |
145 | auto iter = cross_host_control_edges.find(src_host_device); |
146 | if (iter == cross_host_control_edges.end()) { |
147 | cross_host_control_edges[src_host_device] = {edge}; |
148 | } else { |
149 | iter->second.push_back(edge); |
150 | } |
151 | } |
152 | for (const auto& pair : cross_host_control_edges) { |
153 | if (pair.second.size() < cross_host_edges_threshold) { |
154 | continue; |
155 | } |
156 | VLOG(0) << "Optmize cross host input control edge, dst node: " |
157 | << dst->name() << " dst device: " << dst_host_device |
158 | << " src host device: " << pair.first |
159 | << " edges size: " << pair.second.size(); |
160 | Node* control_before; |
161 | TF_RETURN_IF_ERROR(BuildNoopNode( |
162 | *dst, |
163 | graph->NewName(strings::StrCat(dst->name(), "/" , "control_before" )), |
164 | /*device=*/pair.first, graph, &control_before)); |
165 | graph->AddControlEdge(control_before, dst); |
166 | for (const Edge* edge : pair.second) { |
167 | graph->AddControlEdge(edge->src(), control_before); |
168 | graph->RemoveEdge(edge); |
169 | } |
170 | } |
171 | } |
172 | return OkStatus(); |
173 | } |
174 | |
175 | } // namespace tensorflow |
176 | |