1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/common_runtime/replicate_per_replica_nodes.h" |
16 | |
17 | #include <algorithm> |
18 | #include <queue> |
19 | |
20 | #include "absl/strings/str_cat.h" |
21 | #include "tensorflow/core/common_runtime/optimize_cross_host_control_deps.h" |
22 | #include "tensorflow/core/framework/node_def.pb.h" |
23 | #include "tensorflow/core/framework/node_def_builder.h" |
24 | #include "tensorflow/core/platform/errors.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace { |
28 | |
29 | constexpr int kOptimizeCrossHostEdgesTheshold = 8; |
30 | |
31 | // A helper for rewriting nodes assigned to a virtual composite device. |
32 | class ReplicateHelper { |
33 | public: |
34 | // Initialize replicated nodes with nullptr. |
35 | Status InitializeNode(const Node* node, int num_allowed_devices) { |
36 | if (replicated_nodes_map_.find(node) != replicated_nodes_map_.end()) { |
37 | return errors::InvalidArgument("Node " , node->name(), |
38 | " has been replicated." ); |
39 | } |
40 | std::vector<Node*> replicated_nodes(num_allowed_devices, nullptr); |
41 | replicated_nodes_map_.emplace(node, std::move(replicated_nodes)); |
42 | return OkStatus(); |
43 | } |
44 | |
45 | // Replicate the given node to an allowed device. |
46 | Status ReplicateNode(const Node* node, |
47 | const std::vector<string>& allowed_devices, |
48 | int allowed_device_index, Graph* graph) { |
49 | auto& replicated_nodes = replicated_nodes_map_.at(node); |
50 | if (replicated_nodes[allowed_device_index] != nullptr) { |
51 | return OkStatus(); |
52 | } |
53 | const auto& device = allowed_devices.at(allowed_device_index); |
54 | NodeDef node_def = node->def(); |
55 | const string suffix = strings::StrCat("/R" , allowed_device_index); |
56 | node_def.set_name(graph->NewName(strings::StrCat(node_def.name(), suffix))); |
57 | TF_ASSIGN_OR_RETURN(Node * replicated_node, graph->AddNode(node_def)); |
58 | replicated_node->set_assigned_device_name(device); |
59 | if (replicated_node->IsArg()) { |
60 | replicated_node->AddAttr("sub_index" , allowed_device_index); |
61 | } |
62 | replicated_nodes[allowed_device_index] = replicated_node; |
63 | return OkStatus(); |
64 | } |
65 | |
66 | // Replace an edge (a regular device -> composite device) with |
67 | // N edges (a regular device -> allowed devices). |
68 | void ReplicateFromRegularDeviceToCompositeDevice(const Edge* edge, |
69 | Graph* graph) const { |
70 | Node* src = edge->src(); |
71 | const std::vector<Node*>& dst_replicated_nodes = |
72 | replicated_nodes_map_.at(edge->dst()); |
73 | for (Node* dst : dst_replicated_nodes) { |
74 | // Skip a replicated dst node without any consumer. |
75 | if (dst == nullptr) { |
76 | continue; |
77 | } |
78 | graph->AddEdge(src, edge->src_output(), dst, edge->dst_input()); |
79 | } |
80 | } |
81 | |
82 | // Replace an edge (composite device -> composite device) with |
83 | // N edges (allowed devices -> allowed devices). |
84 | Status ReplicateFromCompositeDeviceToCompositeDevice( |
85 | const Edge* edge, const std::vector<string>& allowed_devices, |
86 | Graph* graph) { |
87 | const std::vector<Node*>& src_replicated_nodes = |
88 | replicated_nodes_map_.at(edge->src()); |
89 | const std::vector<Node*>& dst_replicated_nodes = |
90 | replicated_nodes_map_.at(edge->dst()); |
91 | if (src_replicated_nodes.size() != dst_replicated_nodes.size()) { |
92 | return errors::InvalidArgument( |
93 | "Nodes assigned to the same composite device should have the " |
94 | "same number of replicated nodes. Found an edge from node " , |
95 | edge->src()->name(), " (" , src_replicated_nodes.size(), |
96 | " replicated nodes) to node " , edge->dst()->name(), " (" , |
97 | dst_replicated_nodes.size(), " replicated nodes)." ); |
98 | } |
99 | for (int i = 0; i < src_replicated_nodes.size(); ++i) { |
100 | Node* dst = dst_replicated_nodes.at(i); |
101 | // Skip a replicated dst node without any consumer. |
102 | if (dst == nullptr) { |
103 | continue; |
104 | } |
105 | TF_RETURN_IF_ERROR(ReplicateNode(edge->src(), allowed_devices, i, graph)); |
106 | graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst, |
107 | edge->dst_input()); |
108 | } |
109 | return OkStatus(); |
110 | } |
111 | |
112 | // Data edge: replace an edge (composite device -> a regular device) with |
113 | // one edge (one allowed device -> a regular device). |
114 | // Control edge: replace an edge (composite device -> a regular device) with |
115 | // N edges (allowed devices -> a regular device). |
116 | Status ReplicateFromCompositeDeviceToRegularDevice( |
117 | const Edge* edge, const std::vector<string>& allowed_devices, |
118 | Graph* graph) { |
119 | const std::vector<Node*>& src_replicated_nodes = |
120 | replicated_nodes_map_.at(edge->src()); |
121 | Node* dst = edge->dst(); |
122 | const string& dst_device = dst->assigned_device_name(); |
123 | bool found_src_node = false; |
124 | for (int i = 0; i < allowed_devices.size(); ++i) { |
125 | if (allowed_devices.at(i) == dst_device) { |
126 | TF_RETURN_IF_ERROR( |
127 | ReplicateNode(edge->src(), allowed_devices, i, graph)); |
128 | graph->AddEdge(src_replicated_nodes.at(i), edge->src_output(), dst, |
129 | edge->dst_input()); |
130 | found_src_node = true; |
131 | break; |
132 | } |
133 | } |
134 | if (!found_src_node) { |
135 | for (int i = 0; i < allowed_devices.size(); ++i) { |
136 | TF_RETURN_IF_ERROR( |
137 | ReplicateNode(edge->src(), allowed_devices, i, graph)); |
138 | } |
139 | if (edge->IsControlEdge()) { |
140 | for (Node* replicated_node : src_replicated_nodes) { |
141 | // Duplication check in `Graph::AddControlEdge` is expensive for the |
142 | // dst node with a lot of input edges. Here each (src, dst) pair |
143 | // will only occur once so it is safe to skip the duplication check. |
144 | graph->AddControlEdge(replicated_node, dst, |
145 | /*allow_duplicates=*/true); |
146 | } |
147 | return OkStatus(); |
148 | } |
149 | if (edge->src()->type_string() == "_Arg" ) { |
150 | // This happens when the dst node runs on a host CPU and |
151 | // captures a function with an arg node assigned to the same |
152 | // composite device (e.g. ScanDataset). |
153 | // For this case, we insert a PackOp between replicated nodes and the |
154 | // dst node. The dst node is responsible for unpacking the packed |
155 | // tensor. |
156 | // Add '/Packed' as a substring to the name of the new node, which |
157 | // could be helpful when debugging the graph. |
158 | NodeDefBuilder pack_builder( |
159 | graph->NewName(absl::StrCat(edge->src()->name(), "/Packed" )), |
160 | "Pack" ); |
161 | const int num_replicas = src_replicated_nodes.size(); |
162 | pack_builder.Attr("N" , num_replicas); |
163 | const DataType dtype = edge->src()->output_type(edge->src_output()); |
164 | pack_builder.Attr("T" , dtype); |
165 | std::vector<NodeDefBuilder::NodeOut> inputs; |
166 | inputs.reserve(src_replicated_nodes.size()); |
167 | for (Node* replicated_node : src_replicated_nodes) { |
168 | inputs.emplace_back(NodeDefBuilder::NodeOut{ |
169 | replicated_node->name(), edge->src_output(), dtype}); |
170 | } |
171 | pack_builder.Input(inputs); |
172 | NodeDef pack_def; |
173 | TF_RETURN_IF_ERROR(pack_builder.Finalize(&pack_def)); |
174 | TF_ASSIGN_OR_RETURN(Node * pack_node, graph->AddNode(pack_def)); |
175 | pack_node->set_assigned_device_name(dst->assigned_device_name()); |
176 | for (int i = 0; i < src_replicated_nodes.size(); ++i) { |
177 | graph->AddEdge(src_replicated_nodes[i], edge->src_output(), pack_node, |
178 | i); |
179 | } |
180 | graph->AddEdge(pack_node, /*x=*/0, dst, edge->dst_input()); |
181 | } else { |
182 | return errors::InvalidArgument( |
183 | "Dst node should be assigned to an allowed device. Found an " |
184 | "edge from node " , |
185 | edge->src()->name(), " assigned to " , |
186 | edge->src()->assigned_device_name(), " to node " , dst->name(), |
187 | " assigned to " , dst_device); |
188 | } |
189 | } |
190 | return OkStatus(); |
191 | } |
192 | |
193 | private: |
194 | // Map from original nodes to corresponding replicated nodes. |
195 | absl::flat_hash_map<const Node*, std::vector<Node*>> replicated_nodes_map_; |
196 | }; |
197 | |
198 | // Replicate the nodes in cluster_nodes and update edges. |
199 | Status ReplicateNodesAndEdges(const std::vector<string>& allowed_devices, |
200 | absl::flat_hash_map<Node*, int>* cluster_nodes, |
201 | ReplicateHelper* helper, Graph* graph) { |
202 | // Contains nodes in cluster_nodes whose out nodes are all on physical |
203 | // devices. |
204 | std::queue<Node*> nodes_ready_to_delete; |
205 | for (auto& pair : *cluster_nodes) { |
206 | Node* node = pair.first; |
207 | for (const Edge* edge : node->out_edges()) { |
208 | Node* dst = edge->dst(); |
209 | if (dst->assigned_device_name() != node->assigned_device_name()) { |
210 | // The dst node is assigned to a different device. |
211 | TF_RETURN_IF_ERROR(helper->ReplicateFromCompositeDeviceToRegularDevice( |
212 | edge, allowed_devices, graph)); |
213 | --pair.second; |
214 | } |
215 | } |
216 | // Node is ready to delete when all its consumer nodes are assigned to a |
217 | // physical device. |
218 | if (cluster_nodes->at(node) == 0) { |
219 | nodes_ready_to_delete.push(node); |
220 | } |
221 | } |
222 | |
223 | while (!nodes_ready_to_delete.empty()) { |
224 | Node* node = nodes_ready_to_delete.front(); |
225 | nodes_ready_to_delete.pop(); |
226 | |
227 | // Update input edges. |
228 | for (const Edge* edge : node->in_edges()) { |
229 | Node* src = edge->src(); |
230 | if (src->assigned_device_name() != node->assigned_device_name()) { |
231 | // The source node is assigned to a different device. |
232 | helper->ReplicateFromRegularDeviceToCompositeDevice(edge, graph); |
233 | } else { |
234 | // The source node is assigned to the same composite device. |
235 | TF_RETURN_IF_ERROR( |
236 | helper->ReplicateFromCompositeDeviceToCompositeDevice( |
237 | edge, allowed_devices, graph)); |
238 | if (--(*cluster_nodes)[src] == 0) { |
239 | nodes_ready_to_delete.push(src); |
240 | } |
241 | } |
242 | } |
243 | |
244 | // Remove the original node. |
245 | cluster_nodes->erase(node); |
246 | graph->RemoveNode(node); |
247 | } |
248 | return OkStatus(); |
249 | } |
250 | |
251 | } // namespace |
252 | |
253 | Status ReplicatePerReplicaNodesInFunctionGraph( |
254 | const absl::flat_hash_map<string, const std::vector<string>*>& |
255 | composite_devices, |
256 | Graph* graph) { |
257 | std::set<string> composite_device_names; |
258 | for (const auto& it : composite_devices) { |
259 | composite_device_names.insert(it.first); |
260 | } |
261 | // Map from a composite device to a cluster of nodes assigned to the |
262 | // composite device and the numbers of their out edges to process. |
263 | absl::flat_hash_map<string, absl::flat_hash_map<Node*, int>> |
264 | composite_device_to_cluster_nodes; |
265 | for (Node* n : graph->op_nodes()) { |
266 | if (composite_device_names.find(n->assigned_device_name()) != |
267 | composite_device_names.end()) { |
268 | // TODO(b/145922293): Validate that an _Arg node assigned to a |
269 | // CompositeDevice should have an attribute indicating that the _Arg node |
270 | // represents a packed input. |
271 | composite_device_to_cluster_nodes[n->assigned_device_name()].emplace( |
272 | n, n->out_edges().size()); |
273 | } |
274 | } |
275 | |
276 | if (composite_device_to_cluster_nodes.empty()) { |
277 | VLOG(1) << "No nodes with composiste device found." ; |
278 | return OkStatus(); |
279 | } |
280 | |
281 | for (auto& it : composite_device_to_cluster_nodes) { |
282 | const std::vector<string>& allowed_devices = |
283 | *composite_devices.at(it.first); |
284 | if (allowed_devices.empty()) { |
285 | return errors::InvalidArgument("No allowed device of composite device: " , |
286 | it.first); |
287 | } |
288 | absl::flat_hash_map<Node*, int>& cluster_nodes = it.second; |
289 | if (allowed_devices.size() == 1) { |
290 | // Reuse the original nodes if there is only one allowed device. |
291 | for (const auto& pair : it.second) { |
292 | Node* n = pair.first; |
293 | n->set_assigned_device_name(allowed_devices.at(0)); |
294 | if (n->IsArg()) { |
295 | n->AddAttr("sub_index" , 0); |
296 | } |
297 | } |
298 | continue; |
299 | } |
300 | ReplicateHelper helper; |
301 | for (const auto& pair : cluster_nodes) { |
302 | TF_RETURN_IF_ERROR( |
303 | helper.InitializeNode(pair.first, allowed_devices.size())); |
304 | } |
305 | |
306 | TF_RETURN_IF_ERROR(ReplicateNodesAndEdges(allowed_devices, &cluster_nodes, |
307 | &helper, graph)); |
308 | |
309 | if (!cluster_nodes.empty()) { |
310 | return errors::InvalidArgument( |
311 | "There are still " , cluster_nodes.size(), |
312 | " nodes on CompositiveDevice " , |
313 | cluster_nodes.begin()->first->assigned_device_name()); |
314 | } |
315 | } |
316 | |
317 | // Optimize cross host control output/input edges. We apply the optimizations |
318 | // at the end to reduce the newly created cross-host edges caused by |
319 | // per-replica nodes/edges replications. |
320 | TF_RETURN_IF_ERROR(OptimizeCrossHostControlOutputEdges( |
321 | graph, kOptimizeCrossHostEdgesTheshold)); |
322 | TF_RETURN_IF_ERROR(OptimizeCrossHostControlInputEdges( |
323 | graph, kOptimizeCrossHostEdgesTheshold)); |
324 | |
325 | return OkStatus(); |
326 | } |
327 | |
328 | } // namespace tensorflow |
329 | |