1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/graph/graph_partition.h" |
17 | |
18 | #include <deque> |
19 | #include <queue> |
20 | #include <unordered_map> |
21 | #include <unordered_set> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "absl/container/flat_hash_map.h" |
26 | #include "tensorflow/core/framework/function.h" |
27 | #include "tensorflow/core/framework/memory_types.h" |
28 | #include "tensorflow/core/framework/node_def_builder.h" |
29 | #include "tensorflow/core/framework/tensor.pb.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/framework/versions.pb.h" |
32 | #include "tensorflow/core/graph/algorithm.h" |
33 | #include "tensorflow/core/graph/control_flow.h" |
34 | #include "tensorflow/core/graph/costmodel.h" |
35 | #include "tensorflow/core/graph/graph_def_builder.h" |
36 | #include "tensorflow/core/graph/node_builder.h" |
37 | #include "tensorflow/core/graph/tensor_id.h" |
38 | #include "tensorflow/core/lib/core/errors.h" |
39 | #include "tensorflow/core/lib/hash/hash.h" |
40 | #include "tensorflow/core/lib/strings/str_util.h" |
41 | #include "tensorflow/core/platform/logging.h" |
42 | #include "tensorflow/core/util/device_name_utils.h" |
43 | #include "tensorflow/core/util/dump_graph.h" |
44 | |
45 | namespace tensorflow { |
46 | |
47 | namespace { |
48 | |
49 | inline bool IsMerge(const NodeDef& node_def) { |
50 | return node_def.op() == "Merge" || node_def.op() == "RefMerge" || |
51 | node_def.op() == "_XlaMerge" ; |
52 | } |
53 | |
54 | inline bool IsNextIteration(const NodeDef& node_def) { |
55 | return node_def.op() == "NextIteration" || |
56 | node_def.op() == "RefNextIteration" ; |
57 | } |
58 | |
59 | struct DupRecvKey { |
60 | int src_node_id; // Edge's src node id |
61 | int src_output_slot; // Edge's src node output slot |
62 | GraphDef* dst_graph; // Edge's dst node is in this subgraph |
63 | bool recv_output_on_host; // The output of recv is on host |
64 | |
65 | template <typename H> |
66 | friend H AbslHashValue(H h, const DupRecvKey& c) { |
67 | return H::combine(std::move(h), c.src_node_id, c.src_output_slot, |
68 | reinterpret_cast<std::uintptr_t>(c.dst_graph), |
69 | c.recv_output_on_host); |
70 | } |
71 | |
72 | friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) { |
73 | return (x.src_node_id == y.src_node_id) && |
74 | (x.src_output_slot == y.src_output_slot) && |
75 | (x.dst_graph == y.dst_graph) && |
76 | (x.recv_output_on_host == y.recv_output_on_host); |
77 | } |
78 | }; |
79 | |
80 | // struct used to store the recvs, so that start times can be properly updated |
81 | struct RecvInfo { |
82 | NodeDef* recv; |
83 | NodeDef* real_recv; |
84 | int64_t start_time; |
85 | }; |
86 | |
87 | typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable; |
88 | |
89 | // A map used to store memory types for the inputs/outputs of every node. |
90 | // The key is a pair of ints consisting of a node id and input/output index. |
91 | // TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC. |
92 | struct NodePort { |
93 | int node_id; |
94 | int index; |
95 | |
96 | friend bool operator==(const NodePort& x, const NodePort& y) { |
97 | return x.node_id == y.node_id && x.index == y.index; |
98 | } |
99 | |
100 | template <typename H> |
101 | friend H AbslHashValue(H h, const NodePort& c) { |
102 | return H::combine(std::move(h), c.node_id, c.index); |
103 | } |
104 | }; |
105 | |
106 | typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap; |
107 | |
108 | // We collect the following information about the graph before performing |
109 | // graph partitioning. |
110 | struct GraphInfo { |
111 | std::vector<DeviceType> device_types; |
112 | MemoryTypeMap input_types; |
113 | MemoryTypeMap output_types; |
114 | std::vector<ControlFlowInfo> cf_info; |
115 | }; |
116 | |
117 | DataType EdgeType(const Edge* e) { |
118 | if (e->IsControlEdge()) { |
119 | return DT_FLOAT; |
120 | } else { |
121 | return e->dst()->input_type(e->dst_input()); |
122 | } |
123 | } |
124 | |
125 | // Return true iff we need to add the same device send/recv for 'edge'. |
126 | bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) { |
127 | if (edge->IsControlEdge()) { |
128 | return false; |
129 | } |
130 | |
131 | const Node* src = edge->src(); |
132 | const Node* dst = edge->dst(); |
133 | if (src->assigned_device_name() == dst->assigned_device_name()) { |
134 | int src_port = edge->src_output(); |
135 | int dst_port = edge->dst_input(); |
136 | if (info.device_types[src->id()] != DEVICE_CPU) { |
137 | auto src_it = info.output_types.find({src->id(), src_port}); |
138 | DCHECK(src_it != info.output_types.end()); |
139 | auto dst_it = info.input_types.find({dst->id(), dst_port}); |
140 | DCHECK(dst_it != info.input_types.end()); |
141 | return src_it->second != dst_it->second; |
142 | } |
143 | } |
144 | return false; |
145 | } |
146 | |
147 | // Return true iff (dst, dst_input) is specified on host memory. |
148 | bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) { |
149 | const Node* dst = edge->dst(); |
150 | int dst_port = edge->dst_input(); |
151 | if (info.device_types[dst->id()] != DEVICE_CPU) { |
152 | if (edge->IsControlEdge()) return false; |
153 | auto dst_it = info.input_types.find({dst->id(), dst_port}); |
154 | DCHECK(dst_it != info.input_types.end()); |
155 | return dst_it->second == HOST_MEMORY; |
156 | } |
157 | return true; |
158 | } |
159 | |
160 | // Add an input to dst that comes from the "src_slot" output of the |
161 | // node named by "src_name". |
162 | void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) { |
163 | if (src_slot == Graph::kControlSlot) { |
164 | dst->add_input(strings::StrCat("^" , src_name)); |
165 | } else if (src_slot == 0) { |
166 | dst->add_input(src_name.data(), src_name.size()); |
167 | } else { |
168 | dst->add_input(strings::StrCat(src_name, ":" , src_slot)); |
169 | } |
170 | } |
171 | |
172 | // Add a control edge from each input to each recv. |
173 | void AddReadControl(const std::vector<NodeDef*>& recvs, |
174 | const std::vector<string>& inputs) { |
175 | for (NodeDef* recv : recvs) { |
176 | for (const string& input : inputs) { |
177 | recv->add_input(strings::StrCat("^" , input)); |
178 | } |
179 | } |
180 | } |
181 | |
182 | void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge, |
183 | const string& tensor_name_attr, NodeDefBuilder* builder) { |
184 | builder->Attr("tensor_name" , tensor_name_attr); |
185 | builder->Attr("send_device" , edge->src()->assigned_device_name()); |
186 | builder->Attr("send_device_incarnation" , |
187 | static_cast<int64_t>( |
188 | opts.get_incarnation(edge->src()->assigned_device_name()))); |
189 | builder->Attr("recv_device" , edge->dst()->assigned_device_name()); |
190 | builder->Attr("client_terminated" , false); |
191 | builder->Attr("_src" , edge->src()->name()); |
192 | builder->Attr("_dst" , edge->dst()->name()); |
193 | } |
194 | |
195 | NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info, |
196 | GraphDef* gdef, const Edge* edge, |
197 | NodeDefBuilder::NodeOut send_from, int64_t start_time, |
198 | const string& tensor_name_attr, Status* status) { |
199 | const DataType dtype = send_from.data_type; |
200 | const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype; |
201 | const Node* src = edge->src(); |
202 | const int src_port = edge->src_output(); |
203 | |
204 | // host_memory = true iff we need to use HostSend/HostCast. |
205 | bool host_memory = false; |
206 | if (!edge->IsControlEdge()) { |
207 | auto src_it = g_info.output_types.find({src->id(), src_port}); |
208 | DCHECK(src_it != g_info.output_types.end()); |
209 | host_memory = (src_it->second == HOST_MEMORY); |
210 | } |
211 | |
212 | // Add a cast node that casts dtype to cast_dtype. |
213 | // NOTE(yuanbyu): Only cast for cross-device send/recv. |
214 | if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) { |
215 | const string cast_op = (host_memory) ? "_HostCast" : "Cast" ; |
216 | NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op, |
217 | NodeDebugInfo(*src)); |
218 | cast_builder.Device(src->assigned_device_name()).Input(send_from); |
219 | if (opts.scheduling_for_recvs) { |
220 | cast_builder.Attr("_start_time" , start_time); |
221 | } |
222 | cast_builder.Attr("DstT" , cast_dtype); |
223 | |
224 | if (cast_dtype == DT_BFLOAT16) { |
225 | // the below attribute specifies that the cast to bfloat16 should use |
226 | // truncation. This is needed to retain legacy behavior when we change |
227 | // the default bfloat16 casts to use rounding instead of truncation |
228 | cast_builder.Attr("Truncate" , true); |
229 | } |
230 | |
231 | NodeDef* cast = gdef->add_node(); |
232 | *status = cast_builder.Finalize(cast, /*consume=*/true); |
233 | if (!status->ok()) return nullptr; |
234 | |
235 | // Connect the Send op to the cast. |
236 | send_from.Reset(cast->name(), 0, cast_dtype); |
237 | } |
238 | |
239 | // Add the send node. |
240 | const string send_op = (host_memory) ? "_HostSend" : "_Send" ; |
241 | NodeDefBuilder send_builder(opts.new_name(src->name()), send_op, |
242 | NodeDebugInfo(*src)); |
243 | SetSendRecvAttrs(opts, edge, tensor_name_attr, &send_builder); |
244 | send_builder.Device(src->assigned_device_name()).Input(send_from); |
245 | if (opts.scheduling_for_recvs) { |
246 | send_builder.Attr("_start_time" , start_time); |
247 | } |
248 | NodeDef* send = gdef->add_node(); |
249 | *status = send_builder.Finalize(send, /*consume=*/true); |
250 | return send; |
251 | } |
252 | |
253 | NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info, |
254 | GraphDef* gdef, const Edge* edge, NodeDef** real_recv, |
255 | const string& tensor_name_attr, Status* status) { |
256 | const DataType dtype = EdgeType(edge); |
257 | const Node* src = edge->src(); |
258 | const Node* dst = edge->dst(); |
259 | const int dst_port = edge->dst_input(); |
260 | DataType cast_dtype = dtype; |
261 | |
262 | // NOTE(yuanbyu): Only cast for cross-device send/recv. |
263 | if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) { |
264 | cast_dtype = opts.should_cast(edge); |
265 | } |
266 | |
267 | // host_memory = true iff we need to use HostRecv/HostCast. |
268 | // Also log the introduction of the send-recv pair, for performance debugging. |
269 | bool host_memory = false; |
270 | if (!edge->IsControlEdge()) { |
271 | auto dst_it = g_info.input_types.find({dst->id(), dst_port}); |
272 | DCHECK(dst_it != g_info.input_types.end()); |
273 | host_memory = (dst_it->second == HOST_MEMORY); |
274 | bool src_host_memory = false; |
275 | if (VLOG_IS_ON(1)) { |
276 | const int src_port = edge->src_output(); |
277 | auto src_it = g_info.output_types.find({src->id(), src_port}); |
278 | DCHECK(src_it != g_info.output_types.end()); |
279 | src_host_memory = (src_it->second == HOST_MEMORY); |
280 | } |
281 | VLOG(1) << "Receiving data" |
282 | << " from " << src->name() << " (" << src->type_string() << ")" |
283 | << " on " << src->assigned_device_name() << " in " |
284 | << (src_host_memory ? "host memory" : "device memory" ) << " for " |
285 | << dst->name() << " (" << dst->type_string() << ")" |
286 | << " on " << dst->assigned_device_name() << " in " |
287 | << (host_memory ? "host memory" : "device memory" ); |
288 | } else { |
289 | // Log control-edge transfers too, but don't mention memory space since it's |
290 | // irrelevant. |
291 | VLOG(1) << "Receiving control" |
292 | << " from " << src->name() << " (" << src->type_string() << ")" |
293 | << " on " << src->assigned_device_name() << " for " << dst->name() |
294 | << " (" << dst->type_string() << ")" |
295 | << " on " << dst->assigned_device_name(); |
296 | } |
297 | |
298 | // Add the recv node. |
299 | const string recv_op = (host_memory) ? "_HostRecv" : "_Recv" ; |
300 | NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op, |
301 | NodeDebugInfo(*src)); |
302 | SetSendRecvAttrs(opts, edge, tensor_name_attr, &recv_builder); |
303 | recv_builder.Device(dst->assigned_device_name()) |
304 | .Attr("tensor_type" , cast_dtype); |
305 | NodeDef* recv = gdef->add_node(); |
306 | *status = recv_builder.Finalize(recv, /*consume=*/true); |
307 | if (!status->ok()) return nullptr; |
308 | *real_recv = recv; |
309 | |
310 | // Add the cast node (from cast_dtype to dtype) or an Identity node. |
311 | if (dtype != cast_dtype) { |
312 | const string cast_op = (host_memory) ? "_HostCast" : "Cast" ; |
313 | NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op, |
314 | NodeDebugInfo(*src)); |
315 | cast_builder.Attr("DstT" , dtype); |
316 | cast_builder.Device(dst->assigned_device_name()) |
317 | .Input(recv->name(), 0, cast_dtype); |
318 | NodeDef* cast = gdef->add_node(); |
319 | *status = cast_builder.Finalize(cast, /*consume=*/true); |
320 | if (!status->ok()) return nullptr; |
321 | return cast; |
322 | } else if (edge->IsControlEdge()) { |
323 | // An Identity is only needed for control edges. |
324 | NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity" , |
325 | NodeDebugInfo(*src)); |
326 | id_builder.Device(dst->assigned_device_name()) |
327 | .Input(recv->name(), 0, cast_dtype); |
328 | NodeDef* id = gdef->add_node(); |
329 | *status = id_builder.Finalize(id, /*consume=*/true); |
330 | if (!status->ok()) return nullptr; |
331 | return id; |
332 | } else { |
333 | return recv; |
334 | } |
335 | } |
336 | |
337 | NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef, |
338 | const Edge* edge, Status* status) { |
339 | const Node* src = edge->src(); |
340 | Tensor tensor(DT_FLOAT, TensorShape({0})); |
341 | NodeDef* result = gdef->add_node(); |
342 | *status = NodeDefBuilder(opts.new_name(src->name()), "Const" ) |
343 | .Device(src->assigned_device_name()) |
344 | .Attr("dtype" , DT_FLOAT) |
345 | .Attr("value" , tensor) |
346 | .Finalize(result, /*consume=*/true); |
347 | return result; |
348 | } |
349 | |
350 | // A dummy node for scheduling. |
351 | NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef, |
352 | const string& assigned_device_name, int64_t epoch, |
353 | int64_t starttime, Status* status) { |
354 | NodeDef* result = gdef->add_node(); |
355 | *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_" , epoch)), |
356 | "ControlTrigger" ) |
357 | .Device(assigned_device_name) |
358 | .Attr("_start_time" , starttime) |
359 | .Finalize(result, /*consume=*/true); |
360 | return result; |
361 | } |
362 | |
363 | // Optimize colocation for control flow nodes. For cond, we want the |
364 | // switch nodes to colocate with its data input. This is particularly |
365 | // needed for conditional reading of a remote variable. It may also |
366 | // reduce the number of devices involved in a loop. |
367 | // TODO(yuanbyu): In this case, we don't respect the requested device in |
368 | // the GraphDef for these nodes. Ideally, the placer would enforce the |
369 | // colocation to render this unnecessary. |
370 | void OptimizeControlFlowColocation(Graph* graph) { |
371 | auto visit = [](Node* node) { |
372 | if (IsSwitch(node)) { |
373 | for (const Edge* in_edge : node->in_edges()) { |
374 | if (in_edge->dst_input() == 0) { |
375 | // Colocate with the data input. |
376 | node->set_assigned_device_name( |
377 | in_edge->src()->assigned_device_name()); |
378 | return; |
379 | } |
380 | } |
381 | } else if (IsExit(node)) { |
382 | for (const Edge* in_edge : node->in_edges()) { |
383 | if (!in_edge->IsControlEdge()) { |
384 | // Colocate with upstream node. |
385 | node->set_assigned_device_name( |
386 | in_edge->src()->assigned_device_name()); |
387 | return; |
388 | } |
389 | } |
390 | } else { |
391 | if ((IsEnter(node) && !IsRefType(node->input_type(0))) || |
392 | IsNextIteration(node)) { |
393 | const Edge* data_edge = nullptr; |
394 | for (const Edge* out_edge : node->out_edges()) { |
395 | if (!out_edge->IsControlEdge()) { |
396 | data_edge = out_edge; |
397 | break; |
398 | } |
399 | } |
400 | // Colocate with the first downstream data node. |
401 | if (data_edge) { |
402 | node->set_assigned_device_name( |
403 | data_edge->dst()->assigned_device_name()); |
404 | } |
405 | } |
406 | } |
407 | }; |
408 | DFS(*graph, visit, {}); |
409 | } |
410 | |
411 | string ControlLoopName(const string& name) { |
412 | return strings::StrCat("_cloop" , name); |
413 | } |
414 | |
415 | bool IsControlLoop(const Node* node) { |
416 | const string& name = node->name(); |
417 | return absl::StartsWith(name, "_cloop" ); |
418 | } |
419 | |
420 | // An enter node for control flow. |
421 | Node* AddControlEnter(Graph* g, const string& node_name, |
422 | const string& device_name, const string& frame_name, |
423 | const int parallel_iterations, Status* status) { |
424 | NodeBuilder node_builder(node_name, "Enter" , g->op_registry()); |
425 | node_builder.Input({"dummy" , 0, DT_FLOAT}); |
426 | node_builder.Attr("frame_name" , frame_name); |
427 | node_builder.Attr("parallel_iterations" , parallel_iterations); |
428 | Node* res_node; |
429 | *status = node_builder.Finalize(g, &res_node, /*consume=*/true); |
430 | if (!status->ok()) return nullptr; |
431 | res_node->set_assigned_device_name(device_name); |
432 | return res_node; |
433 | } |
434 | |
435 | // A merge node for control flow. |
436 | Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g, |
437 | const string& node_name, const string& device_name, |
438 | Status* status) { |
439 | NodeBuilder node_builder(node_name, "Merge" , g->op_registry()); |
440 | node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}}); |
441 | Node* res_node; |
442 | *status = node_builder.Finalize(g, &res_node, /*consume=*/true); |
443 | if (!status->ok()) return nullptr; |
444 | res_node->set_assigned_device_name(device_name); |
445 | return res_node; |
446 | } |
447 | |
448 | // A switch node for control flow. |
449 | Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2, |
450 | const string& device_name, |
451 | const GraphDefBuilder::Options& bopts) { |
452 | Node* res_node = |
453 | ops::BinaryOp("Switch" , std::move(input1), std::move(input2), bopts); |
454 | if (bopts.HaveError()) return nullptr; |
455 | res_node->set_assigned_device_name(device_name); |
456 | return res_node; |
457 | } |
458 | |
459 | // A next_iteration node for control flow. |
460 | Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name, |
461 | const GraphDefBuilder::Options& bopts) { |
462 | Node* res_node = ops::UnaryOp("NextIteration" , std::move(input), bopts); |
463 | if (bopts.HaveError()) return nullptr; |
464 | res_node->set_assigned_device_name(device_name); |
465 | return res_node; |
466 | } |
467 | |
468 | Node* EmptyConst(const GraphDefBuilder::Options& options) { |
469 | if (options.HaveError()) return nullptr; |
470 | NodeBuilder node_builder(options.GetNameForOp("Const" ), "Const" , |
471 | options.op_registry()); |
472 | const DataType dt = DataTypeToEnum<float>::v(); |
473 | TensorProto proto; |
474 | proto.set_dtype(dt); |
475 | TensorShape empty_shape({0}); |
476 | empty_shape.AsProto(proto.mutable_tensor_shape()); |
477 | node_builder.Attr("dtype" , dt).Attr("value" , proto); |
478 | return options.FinalizeBuilder(&node_builder); |
479 | } |
480 | |
481 | // A dummy const node for control flow. |
482 | Node* AddControlConst(const string& device_name, |
483 | const GraphDefBuilder::Options& bopts) { |
484 | Node* res_node = EmptyConst(bopts); |
485 | if (bopts.HaveError()) return nullptr; |
486 | res_node->set_assigned_device_name(device_name); |
487 | return res_node; |
488 | } |
489 | |
490 | // A synthetic loop, made up of dummy nodes. It performs control-flow actions |
491 | // on behalf of a leader on a different device. |
492 | struct ControlLoop { |
493 | Node* enter = nullptr; |
494 | Node* merge = nullptr; |
495 | Node* switch_node = nullptr; |
496 | }; |
497 | |
498 | // Add the control flow info of a new node added during partitioning. |
499 | // The new node has the same control flow info as src. |
500 | void AddControlFlowInfo(const Node* node, const Node* src, |
501 | std::vector<ControlFlowInfo>* cf_info) { |
502 | int id = node->id(); |
503 | if (static_cast<size_t>(id) >= cf_info->size()) { |
504 | cf_info->resize(id + 1); |
505 | } |
506 | const ControlFlowInfo& src_info = (*cf_info)[src->id()]; |
507 | ControlFlowInfo* info = &(*cf_info)[id]; |
508 | info->frame = src_info.frame; |
509 | info->parent_frame = src_info.parent_frame; |
510 | info->frame_name = src_info.frame_name; |
511 | } |
512 | |
513 | // Constructs a control loop. Returns a struct containing the newly created |
514 | // enter, merge, and switch nodes. The enter and merge nodes are used in the |
515 | // recursive construction of control loops for nested frames (loops). The |
516 | // switch node will be connected to the LoopCond node. The merge node will |
517 | // be connected to all the recvs of the same frame by control edges when |
518 | // the actual partitioning happens. |
519 | Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src, |
520 | const Edge* edge, Node* loop_cond, |
521 | std::vector<ControlFlowInfo>* cf_info, |
522 | ControlLoop* loop) { |
523 | Status status; |
524 | GraphDefBuilder::Options bopts(g, &status); |
525 | const ControlFlowInfo& src_info = (*cf_info)[src->id()]; |
526 | const string& device_name = edge->dst()->assigned_device_name(); |
527 | const string& frame_name = src_info.frame_name; |
528 | int parallel_iterations; |
529 | status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations" , |
530 | ¶llel_iterations); |
531 | if (!status.ok()) return status; |
532 | |
533 | // The names of the nodes to be added. |
534 | const string& enter_name = |
535 | ControlLoopName(opts.new_name(edge->dst()->name())); |
536 | const string& merge_name = |
537 | ControlLoopName(opts.new_name(edge->dst()->name())); |
538 | const string& switch_name = |
539 | ControlLoopName(opts.new_name(edge->dst()->name())); |
540 | const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name())); |
541 | |
542 | // Add the nodes to the graph g. |
543 | Node* enter = AddControlEnter(g, enter_name, device_name, frame_name, |
544 | parallel_iterations, &status); |
545 | if (!status.ok()) return status; |
546 | Node* merge = AddControlMerge(enter_name, next_name, g, merge_name, |
547 | device_name, &status); |
548 | if (!status.ok()) return status; |
549 | Node* switch_node = AddControlSwitch(merge, loop_cond, device_name, |
550 | bopts.WithName(switch_name)); |
551 | if (!status.ok()) return status; |
552 | Node* next = |
553 | AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name)); |
554 | if (!status.ok()) return status; |
555 | |
556 | // Add control flow info for these new nodes: |
557 | AddControlFlowInfo(enter, src, cf_info); |
558 | AddControlFlowInfo(merge, src, cf_info); |
559 | AddControlFlowInfo(switch_node, src, cf_info); |
560 | AddControlFlowInfo(next, src, cf_info); |
561 | |
562 | // Add input edges for the newly created merge node: |
563 | g->AddEdge(enter, 0, merge, 0); |
564 | g->AddEdge(next, 0, merge, 1); |
565 | |
566 | loop->enter = enter; |
567 | loop->merge = merge; |
568 | loop->switch_node = switch_node; |
569 | return OkStatus(); |
570 | } |
571 | |
572 | // Build memory and device type info for every node in the graph. |
573 | // TODO(yuanbyu): It might be simpler if we convert MemoryType to |
574 | // DeviceType for the inputs/outputs of each node. |
575 | Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) { |
576 | MemoryTypeVector input_memory_types; |
577 | MemoryTypeVector output_memory_types; |
578 | |
579 | info->device_types.resize(g.num_node_ids(), DEVICE_CPU); |
580 | for (const Node* node : g.op_nodes()) { |
581 | DeviceNameUtils::ParsedName parsed; |
582 | if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(), |
583 | &parsed)) { |
584 | return errors::Internal("Malformed assigned device '" , |
585 | node->assigned_device_name(), "'" ); |
586 | } |
587 | |
588 | TF_RETURN_IF_ERROR(MemoryTypesForNode( |
589 | g.op_registry(), DeviceType(parsed.type), node->def(), |
590 | &input_memory_types, &output_memory_types)); |
591 | |
592 | int node_id = node->id(); |
593 | info->device_types[node_id] = DeviceType(parsed.type); |
594 | for (int i = 0; i < input_memory_types.size(); ++i) { |
595 | info->input_types[{node_id, i}] = input_memory_types[i]; |
596 | } |
597 | for (int i = 0; i < output_memory_types.size(); ++i) { |
598 | info->output_types[{node_id, i}] = output_memory_types[i]; |
599 | } |
600 | } |
601 | return OkStatus(); |
602 | } |
603 | |
604 | const Node* InputFrame(const Node* node, |
605 | const std::vector<ControlFlowInfo>& cf_info) { |
606 | // An input is in the same frame as the node except for Enter nodes. |
607 | // The input of Enter is in the parent frame of the Enter node. |
608 | if (!node->IsEnter()) { |
609 | return node; |
610 | } |
611 | return cf_info[node->id()].parent_frame; |
612 | } |
613 | |
614 | const Node* OutputFrame(const Node* node, |
615 | const std::vector<ControlFlowInfo>& cf_info) { |
616 | // An output is in the same frame as the node except for Exit nodes. |
617 | // The output of Exit is in the parent frame of the Exit node. |
618 | if (!node->IsExit()) { |
619 | return node; |
620 | } |
621 | return cf_info[node->id()].parent_frame; |
622 | } |
623 | |
624 | // Each participating device needs to decide a) if there is a next iteration, |
625 | // and b) if the loop terminates. We take the approach to encode this control |
626 | // flow logic in the dataflow graph. There are at least two possible encodings. |
627 | // In a completely decentralized encoding, the participants communicate peer |
628 | // to peer. The other encoding uses a frame leader (the participant who owns |
629 | // the pivot termination predicate) to broadcast the termination condition to |
630 | // all the participants. For now we take the latter because it is simpler. |
631 | // |
632 | // TODO(yuanbyu): The correctness of this construction is rather subtle. I got |
633 | // it wrong many times so it would be nice to write a proof to be sure. |
634 | Status AddControlFlow(const PartitionOptions& opts, Graph* g, |
635 | GraphInfo* g_info) { |
636 | Status status; |
637 | GraphDefBuilder::Options bopts(g, &status); |
638 | std::vector<ControlFlowInfo>& cf_info = g_info->cf_info; |
639 | |
640 | // Build the control flow info for every node. |
641 | status = BuildControlFlowInfo(g, &cf_info); |
642 | if (!status.ok()) return status; |
643 | |
644 | OptimizeControlFlowColocation(g); |
645 | |
646 | // The map from frames to their LoopCond nodes. |
647 | std::unordered_map<string, Node*> frame_cond_map; |
648 | int num_node_ids = g->num_node_ids(); |
649 | for (int i = 0; i < num_node_ids; ++i) { |
650 | Node* node = g->FindNodeId(i); |
651 | if (node == nullptr) continue; |
652 | |
653 | if (IsLoopCond(node)) { |
654 | const string& frame_name = cf_info[node->id()].frame_name; |
655 | DCHECK(!frame_name.empty()); |
656 | frame_cond_map[frame_name] = node; |
657 | } |
658 | } |
659 | |
660 | // Add all control loops for cross-device frames. |
661 | // A control loop is added only when there is a cross-device edge in a |
662 | // non-root frame. Nothing is added if there is no loops. We also don't |
663 | // add anything for a frame that is completely local to a device. For |
664 | // nested loops, we stack the control loops together by connecting |
665 | // the merge of the outer loop to the enter of the inner loop. |
666 | // |
667 | // A map from <frame_name, device_name> to ControlLoop. |
668 | std::unordered_map<string, ControlLoop> control_loops; |
669 | int num_edge_ids = g->num_edge_ids(); |
670 | for (int i = 0; i < num_edge_ids; ++i) { |
671 | const Edge* edge = g->FindEdgeId(i); |
672 | if (edge == nullptr) continue; |
673 | |
674 | const Node* src = edge->src(); |
675 | const Node* dst = edge->dst(); |
676 | // Skip Sink/Source nodes. |
677 | if (!src->IsOp() || !dst->IsOp()) continue; |
678 | |
679 | const string& src_device = src->assigned_device_name(); |
680 | const string& dst_device = dst->assigned_device_name(); |
681 | // Skip local edges. |
682 | if (src_device == dst_device) continue; |
683 | |
684 | const Node* src_frame = OutputFrame(src, cf_info); |
685 | const Node* dst_frame = InputFrame(dst, cf_info); |
686 | const string& src_frame_name = cf_info[src_frame->id()].frame_name; |
687 | const string& dst_frame_name = cf_info[dst_frame->id()].frame_name; |
688 | // Skip if src and dst are not in the same frame. |
689 | if (src_frame_name.empty() || src_frame_name != dst_frame_name) { |
690 | continue; |
691 | } |
692 | |
693 | // Add the control loop. Start by adding the control loop for the |
694 | // current frame if needed, and recursively adding the control loop |
695 | // for its outer frame when nested. |
696 | ControlLoop child_loop; |
697 | while (true) { |
698 | const string& curr_frame_name = cf_info[src_frame->id()].frame_name; |
699 | if (curr_frame_name.empty()) { |
700 | // We have reached the root frame. |
701 | if (child_loop.merge != nullptr) { |
702 | const string& node_name = opts.new_name(edge->dst()->name()); |
703 | const string& device_name = edge->dst()->assigned_device_name(); |
704 | Node* const_node = |
705 | AddControlConst(device_name, bopts.WithName(node_name)); |
706 | if (!status.ok()) return status; |
707 | AddControlFlowInfo(const_node, src_frame, &cf_info); |
708 | g->AddEdge(const_node, 0, child_loop.enter, 0); |
709 | } |
710 | break; |
711 | } |
712 | |
713 | const string& cl_key = strings::StrCat(curr_frame_name, "$$" , dst_device); |
714 | auto it = control_loops.find(cl_key); |
715 | if (it != control_loops.end()) { |
716 | if (child_loop.enter != nullptr) { |
717 | g->AddEdge(it->second.merge, 0, child_loop.enter, 0); |
718 | } |
719 | break; |
720 | } |
721 | |
722 | // Get the frame's LoopCond. |
723 | auto cond_it = frame_cond_map.find(curr_frame_name); |
724 | if (cond_it == frame_cond_map.end()) { |
725 | return errors::InvalidArgument( |
726 | "A cross-device loop must have a pivot predicate: " , |
727 | curr_frame_name); |
728 | } |
729 | Node* loop_cond = cond_it->second; |
730 | |
731 | // Add the control loop. |
732 | ControlLoop curr_loop; |
733 | status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info, |
734 | &curr_loop); |
735 | if (!status.ok()) return status; |
736 | control_loops[cl_key] = curr_loop; |
737 | |
738 | if (child_loop.enter != nullptr) { |
739 | // Connect the merge of the outer loop to the enter of the inner. |
740 | g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0); |
741 | } |
742 | src_frame = cf_info[src_frame->id()].parent_frame; |
743 | child_loop = curr_loop; |
744 | } |
745 | } |
746 | |
747 | // For a cross-device edge, on the dst device, add a control edge |
748 | // from the merge node of the control loop to dst. If a send/recv is |
749 | // introduced for this edge in future partitioning, we delete this |
750 | // control edge and add a new control edge from the merge to the recv. |
751 | num_edge_ids = g->num_edge_ids(); |
752 | for (int i = 0; i < num_edge_ids; ++i) { |
753 | const Edge* edge = g->FindEdgeId(i); |
754 | if (edge == nullptr) continue; |
755 | |
756 | const Node* src = edge->src(); |
757 | Node* dst = edge->dst(); |
758 | // Skip Sink/Source nodes. |
759 | if (!src->IsOp() || !dst->IsOp()) continue; |
760 | |
761 | const string& src_device = src->assigned_device_name(); |
762 | const string& dst_device = dst->assigned_device_name(); |
763 | if (src_device != dst_device) { |
764 | const Node* src_frame = OutputFrame(src, cf_info); |
765 | const Node* dst_frame = InputFrame(dst, cf_info); |
766 | const string& src_frame_name = cf_info[src_frame->id()].frame_name; |
767 | const string& dst_frame_name = cf_info[dst_frame->id()].frame_name; |
768 | if (!src_frame_name.empty() && src_frame_name == dst_frame_name) { |
769 | const string& cl_key = |
770 | strings::StrCat(dst_frame_name, "$$" , dst_device); |
771 | ControlLoop loop = control_loops[cl_key]; |
772 | DCHECK(loop.enter != nullptr); |
773 | // Note that we'll create multiple duplicate edges if dst has multiple |
774 | // cross-device inputs. This is expected by the logic in Partition(), so |
775 | // it can add control edges to the recv nodes once they're created. |
776 | g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true); |
777 | } |
778 | } |
779 | } |
780 | return OkStatus(); |
781 | } |
782 | |
783 | struct PriorityTopoSortNode { |
784 | PriorityTopoSortNode(const NodeDef* n, int64_t st) |
785 | : node(n), start_time(st) {} |
786 | |
787 | const NodeDef* node; |
788 | int64_t start_time; |
789 | }; |
790 | |
791 | struct PriorityTopoSortNodeGreater { |
792 | bool operator()(const PriorityTopoSortNode& left, |
793 | const PriorityTopoSortNode& right) { |
794 | return left.start_time > right.start_time; |
795 | } |
796 | }; |
797 | |
798 | } // namespace |
799 | |
800 | // Returns in <nodes> the nodes that should participate in epoch-based recv |
801 | // scheduling, along with their times; <nodes> is ordered by increasing |
802 | // start_time. Returns in <node_to_start_time_out> the timing for all nodes, |
803 | // even those not in <nodes>. |
804 | // |
805 | // Comparing to sorting on the node's start time only, this also processes the |
806 | // nodes in dependency order, and updates start times to ensure a node's |
807 | // start_time > the start time for all dependencies. |
808 | // |
809 | // Note that graph_partition_test.cc accesses this function for testing, even |
810 | // though it's not declared in the header. |
811 | Status TopologicalSortNodesWithTimePriority( |
812 | const GraphDef* gdef, |
813 | std::vector<std::pair<const NodeDef*, int64_t>>* nodes, |
814 | std::unordered_map<const NodeDef*, int64_t>* node_to_start_time_out) { |
815 | // Queue of nodes to process; lowest start time is returned first. |
816 | std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>, |
817 | PriorityTopoSortNodeGreater> |
818 | q; |
819 | std::unordered_map<const NodeDef*, int64_t> node_to_start_time; |
820 | auto enqueue = [&q, &node_to_start_time](const NodeDef* node) { |
821 | const int64_t start_time = node_to_start_time[node]; |
822 | q.emplace(node, start_time); |
823 | }; |
824 | |
825 | // Build initial structures, initial contents of queue. |
826 | std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes; |
827 | std::unordered_map<const NodeDef*, int> inputs_needed; |
828 | for (int n = 0; n < gdef->node_size(); ++n) { |
829 | const NodeDef* ndef = &gdef->node(n); |
830 | for (int i = 0; i < ndef->input_size(); ++i) { |
831 | node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)] |
832 | .push_back(ndef); |
833 | } |
834 | int64_t start_time; |
835 | TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time" , &start_time)); |
836 | node_to_start_time[ndef] = start_time; |
837 | inputs_needed[ndef] = ndef->input_size(); |
838 | if (ndef->input_size() == 0) { |
839 | enqueue(ndef); |
840 | } |
841 | } |
842 | |
843 | // Determine which merge nodes are parts of loops; these |
844 | // need to happen in the traversal after all non-NextIteration inputs |
845 | // are run. |
846 | for (int n = 0; n < gdef->node_size(); ++n) { |
847 | const NodeDef* ndef = &gdef->node(n); |
848 | if (IsNextIteration(*ndef)) { |
849 | for (const NodeDef* n : node_to_output_nodes[ndef->name()]) { |
850 | if (IsMerge(*n)) { |
851 | // n is a merge that is part of a loop structure. |
852 | // It doesn't need to wait for this NextIteration loop |
853 | // when doing the traversal. |
854 | --inputs_needed[n]; |
855 | } |
856 | } |
857 | } |
858 | } |
859 | |
860 | // Traverse. |
861 | std::vector<std::pair<const NodeDef*, int64_t>> start_times; |
862 | start_times.reserve(gdef->node_size()); |
863 | while (!q.empty()) { |
864 | PriorityTopoSortNode cur = q.top(); |
865 | q.pop(); |
866 | |
867 | start_times.emplace_back(cur.node, cur.start_time); |
868 | |
869 | for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) { |
870 | auto& output_start_time = node_to_start_time[n]; |
871 | if (output_start_time <= cur.start_time) { |
872 | output_start_time = cur.start_time + 1; |
873 | } |
874 | if (--inputs_needed[n] == 0) { |
875 | enqueue(n); |
876 | } |
877 | } |
878 | } |
879 | |
880 | // Done. |
881 | nodes->swap(start_times); |
882 | node_to_start_time_out->swap(node_to_start_time); |
883 | return OkStatus(); |
884 | } |
885 | |
886 | Status AddControlEdges(const PartitionOptions& opts, |
887 | std::unordered_map<string, GraphDef>* partitions) { |
888 | Status status; |
889 | // TODO(yuanbyu): Very naive for now. To be improved. |
890 | const int num_epochs = 100; |
891 | const int prefetch = 6; |
892 | |
893 | for (auto& part : *partitions) { |
894 | GraphDef* gdef = &part.second; |
895 | std::vector<std::pair<const NodeDef*, int64_t>> start_times; |
896 | std::unordered_map<const NodeDef*, int64_t> node_to_start_time; |
897 | status = TopologicalSortNodesWithTimePriority(gdef, &start_times, |
898 | &node_to_start_time); |
899 | if (!status.ok()) { |
900 | return status; |
901 | } |
902 | |
903 | // Add a dummy node for every epoch, and add a control edge from the |
904 | // "last" node in the preceding epoch to the dummy node. |
905 | string device_name = gdef->node(0).device(); |
906 | int64_t makespan = start_times.back().second; |
907 | int64_t resolution = (makespan / num_epochs) + 1; |
908 | |
909 | int i = 0; |
910 | int j = 0; |
911 | std::vector<NodeDef*> dummys; |
912 | while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) { |
913 | if (i * resolution > start_times[j].second) { |
914 | j++; |
915 | } else { |
916 | NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i, |
917 | i * resolution, &status); |
918 | if (!status.ok()) { |
919 | return status; |
920 | } |
921 | dummys.push_back(dummy); |
922 | if (j > 0) { |
923 | string src_name = start_times[j - 1].first->name(); |
924 | AddInput(dummy, src_name, Graph::kControlSlot); |
925 | } |
926 | i++; |
927 | } |
928 | } |
929 | |
930 | // Finally, add the control edges to recvs. |
931 | for (int n = 0; n < gdef->node_size(); ++n) { |
932 | NodeDef* ndef = gdef->mutable_node(n); |
933 | if (ndef->op() == "_Recv" ) { |
934 | const int64_t start_time = node_to_start_time[ndef]; |
935 | const int recv_epoch = start_time / resolution; |
936 | if (recv_epoch >= prefetch) { |
937 | NodeDef* dummy = dummys[recv_epoch - prefetch]; |
938 | AddInput(ndef, dummy->name(), Graph::kControlSlot); |
939 | } |
940 | } |
941 | } |
942 | } |
943 | return OkStatus(); |
944 | } |
945 | |
946 | // If 'ndef' is a Send or Recv, fills its attr send_device_incarnation |
947 | // if possible. |
948 | void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) { |
949 | StringPiece op(ndef->op()); |
950 | if (op != "_Send" && op != "_Recv" ) { |
951 | // Not related to send/recv. |
952 | return; |
953 | } |
954 | const string& send_device = GetNodeAttrString(*ndef, "send_device" ); |
955 | if (send_device.empty()) { |
956 | // No known send_device. The runtime will detect it later. |
957 | return; |
958 | } |
959 | int64_t incarnation = PartitionOptions::kIllegalIncarnation; |
960 | if (!TryGetNodeAttr(*ndef, "send_device_incarnation" , &incarnation) || |
961 | (incarnation == PartitionOptions::kIllegalIncarnation)) { |
962 | incarnation = opts.get_incarnation(send_device); |
963 | SetAttrValue(incarnation, |
964 | &((*ndef->mutable_attr())["send_device_incarnation" ])); |
965 | } |
966 | } |
967 | |
968 | // Sets attribute send_device_incarnation of all Send/Recv nodes in |
969 | // 'gdef', if possible. |
970 | void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) { |
971 | for (NodeDef& ndef : *gdef->mutable_node()) { |
972 | SetIncarnation(opts, &ndef); |
973 | } |
974 | for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) { |
975 | for (NodeDef& ndef : *fdef.mutable_node_def()) { |
976 | SetIncarnation(opts, &ndef); |
977 | } |
978 | } |
979 | } |
980 | |
981 | Status Partition(const PartitionOptions& opts, Graph* g, |
982 | std::unordered_map<string, GraphDef>* partitions) { |
983 | Status status; |
984 | partitions->clear(); |
985 | |
986 | GraphInfo g_info; |
987 | if (!opts.control_flow_added) { |
988 | // Add the "code" for distributed execution of control flow. Code is |
989 | // added only for the frames that are placed on multiple devices. The |
990 | // new graph is an equivalent transformation of the original graph and |
991 | // has the property that it can be subsequently partitioned arbitrarily |
992 | // (down to the level of individual device) for distributed execution. |
993 | status = AddControlFlow(opts, g, &g_info); |
994 | if (!status.ok()) return status; |
995 | } |
996 | |
997 | // At this point, all the graph mutations have been done. Build memory |
998 | // and device type info for every node and edge in the graph. |
999 | status = BuildMemoryDeviceInfo(*g, &g_info); |
1000 | if (!status.ok()) return status; |
1001 | |
1002 | string dstp; |
1003 | std::vector<const Edge*> inputs; |
1004 | DupRecvTable dup_recv(3); |
1005 | // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref |
1006 | // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref |
1007 | // edge to dst. We will add a control edge for every pair in |
1008 | // (ref_recvs x ref_control_inputs). |
1009 | std::vector<NodeDef*> ref_recvs; |
1010 | std::vector<string> ref_control_inputs; |
1011 | |
1012 | int32_t num_data = 0; |
1013 | int32_t num_control = 0; |
1014 | for (const Node* dst : g->op_nodes()) { |
1015 | dstp = opts.node_to_loc(dst); |
1016 | GraphDef* dst_graph = &(*partitions)[dstp]; |
1017 | NodeDef* dst_def = dst_graph->add_node(); |
1018 | *dst_def = dst->def(); |
1019 | MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def); |
1020 | dst_def->set_device(dst->assigned_device_name()); |
1021 | dst_def->clear_input(); // Inputs are filled below |
1022 | if (opts.need_to_record_start_times) { |
1023 | int64_t start_time; |
1024 | status = GetNodeAttr(*dst_def, "_start_time" , &start_time); |
1025 | if (errors::IsNotFound(status)) { |
1026 | start_time = opts.start_times[dst->id()].value(); |
1027 | AddNodeAttr("_start_time" , start_time, dst_def); |
1028 | } else if (!status.ok()) { |
1029 | return status; |
1030 | } |
1031 | } |
1032 | |
1033 | // Arrange the incoming edges to dst so that input[i] holds the |
1034 | // input flowing into slot numbered i. Trailing entries in input[] |
1035 | // hold control edges. |
1036 | inputs.clear(); |
1037 | inputs.resize(dst->num_inputs(), nullptr); |
1038 | ref_recvs.clear(); |
1039 | ref_control_inputs.clear(); |
1040 | const Edge* control_flow_edge = nullptr; |
1041 | int32_t num_control_flow_edges = 0; |
1042 | int32_t num_input_edges = 0; |
1043 | for (const Edge* edge : dst->in_edges()) { |
1044 | if (edge->IsControlEdge()) { |
1045 | if (IsMerge(edge->src()) && IsControlLoop(edge->src())) { |
1046 | // This is one of the control edges added for control flow. There |
1047 | // can be multiple such edges as the dest node may have multiple |
1048 | // remote inputs. We keep track of the number of such edges. |
1049 | control_flow_edge = edge; |
1050 | ++num_control_flow_edges; |
1051 | } else { |
1052 | inputs.push_back(edge); |
1053 | } |
1054 | } else { |
1055 | DCHECK(inputs[edge->dst_input()] == nullptr); |
1056 | inputs[edge->dst_input()] = edge; |
1057 | ++num_input_edges; |
1058 | } |
1059 | } |
1060 | |
1061 | if (num_input_edges != dst->num_inputs()) { |
1062 | return errors::InvalidArgument("Incomplete graph, missing " , |
1063 | (dst->num_inputs() - num_input_edges), |
1064 | " inputs for " , dst->name()); |
1065 | } |
1066 | |
1067 | // Process in order so that all data edges are added as inputs to |
1068 | // dst in Edge::dst_input() order. |
1069 | for (const Edge* edge : inputs) { |
1070 | const Node* src = edge->src(); |
1071 | if (!src->IsOp()) continue; // Skip Sink/Source nodes. |
1072 | |
1073 | GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; |
1074 | if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) { |
1075 | // Same partition and compatible memory types: |
1076 | AddInput(dst_def, src->name(), edge->src_output()); |
1077 | if (edge->IsControlEdge() || |
1078 | !IsRefType(src->output_type(edge->src_output()))) { |
1079 | ref_control_inputs.push_back(src->name()); |
1080 | } |
1081 | continue; |
1082 | } |
1083 | |
1084 | int64_t send_start_time = 0; |
1085 | int64_t recv_start_time = 0; |
1086 | if (opts.scheduling_for_recvs) { |
1087 | status = GetNodeAttr(src->attrs(), "_start_time" , &send_start_time); |
1088 | if (errors::IsNotFound(status) && opts.need_to_record_start_times) { |
1089 | send_start_time = opts.start_times[src->id()].value(); |
1090 | } else if (!status.ok()) { |
1091 | return status; |
1092 | } |
1093 | |
1094 | status = GetNodeAttr(dst->attrs(), "_start_time" , &recv_start_time); |
1095 | if (errors::IsNotFound(status) && opts.need_to_record_start_times) { |
1096 | recv_start_time = opts.start_times[dst->id()].value(); |
1097 | } else if (!status.ok()) { |
1098 | return status; |
1099 | } |
1100 | } |
1101 | |
1102 | // Check whether there is already a send/recv pair transferring |
1103 | // the same tensor/control from the src to dst partition. |
1104 | const bool on_host = IsDstInputOnHost(edge, g_info); |
1105 | DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host}; |
1106 | auto iter = dup_recv.find(key); |
1107 | if (iter != dup_recv.end()) { |
1108 | // We found one. Reuse the data/control transferred already. |
1109 | const string& recv_node_name = iter->second.recv->name(); |
1110 | if (edge->IsControlEdge()) { |
1111 | AddInput(dst_def, recv_node_name, Graph::kControlSlot); |
1112 | } else { |
1113 | AddInput(dst_def, recv_node_name, 0); |
1114 | } |
1115 | ref_control_inputs.push_back(recv_node_name); |
1116 | |
1117 | // We want the start_time for the recv to be the smallest of the start |
1118 | // times of it's consumers. So we update this whenever we use a recv, |
1119 | // and write it out to the attribute at the end of the subroutine |
1120 | if (iter->second.start_time > recv_start_time) { |
1121 | iter->second.start_time = recv_start_time; |
1122 | } |
1123 | continue; |
1124 | } |
1125 | |
1126 | NodeDefBuilder::NodeOut send_from; |
1127 | if (edge->IsControlEdge()) { |
1128 | // Insert a dummy const node that will generate a tiny |
1129 | // data element to be sent from send to recv. |
1130 | VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "[" |
1131 | << src->name() << "] -> " << dst->assigned_device_name() << "[" |
1132 | << dst->name() << "]" ; |
1133 | NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status); |
1134 | if (!status.ok()) return status; |
1135 | // Set the start time for this dummy node. |
1136 | if (opts.scheduling_for_recvs) { |
1137 | AddNodeAttr("_start_time" , send_start_time, dummy); |
1138 | } |
1139 | AddInput(dummy, src->name(), Graph::kControlSlot); |
1140 | send_from.Reset(dummy->name(), 0, DT_FLOAT); |
1141 | } else { |
1142 | send_from.Reset(src->name(), edge->src_output(), EdgeType(edge)); |
1143 | } |
1144 | |
1145 | string tensor_name_attr; |
1146 | if (opts.get_tensor_name_attr) { |
1147 | tensor_name_attr = opts.get_tensor_name_attr(edge); |
1148 | } else { |
1149 | tensor_name_attr = |
1150 | strings::StrCat("edge_" , edge->id(), "_" , edge->src()->name()); |
1151 | } |
1152 | |
1153 | // Need to split edge by placing matching send/recv nodes on |
1154 | // the src/dst sides of the edge. |
1155 | NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from, |
1156 | send_start_time, tensor_name_attr, &status); |
1157 | if (!status.ok()) return status; |
1158 | |
1159 | NodeDef* real_recv = nullptr; |
1160 | NodeDef* recv = AddRecv(opts, g_info, dst_graph, edge, &real_recv, |
1161 | tensor_name_attr, &status); |
1162 | if (!status.ok()) return status; |
1163 | |
1164 | // Fix up the control flow edge. |
1165 | // NOTE(yuanbyu): 'real_recv' must be the real recv node. |
1166 | if (src_graph == dst_graph) { |
1167 | // For same device send/recv, add a control edge from send to recv. |
1168 | // This prevents the asynchronous recv kernel from being scheduled |
1169 | // before the data is available. |
1170 | AddInput(real_recv, send->name(), Graph::kControlSlot); |
1171 | } else if (control_flow_edge != nullptr) { |
1172 | // Redirect control edge to the real recv since this is not the same |
1173 | // device send/recv. |
1174 | --num_control_flow_edges; |
1175 | AddInput(real_recv, control_flow_edge->src()->name(), |
1176 | Graph::kControlSlot); |
1177 | } |
1178 | |
1179 | if (!edge->IsControlEdge() && |
1180 | IsRefType(src->output_type(edge->src_output()))) { |
1181 | AddNodeAttr("_start_time" , recv_start_time, recv); |
1182 | if (real_recv != recv) { |
1183 | AddNodeAttr("_start_time" , recv_start_time, real_recv); |
1184 | } |
1185 | // If src is of ref type and the edge is not a control edge, dst has |
1186 | // read semantics and therefore we must control the recv. |
1187 | ref_recvs.push_back(real_recv); |
1188 | } else { |
1189 | // Memorize the send/recv pair, only if this is not a "ref" edge. |
1190 | // NOTE(yuanbyu): Collapsing ref edges requires extreme care so |
1191 | // for now we don't do it. |
1192 | dup_recv[key] = {recv, real_recv, recv_start_time}; |
1193 | ref_control_inputs.push_back(recv->name()); |
1194 | } |
1195 | |
1196 | if (edge->IsControlEdge()) { |
1197 | ++num_control; |
1198 | AddInput(dst_def, recv->name(), Graph::kControlSlot); |
1199 | } else { |
1200 | ++num_data; |
1201 | AddInput(dst_def, recv->name(), 0); |
1202 | } |
1203 | } |
1204 | |
1205 | // Add control edges from 'ref_control_inputs' to 'ref_recvs'. |
1206 | // NOTE(yuanbyu): Adding these control edges should not introduce |
1207 | // deadlocks. 'dst' has implicit "read" nodes that, when we split |
1208 | // across devices, are made explicit; Retargeting the dependencies |
1209 | // to 'dst' to those nodes would not introduce cycles if there isn't |
1210 | // one before the transformation. |
1211 | // NOTE(yuanbyu): This may impact performance because it defers the |
1212 | // execution of recvs until all the other inputs become available. |
1213 | AddReadControl(ref_recvs, ref_control_inputs); |
1214 | |
1215 | // Add back the control edges for control flow that are not used. |
1216 | if (control_flow_edge != nullptr) { |
1217 | for (int i = 0; i < num_control_flow_edges; ++i) { |
1218 | AddInput(dst_def, control_flow_edge->src()->name(), |
1219 | Graph::kControlSlot); |
1220 | } |
1221 | } |
1222 | } |
1223 | |
1224 | const FunctionLibraryDefinition* flib_def = opts.flib_def; |
1225 | if (flib_def == nullptr) { |
1226 | flib_def = &g->flib_def(); |
1227 | } |
1228 | |
1229 | // Set versions, function library and send/recv incarnation. |
1230 | for (auto& it : *partitions) { |
1231 | GraphDef* gdef = &it.second; |
1232 | *gdef->mutable_versions() = g->versions(); |
1233 | // Prune unreachable functions from `flib_def` before adding them to `gdef`. |
1234 | *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto(); |
1235 | |
1236 | // Traverse the graph to fill every send/recv op's incarnation |
1237 | // information. |
1238 | SetIncarnation(opts, gdef); |
1239 | } |
1240 | |
1241 | // Set the start times for recvs at the very end. |
1242 | if (opts.scheduling_for_recvs) { |
1243 | for (auto& it : dup_recv) { |
1244 | AddNodeAttr("_start_time" , it.second.start_time, it.second.recv); |
1245 | if (it.second.real_recv != it.second.recv) { |
1246 | AddNodeAttr("_start_time" , it.second.start_time, it.second.real_recv); |
1247 | } |
1248 | } |
1249 | } |
1250 | |
1251 | VLOG(1) << "Added send/recv: controls=" << num_control |
1252 | << ", data=" << num_data; |
1253 | if (VLOG_IS_ON(2)) { |
1254 | for (auto& it : *partitions) { |
1255 | GraphDef* gdef = &it.second; |
1256 | DumpGraphDefToFile(strings::StrCat("partition_" , it.first, "_" , |
1257 | reinterpret_cast<uintptr_t>(gdef)), |
1258 | *gdef); |
1259 | } |
1260 | } |
1261 | return OkStatus(); |
1262 | } |
1263 | |
1264 | } // namespace tensorflow |
1265 | |