1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/graph/graph_partition.h"
17
18#include <deque>
19#include <queue>
20#include <unordered_map>
21#include <unordered_set>
22#include <utility>
23#include <vector>
24
25#include "absl/container/flat_hash_map.h"
26#include "tensorflow/core/framework/function.h"
27#include "tensorflow/core/framework/memory_types.h"
28#include "tensorflow/core/framework/node_def_builder.h"
29#include "tensorflow/core/framework/tensor.pb.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/framework/versions.pb.h"
32#include "tensorflow/core/graph/algorithm.h"
33#include "tensorflow/core/graph/control_flow.h"
34#include "tensorflow/core/graph/costmodel.h"
35#include "tensorflow/core/graph/graph_def_builder.h"
36#include "tensorflow/core/graph/node_builder.h"
37#include "tensorflow/core/graph/tensor_id.h"
38#include "tensorflow/core/lib/core/errors.h"
39#include "tensorflow/core/lib/hash/hash.h"
40#include "tensorflow/core/lib/strings/str_util.h"
41#include "tensorflow/core/platform/logging.h"
42#include "tensorflow/core/util/device_name_utils.h"
43#include "tensorflow/core/util/dump_graph.h"
44
45namespace tensorflow {
46
47namespace {
48
49inline bool IsMerge(const NodeDef& node_def) {
50 return node_def.op() == "Merge" || node_def.op() == "RefMerge" ||
51 node_def.op() == "_XlaMerge";
52}
53
54inline bool IsNextIteration(const NodeDef& node_def) {
55 return node_def.op() == "NextIteration" ||
56 node_def.op() == "RefNextIteration";
57}
58
59struct DupRecvKey {
60 int src_node_id; // Edge's src node id
61 int src_output_slot; // Edge's src node output slot
62 GraphDef* dst_graph; // Edge's dst node is in this subgraph
63 bool recv_output_on_host; // The output of recv is on host
64
65 template <typename H>
66 friend H AbslHashValue(H h, const DupRecvKey& c) {
67 return H::combine(std::move(h), c.src_node_id, c.src_output_slot,
68 reinterpret_cast<std::uintptr_t>(c.dst_graph),
69 c.recv_output_on_host);
70 }
71
72 friend bool operator==(const DupRecvKey& x, const DupRecvKey& y) {
73 return (x.src_node_id == y.src_node_id) &&
74 (x.src_output_slot == y.src_output_slot) &&
75 (x.dst_graph == y.dst_graph) &&
76 (x.recv_output_on_host == y.recv_output_on_host);
77 }
78};
79
80// struct used to store the recvs, so that start times can be properly updated
81struct RecvInfo {
82 NodeDef* recv;
83 NodeDef* real_recv;
84 int64_t start_time;
85};
86
87typedef absl::flat_hash_map<DupRecvKey, RecvInfo> DupRecvTable;
88
89// A map used to store memory types for the inputs/outputs of every node.
90// The key is a pair of ints consisting of a node id and input/output index.
91// TODO(power): migrate back to std::pair when absl::Hash is fixed for MSVC.
92struct NodePort {
93 int node_id;
94 int index;
95
96 friend bool operator==(const NodePort& x, const NodePort& y) {
97 return x.node_id == y.node_id && x.index == y.index;
98 }
99
100 template <typename H>
101 friend H AbslHashValue(H h, const NodePort& c) {
102 return H::combine(std::move(h), c.node_id, c.index);
103 }
104};
105
106typedef absl::flat_hash_map<NodePort, MemoryType> MemoryTypeMap;
107
108// We collect the following information about the graph before performing
109// graph partitioning.
110struct GraphInfo {
111 std::vector<DeviceType> device_types;
112 MemoryTypeMap input_types;
113 MemoryTypeMap output_types;
114 std::vector<ControlFlowInfo> cf_info;
115};
116
117DataType EdgeType(const Edge* e) {
118 if (e->IsControlEdge()) {
119 return DT_FLOAT;
120 } else {
121 return e->dst()->input_type(e->dst_input());
122 }
123}
124
125// Return true iff we need to add the same device send/recv for 'edge'.
126bool NeedSameDeviceSendRecv(const Edge* edge, const GraphInfo& info) {
127 if (edge->IsControlEdge()) {
128 return false;
129 }
130
131 const Node* src = edge->src();
132 const Node* dst = edge->dst();
133 if (src->assigned_device_name() == dst->assigned_device_name()) {
134 int src_port = edge->src_output();
135 int dst_port = edge->dst_input();
136 if (info.device_types[src->id()] != DEVICE_CPU) {
137 auto src_it = info.output_types.find({src->id(), src_port});
138 DCHECK(src_it != info.output_types.end());
139 auto dst_it = info.input_types.find({dst->id(), dst_port});
140 DCHECK(dst_it != info.input_types.end());
141 return src_it->second != dst_it->second;
142 }
143 }
144 return false;
145}
146
147// Return true iff (dst, dst_input) is specified on host memory.
148bool IsDstInputOnHost(const Edge* edge, const GraphInfo& info) {
149 const Node* dst = edge->dst();
150 int dst_port = edge->dst_input();
151 if (info.device_types[dst->id()] != DEVICE_CPU) {
152 if (edge->IsControlEdge()) return false;
153 auto dst_it = info.input_types.find({dst->id(), dst_port});
154 DCHECK(dst_it != info.input_types.end());
155 return dst_it->second == HOST_MEMORY;
156 }
157 return true;
158}
159
160// Add an input to dst that comes from the "src_slot" output of the
161// node named by "src_name".
162void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
163 if (src_slot == Graph::kControlSlot) {
164 dst->add_input(strings::StrCat("^", src_name));
165 } else if (src_slot == 0) {
166 dst->add_input(src_name.data(), src_name.size());
167 } else {
168 dst->add_input(strings::StrCat(src_name, ":", src_slot));
169 }
170}
171
172// Add a control edge from each input to each recv.
173void AddReadControl(const std::vector<NodeDef*>& recvs,
174 const std::vector<string>& inputs) {
175 for (NodeDef* recv : recvs) {
176 for (const string& input : inputs) {
177 recv->add_input(strings::StrCat("^", input));
178 }
179 }
180}
181
182void SetSendRecvAttrs(const PartitionOptions& opts, const Edge* edge,
183 const string& tensor_name_attr, NodeDefBuilder* builder) {
184 builder->Attr("tensor_name", tensor_name_attr);
185 builder->Attr("send_device", edge->src()->assigned_device_name());
186 builder->Attr("send_device_incarnation",
187 static_cast<int64_t>(
188 opts.get_incarnation(edge->src()->assigned_device_name())));
189 builder->Attr("recv_device", edge->dst()->assigned_device_name());
190 builder->Attr("client_terminated", false);
191 builder->Attr("_src", edge->src()->name());
192 builder->Attr("_dst", edge->dst()->name());
193}
194
195NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
196 GraphDef* gdef, const Edge* edge,
197 NodeDefBuilder::NodeOut send_from, int64_t start_time,
198 const string& tensor_name_attr, Status* status) {
199 const DataType dtype = send_from.data_type;
200 const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
201 const Node* src = edge->src();
202 const int src_port = edge->src_output();
203
204 // host_memory = true iff we need to use HostSend/HostCast.
205 bool host_memory = false;
206 if (!edge->IsControlEdge()) {
207 auto src_it = g_info.output_types.find({src->id(), src_port});
208 DCHECK(src_it != g_info.output_types.end());
209 host_memory = (src_it->second == HOST_MEMORY);
210 }
211
212 // Add a cast node that casts dtype to cast_dtype.
213 // NOTE(yuanbyu): Only cast for cross-device send/recv.
214 if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
215 const string cast_op = (host_memory) ? "_HostCast" : "Cast";
216 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
217 NodeDebugInfo(*src));
218 cast_builder.Device(src->assigned_device_name()).Input(send_from);
219 if (opts.scheduling_for_recvs) {
220 cast_builder.Attr("_start_time", start_time);
221 }
222 cast_builder.Attr("DstT", cast_dtype);
223
224 if (cast_dtype == DT_BFLOAT16) {
225 // the below attribute specifies that the cast to bfloat16 should use
226 // truncation. This is needed to retain legacy behavior when we change
227 // the default bfloat16 casts to use rounding instead of truncation
228 cast_builder.Attr("Truncate", true);
229 }
230
231 NodeDef* cast = gdef->add_node();
232 *status = cast_builder.Finalize(cast, /*consume=*/true);
233 if (!status->ok()) return nullptr;
234
235 // Connect the Send op to the cast.
236 send_from.Reset(cast->name(), 0, cast_dtype);
237 }
238
239 // Add the send node.
240 const string send_op = (host_memory) ? "_HostSend" : "_Send";
241 NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
242 NodeDebugInfo(*src));
243 SetSendRecvAttrs(opts, edge, tensor_name_attr, &send_builder);
244 send_builder.Device(src->assigned_device_name()).Input(send_from);
245 if (opts.scheduling_for_recvs) {
246 send_builder.Attr("_start_time", start_time);
247 }
248 NodeDef* send = gdef->add_node();
249 *status = send_builder.Finalize(send, /*consume=*/true);
250 return send;
251}
252
253NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
254 GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
255 const string& tensor_name_attr, Status* status) {
256 const DataType dtype = EdgeType(edge);
257 const Node* src = edge->src();
258 const Node* dst = edge->dst();
259 const int dst_port = edge->dst_input();
260 DataType cast_dtype = dtype;
261
262 // NOTE(yuanbyu): Only cast for cross-device send/recv.
263 if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
264 cast_dtype = opts.should_cast(edge);
265 }
266
267 // host_memory = true iff we need to use HostRecv/HostCast.
268 // Also log the introduction of the send-recv pair, for performance debugging.
269 bool host_memory = false;
270 if (!edge->IsControlEdge()) {
271 auto dst_it = g_info.input_types.find({dst->id(), dst_port});
272 DCHECK(dst_it != g_info.input_types.end());
273 host_memory = (dst_it->second == HOST_MEMORY);
274 bool src_host_memory = false;
275 if (VLOG_IS_ON(1)) {
276 const int src_port = edge->src_output();
277 auto src_it = g_info.output_types.find({src->id(), src_port});
278 DCHECK(src_it != g_info.output_types.end());
279 src_host_memory = (src_it->second == HOST_MEMORY);
280 }
281 VLOG(1) << "Receiving data"
282 << " from " << src->name() << " (" << src->type_string() << ")"
283 << " on " << src->assigned_device_name() << " in "
284 << (src_host_memory ? "host memory" : "device memory") << " for "
285 << dst->name() << " (" << dst->type_string() << ")"
286 << " on " << dst->assigned_device_name() << " in "
287 << (host_memory ? "host memory" : "device memory");
288 } else {
289 // Log control-edge transfers too, but don't mention memory space since it's
290 // irrelevant.
291 VLOG(1) << "Receiving control"
292 << " from " << src->name() << " (" << src->type_string() << ")"
293 << " on " << src->assigned_device_name() << " for " << dst->name()
294 << " (" << dst->type_string() << ")"
295 << " on " << dst->assigned_device_name();
296 }
297
298 // Add the recv node.
299 const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
300 NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
301 NodeDebugInfo(*src));
302 SetSendRecvAttrs(opts, edge, tensor_name_attr, &recv_builder);
303 recv_builder.Device(dst->assigned_device_name())
304 .Attr("tensor_type", cast_dtype);
305 NodeDef* recv = gdef->add_node();
306 *status = recv_builder.Finalize(recv, /*consume=*/true);
307 if (!status->ok()) return nullptr;
308 *real_recv = recv;
309
310 // Add the cast node (from cast_dtype to dtype) or an Identity node.
311 if (dtype != cast_dtype) {
312 const string cast_op = (host_memory) ? "_HostCast" : "Cast";
313 NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
314 NodeDebugInfo(*src));
315 cast_builder.Attr("DstT", dtype);
316 cast_builder.Device(dst->assigned_device_name())
317 .Input(recv->name(), 0, cast_dtype);
318 NodeDef* cast = gdef->add_node();
319 *status = cast_builder.Finalize(cast, /*consume=*/true);
320 if (!status->ok()) return nullptr;
321 return cast;
322 } else if (edge->IsControlEdge()) {
323 // An Identity is only needed for control edges.
324 NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
325 NodeDebugInfo(*src));
326 id_builder.Device(dst->assigned_device_name())
327 .Input(recv->name(), 0, cast_dtype);
328 NodeDef* id = gdef->add_node();
329 *status = id_builder.Finalize(id, /*consume=*/true);
330 if (!status->ok()) return nullptr;
331 return id;
332 } else {
333 return recv;
334 }
335}
336
337NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
338 const Edge* edge, Status* status) {
339 const Node* src = edge->src();
340 Tensor tensor(DT_FLOAT, TensorShape({0}));
341 NodeDef* result = gdef->add_node();
342 *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
343 .Device(src->assigned_device_name())
344 .Attr("dtype", DT_FLOAT)
345 .Attr("value", tensor)
346 .Finalize(result, /*consume=*/true);
347 return result;
348}
349
350// A dummy node for scheduling.
351NodeDef* AddControlTrigger(const PartitionOptions& opts, GraphDef* gdef,
352 const string& assigned_device_name, int64_t epoch,
353 int64_t starttime, Status* status) {
354 NodeDef* result = gdef->add_node();
355 *status = NodeDefBuilder(opts.new_name(strings::StrCat("synch_", epoch)),
356 "ControlTrigger")
357 .Device(assigned_device_name)
358 .Attr("_start_time", starttime)
359 .Finalize(result, /*consume=*/true);
360 return result;
361}
362
363// Optimize colocation for control flow nodes. For cond, we want the
364// switch nodes to colocate with its data input. This is particularly
365// needed for conditional reading of a remote variable. It may also
366// reduce the number of devices involved in a loop.
367// TODO(yuanbyu): In this case, we don't respect the requested device in
368// the GraphDef for these nodes. Ideally, the placer would enforce the
369// colocation to render this unnecessary.
370void OptimizeControlFlowColocation(Graph* graph) {
371 auto visit = [](Node* node) {
372 if (IsSwitch(node)) {
373 for (const Edge* in_edge : node->in_edges()) {
374 if (in_edge->dst_input() == 0) {
375 // Colocate with the data input.
376 node->set_assigned_device_name(
377 in_edge->src()->assigned_device_name());
378 return;
379 }
380 }
381 } else if (IsExit(node)) {
382 for (const Edge* in_edge : node->in_edges()) {
383 if (!in_edge->IsControlEdge()) {
384 // Colocate with upstream node.
385 node->set_assigned_device_name(
386 in_edge->src()->assigned_device_name());
387 return;
388 }
389 }
390 } else {
391 if ((IsEnter(node) && !IsRefType(node->input_type(0))) ||
392 IsNextIteration(node)) {
393 const Edge* data_edge = nullptr;
394 for (const Edge* out_edge : node->out_edges()) {
395 if (!out_edge->IsControlEdge()) {
396 data_edge = out_edge;
397 break;
398 }
399 }
400 // Colocate with the first downstream data node.
401 if (data_edge) {
402 node->set_assigned_device_name(
403 data_edge->dst()->assigned_device_name());
404 }
405 }
406 }
407 };
408 DFS(*graph, visit, {});
409}
410
411string ControlLoopName(const string& name) {
412 return strings::StrCat("_cloop", name);
413}
414
415bool IsControlLoop(const Node* node) {
416 const string& name = node->name();
417 return absl::StartsWith(name, "_cloop");
418}
419
420// An enter node for control flow.
421Node* AddControlEnter(Graph* g, const string& node_name,
422 const string& device_name, const string& frame_name,
423 const int parallel_iterations, Status* status) {
424 NodeBuilder node_builder(node_name, "Enter", g->op_registry());
425 node_builder.Input({"dummy", 0, DT_FLOAT});
426 node_builder.Attr("frame_name", frame_name);
427 node_builder.Attr("parallel_iterations", parallel_iterations);
428 Node* res_node;
429 *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
430 if (!status->ok()) return nullptr;
431 res_node->set_assigned_device_name(device_name);
432 return res_node;
433}
434
435// A merge node for control flow.
436Node* AddControlMerge(const string& in_name1, const string& in_name2, Graph* g,
437 const string& node_name, const string& device_name,
438 Status* status) {
439 NodeBuilder node_builder(node_name, "Merge", g->op_registry());
440 node_builder.Input({{in_name1, 0, DT_FLOAT}, {in_name2, 0, DT_FLOAT}});
441 Node* res_node;
442 *status = node_builder.Finalize(g, &res_node, /*consume=*/true);
443 if (!status->ok()) return nullptr;
444 res_node->set_assigned_device_name(device_name);
445 return res_node;
446}
447
448// A switch node for control flow.
449Node* AddControlSwitch(NodeBuilder::NodeOut input1, NodeBuilder::NodeOut input2,
450 const string& device_name,
451 const GraphDefBuilder::Options& bopts) {
452 Node* res_node =
453 ops::BinaryOp("Switch", std::move(input1), std::move(input2), bopts);
454 if (bopts.HaveError()) return nullptr;
455 res_node->set_assigned_device_name(device_name);
456 return res_node;
457}
458
459// A next_iteration node for control flow.
460Node* AddControlNext(NodeBuilder::NodeOut input, const string& device_name,
461 const GraphDefBuilder::Options& bopts) {
462 Node* res_node = ops::UnaryOp("NextIteration", std::move(input), bopts);
463 if (bopts.HaveError()) return nullptr;
464 res_node->set_assigned_device_name(device_name);
465 return res_node;
466}
467
468Node* EmptyConst(const GraphDefBuilder::Options& options) {
469 if (options.HaveError()) return nullptr;
470 NodeBuilder node_builder(options.GetNameForOp("Const"), "Const",
471 options.op_registry());
472 const DataType dt = DataTypeToEnum<float>::v();
473 TensorProto proto;
474 proto.set_dtype(dt);
475 TensorShape empty_shape({0});
476 empty_shape.AsProto(proto.mutable_tensor_shape());
477 node_builder.Attr("dtype", dt).Attr("value", proto);
478 return options.FinalizeBuilder(&node_builder);
479}
480
481// A dummy const node for control flow.
482Node* AddControlConst(const string& device_name,
483 const GraphDefBuilder::Options& bopts) {
484 Node* res_node = EmptyConst(bopts);
485 if (bopts.HaveError()) return nullptr;
486 res_node->set_assigned_device_name(device_name);
487 return res_node;
488}
489
490// A synthetic loop, made up of dummy nodes. It performs control-flow actions
491// on behalf of a leader on a different device.
492struct ControlLoop {
493 Node* enter = nullptr;
494 Node* merge = nullptr;
495 Node* switch_node = nullptr;
496};
497
498// Add the control flow info of a new node added during partitioning.
499// The new node has the same control flow info as src.
500void AddControlFlowInfo(const Node* node, const Node* src,
501 std::vector<ControlFlowInfo>* cf_info) {
502 int id = node->id();
503 if (static_cast<size_t>(id) >= cf_info->size()) {
504 cf_info->resize(id + 1);
505 }
506 const ControlFlowInfo& src_info = (*cf_info)[src->id()];
507 ControlFlowInfo* info = &(*cf_info)[id];
508 info->frame = src_info.frame;
509 info->parent_frame = src_info.parent_frame;
510 info->frame_name = src_info.frame_name;
511}
512
513// Constructs a control loop. Returns a struct containing the newly created
514// enter, merge, and switch nodes. The enter and merge nodes are used in the
515// recursive construction of control loops for nested frames (loops). The
516// switch node will be connected to the LoopCond node. The merge node will
517// be connected to all the recvs of the same frame by control edges when
518// the actual partitioning happens.
519Status AddControlLoop(const PartitionOptions& opts, Graph* g, const Node* src,
520 const Edge* edge, Node* loop_cond,
521 std::vector<ControlFlowInfo>* cf_info,
522 ControlLoop* loop) {
523 Status status;
524 GraphDefBuilder::Options bopts(g, &status);
525 const ControlFlowInfo& src_info = (*cf_info)[src->id()];
526 const string& device_name = edge->dst()->assigned_device_name();
527 const string& frame_name = src_info.frame_name;
528 int parallel_iterations;
529 status = GetNodeAttr(src_info.frame->attrs(), "parallel_iterations",
530 &parallel_iterations);
531 if (!status.ok()) return status;
532
533 // The names of the nodes to be added.
534 const string& enter_name =
535 ControlLoopName(opts.new_name(edge->dst()->name()));
536 const string& merge_name =
537 ControlLoopName(opts.new_name(edge->dst()->name()));
538 const string& switch_name =
539 ControlLoopName(opts.new_name(edge->dst()->name()));
540 const string& next_name = ControlLoopName(opts.new_name(edge->dst()->name()));
541
542 // Add the nodes to the graph g.
543 Node* enter = AddControlEnter(g, enter_name, device_name, frame_name,
544 parallel_iterations, &status);
545 if (!status.ok()) return status;
546 Node* merge = AddControlMerge(enter_name, next_name, g, merge_name,
547 device_name, &status);
548 if (!status.ok()) return status;
549 Node* switch_node = AddControlSwitch(merge, loop_cond, device_name,
550 bopts.WithName(switch_name));
551 if (!status.ok()) return status;
552 Node* next =
553 AddControlNext({switch_node, 1}, device_name, bopts.WithName(next_name));
554 if (!status.ok()) return status;
555
556 // Add control flow info for these new nodes:
557 AddControlFlowInfo(enter, src, cf_info);
558 AddControlFlowInfo(merge, src, cf_info);
559 AddControlFlowInfo(switch_node, src, cf_info);
560 AddControlFlowInfo(next, src, cf_info);
561
562 // Add input edges for the newly created merge node:
563 g->AddEdge(enter, 0, merge, 0);
564 g->AddEdge(next, 0, merge, 1);
565
566 loop->enter = enter;
567 loop->merge = merge;
568 loop->switch_node = switch_node;
569 return OkStatus();
570}
571
572// Build memory and device type info for every node in the graph.
573// TODO(yuanbyu): It might be simpler if we convert MemoryType to
574// DeviceType for the inputs/outputs of each node.
575Status BuildMemoryDeviceInfo(const Graph& g, GraphInfo* info) {
576 MemoryTypeVector input_memory_types;
577 MemoryTypeVector output_memory_types;
578
579 info->device_types.resize(g.num_node_ids(), DEVICE_CPU);
580 for (const Node* node : g.op_nodes()) {
581 DeviceNameUtils::ParsedName parsed;
582 if (!DeviceNameUtils::ParseFullName(node->assigned_device_name(),
583 &parsed)) {
584 return errors::Internal("Malformed assigned device '",
585 node->assigned_device_name(), "'");
586 }
587
588 TF_RETURN_IF_ERROR(MemoryTypesForNode(
589 g.op_registry(), DeviceType(parsed.type), node->def(),
590 &input_memory_types, &output_memory_types));
591
592 int node_id = node->id();
593 info->device_types[node_id] = DeviceType(parsed.type);
594 for (int i = 0; i < input_memory_types.size(); ++i) {
595 info->input_types[{node_id, i}] = input_memory_types[i];
596 }
597 for (int i = 0; i < output_memory_types.size(); ++i) {
598 info->output_types[{node_id, i}] = output_memory_types[i];
599 }
600 }
601 return OkStatus();
602}
603
604const Node* InputFrame(const Node* node,
605 const std::vector<ControlFlowInfo>& cf_info) {
606 // An input is in the same frame as the node except for Enter nodes.
607 // The input of Enter is in the parent frame of the Enter node.
608 if (!node->IsEnter()) {
609 return node;
610 }
611 return cf_info[node->id()].parent_frame;
612}
613
614const Node* OutputFrame(const Node* node,
615 const std::vector<ControlFlowInfo>& cf_info) {
616 // An output is in the same frame as the node except for Exit nodes.
617 // The output of Exit is in the parent frame of the Exit node.
618 if (!node->IsExit()) {
619 return node;
620 }
621 return cf_info[node->id()].parent_frame;
622}
623
624// Each participating device needs to decide a) if there is a next iteration,
625// and b) if the loop terminates. We take the approach to encode this control
626// flow logic in the dataflow graph. There are at least two possible encodings.
627// In a completely decentralized encoding, the participants communicate peer
628// to peer. The other encoding uses a frame leader (the participant who owns
629// the pivot termination predicate) to broadcast the termination condition to
630// all the participants. For now we take the latter because it is simpler.
631//
632// TODO(yuanbyu): The correctness of this construction is rather subtle. I got
633// it wrong many times so it would be nice to write a proof to be sure.
634Status AddControlFlow(const PartitionOptions& opts, Graph* g,
635 GraphInfo* g_info) {
636 Status status;
637 GraphDefBuilder::Options bopts(g, &status);
638 std::vector<ControlFlowInfo>& cf_info = g_info->cf_info;
639
640 // Build the control flow info for every node.
641 status = BuildControlFlowInfo(g, &cf_info);
642 if (!status.ok()) return status;
643
644 OptimizeControlFlowColocation(g);
645
646 // The map from frames to their LoopCond nodes.
647 std::unordered_map<string, Node*> frame_cond_map;
648 int num_node_ids = g->num_node_ids();
649 for (int i = 0; i < num_node_ids; ++i) {
650 Node* node = g->FindNodeId(i);
651 if (node == nullptr) continue;
652
653 if (IsLoopCond(node)) {
654 const string& frame_name = cf_info[node->id()].frame_name;
655 DCHECK(!frame_name.empty());
656 frame_cond_map[frame_name] = node;
657 }
658 }
659
660 // Add all control loops for cross-device frames.
661 // A control loop is added only when there is a cross-device edge in a
662 // non-root frame. Nothing is added if there is no loops. We also don't
663 // add anything for a frame that is completely local to a device. For
664 // nested loops, we stack the control loops together by connecting
665 // the merge of the outer loop to the enter of the inner loop.
666 //
667 // A map from <frame_name, device_name> to ControlLoop.
668 std::unordered_map<string, ControlLoop> control_loops;
669 int num_edge_ids = g->num_edge_ids();
670 for (int i = 0; i < num_edge_ids; ++i) {
671 const Edge* edge = g->FindEdgeId(i);
672 if (edge == nullptr) continue;
673
674 const Node* src = edge->src();
675 const Node* dst = edge->dst();
676 // Skip Sink/Source nodes.
677 if (!src->IsOp() || !dst->IsOp()) continue;
678
679 const string& src_device = src->assigned_device_name();
680 const string& dst_device = dst->assigned_device_name();
681 // Skip local edges.
682 if (src_device == dst_device) continue;
683
684 const Node* src_frame = OutputFrame(src, cf_info);
685 const Node* dst_frame = InputFrame(dst, cf_info);
686 const string& src_frame_name = cf_info[src_frame->id()].frame_name;
687 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
688 // Skip if src and dst are not in the same frame.
689 if (src_frame_name.empty() || src_frame_name != dst_frame_name) {
690 continue;
691 }
692
693 // Add the control loop. Start by adding the control loop for the
694 // current frame if needed, and recursively adding the control loop
695 // for its outer frame when nested.
696 ControlLoop child_loop;
697 while (true) {
698 const string& curr_frame_name = cf_info[src_frame->id()].frame_name;
699 if (curr_frame_name.empty()) {
700 // We have reached the root frame.
701 if (child_loop.merge != nullptr) {
702 const string& node_name = opts.new_name(edge->dst()->name());
703 const string& device_name = edge->dst()->assigned_device_name();
704 Node* const_node =
705 AddControlConst(device_name, bopts.WithName(node_name));
706 if (!status.ok()) return status;
707 AddControlFlowInfo(const_node, src_frame, &cf_info);
708 g->AddEdge(const_node, 0, child_loop.enter, 0);
709 }
710 break;
711 }
712
713 const string& cl_key = strings::StrCat(curr_frame_name, "$$", dst_device);
714 auto it = control_loops.find(cl_key);
715 if (it != control_loops.end()) {
716 if (child_loop.enter != nullptr) {
717 g->AddEdge(it->second.merge, 0, child_loop.enter, 0);
718 }
719 break;
720 }
721
722 // Get the frame's LoopCond.
723 auto cond_it = frame_cond_map.find(curr_frame_name);
724 if (cond_it == frame_cond_map.end()) {
725 return errors::InvalidArgument(
726 "A cross-device loop must have a pivot predicate: ",
727 curr_frame_name);
728 }
729 Node* loop_cond = cond_it->second;
730
731 // Add the control loop.
732 ControlLoop curr_loop;
733 status = AddControlLoop(opts, g, src_frame, edge, loop_cond, &cf_info,
734 &curr_loop);
735 if (!status.ok()) return status;
736 control_loops[cl_key] = curr_loop;
737
738 if (child_loop.enter != nullptr) {
739 // Connect the merge of the outer loop to the enter of the inner.
740 g->AddEdge(curr_loop.merge, 0, child_loop.enter, 0);
741 }
742 src_frame = cf_info[src_frame->id()].parent_frame;
743 child_loop = curr_loop;
744 }
745 }
746
747 // For a cross-device edge, on the dst device, add a control edge
748 // from the merge node of the control loop to dst. If a send/recv is
749 // introduced for this edge in future partitioning, we delete this
750 // control edge and add a new control edge from the merge to the recv.
751 num_edge_ids = g->num_edge_ids();
752 for (int i = 0; i < num_edge_ids; ++i) {
753 const Edge* edge = g->FindEdgeId(i);
754 if (edge == nullptr) continue;
755
756 const Node* src = edge->src();
757 Node* dst = edge->dst();
758 // Skip Sink/Source nodes.
759 if (!src->IsOp() || !dst->IsOp()) continue;
760
761 const string& src_device = src->assigned_device_name();
762 const string& dst_device = dst->assigned_device_name();
763 if (src_device != dst_device) {
764 const Node* src_frame = OutputFrame(src, cf_info);
765 const Node* dst_frame = InputFrame(dst, cf_info);
766 const string& src_frame_name = cf_info[src_frame->id()].frame_name;
767 const string& dst_frame_name = cf_info[dst_frame->id()].frame_name;
768 if (!src_frame_name.empty() && src_frame_name == dst_frame_name) {
769 const string& cl_key =
770 strings::StrCat(dst_frame_name, "$$", dst_device);
771 ControlLoop loop = control_loops[cl_key];
772 DCHECK(loop.enter != nullptr);
773 // Note that we'll create multiple duplicate edges if dst has multiple
774 // cross-device inputs. This is expected by the logic in Partition(), so
775 // it can add control edges to the recv nodes once they're created.
776 g->AddControlEdge(loop.merge, dst, /*allow_duplicates=*/true);
777 }
778 }
779 }
780 return OkStatus();
781}
782
783struct PriorityTopoSortNode {
784 PriorityTopoSortNode(const NodeDef* n, int64_t st)
785 : node(n), start_time(st) {}
786
787 const NodeDef* node;
788 int64_t start_time;
789};
790
791struct PriorityTopoSortNodeGreater {
792 bool operator()(const PriorityTopoSortNode& left,
793 const PriorityTopoSortNode& right) {
794 return left.start_time > right.start_time;
795 }
796};
797
798} // namespace
799
800// Returns in <nodes> the nodes that should participate in epoch-based recv
801// scheduling, along with their times; <nodes> is ordered by increasing
802// start_time. Returns in <node_to_start_time_out> the timing for all nodes,
803// even those not in <nodes>.
804//
805// Comparing to sorting on the node's start time only, this also processes the
806// nodes in dependency order, and updates start times to ensure a node's
807// start_time > the start time for all dependencies.
808//
809// Note that graph_partition_test.cc accesses this function for testing, even
810// though it's not declared in the header.
811Status TopologicalSortNodesWithTimePriority(
812 const GraphDef* gdef,
813 std::vector<std::pair<const NodeDef*, int64_t>>* nodes,
814 std::unordered_map<const NodeDef*, int64_t>* node_to_start_time_out) {
815 // Queue of nodes to process; lowest start time is returned first.
816 std::priority_queue<PriorityTopoSortNode, std::vector<PriorityTopoSortNode>,
817 PriorityTopoSortNodeGreater>
818 q;
819 std::unordered_map<const NodeDef*, int64_t> node_to_start_time;
820 auto enqueue = [&q, &node_to_start_time](const NodeDef* node) {
821 const int64_t start_time = node_to_start_time[node];
822 q.emplace(node, start_time);
823 };
824
825 // Build initial structures, initial contents of queue.
826 std::unordered_map<string, std::vector<const NodeDef*>> node_to_output_nodes;
827 std::unordered_map<const NodeDef*, int> inputs_needed;
828 for (int n = 0; n < gdef->node_size(); ++n) {
829 const NodeDef* ndef = &gdef->node(n);
830 for (int i = 0; i < ndef->input_size(); ++i) {
831 node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
832 .push_back(ndef);
833 }
834 int64_t start_time;
835 TF_RETURN_IF_ERROR(GetNodeAttr(*ndef, "_start_time", &start_time));
836 node_to_start_time[ndef] = start_time;
837 inputs_needed[ndef] = ndef->input_size();
838 if (ndef->input_size() == 0) {
839 enqueue(ndef);
840 }
841 }
842
843 // Determine which merge nodes are parts of loops; these
844 // need to happen in the traversal after all non-NextIteration inputs
845 // are run.
846 for (int n = 0; n < gdef->node_size(); ++n) {
847 const NodeDef* ndef = &gdef->node(n);
848 if (IsNextIteration(*ndef)) {
849 for (const NodeDef* n : node_to_output_nodes[ndef->name()]) {
850 if (IsMerge(*n)) {
851 // n is a merge that is part of a loop structure.
852 // It doesn't need to wait for this NextIteration loop
853 // when doing the traversal.
854 --inputs_needed[n];
855 }
856 }
857 }
858 }
859
860 // Traverse.
861 std::vector<std::pair<const NodeDef*, int64_t>> start_times;
862 start_times.reserve(gdef->node_size());
863 while (!q.empty()) {
864 PriorityTopoSortNode cur = q.top();
865 q.pop();
866
867 start_times.emplace_back(cur.node, cur.start_time);
868
869 for (const NodeDef* n : node_to_output_nodes[cur.node->name()]) {
870 auto& output_start_time = node_to_start_time[n];
871 if (output_start_time <= cur.start_time) {
872 output_start_time = cur.start_time + 1;
873 }
874 if (--inputs_needed[n] == 0) {
875 enqueue(n);
876 }
877 }
878 }
879
880 // Done.
881 nodes->swap(start_times);
882 node_to_start_time_out->swap(node_to_start_time);
883 return OkStatus();
884}
885
886Status AddControlEdges(const PartitionOptions& opts,
887 std::unordered_map<string, GraphDef>* partitions) {
888 Status status;
889 // TODO(yuanbyu): Very naive for now. To be improved.
890 const int num_epochs = 100;
891 const int prefetch = 6;
892
893 for (auto& part : *partitions) {
894 GraphDef* gdef = &part.second;
895 std::vector<std::pair<const NodeDef*, int64_t>> start_times;
896 std::unordered_map<const NodeDef*, int64_t> node_to_start_time;
897 status = TopologicalSortNodesWithTimePriority(gdef, &start_times,
898 &node_to_start_time);
899 if (!status.ok()) {
900 return status;
901 }
902
903 // Add a dummy node for every epoch, and add a control edge from the
904 // "last" node in the preceding epoch to the dummy node.
905 string device_name = gdef->node(0).device();
906 int64_t makespan = start_times.back().second;
907 int64_t resolution = (makespan / num_epochs) + 1;
908
909 int i = 0;
910 int j = 0;
911 std::vector<NodeDef*> dummys;
912 while (i < num_epochs && static_cast<size_t>(j) < start_times.size()) {
913 if (i * resolution > start_times[j].second) {
914 j++;
915 } else {
916 NodeDef* dummy = AddControlTrigger(opts, gdef, device_name, i,
917 i * resolution, &status);
918 if (!status.ok()) {
919 return status;
920 }
921 dummys.push_back(dummy);
922 if (j > 0) {
923 string src_name = start_times[j - 1].first->name();
924 AddInput(dummy, src_name, Graph::kControlSlot);
925 }
926 i++;
927 }
928 }
929
930 // Finally, add the control edges to recvs.
931 for (int n = 0; n < gdef->node_size(); ++n) {
932 NodeDef* ndef = gdef->mutable_node(n);
933 if (ndef->op() == "_Recv") {
934 const int64_t start_time = node_to_start_time[ndef];
935 const int recv_epoch = start_time / resolution;
936 if (recv_epoch >= prefetch) {
937 NodeDef* dummy = dummys[recv_epoch - prefetch];
938 AddInput(ndef, dummy->name(), Graph::kControlSlot);
939 }
940 }
941 }
942 }
943 return OkStatus();
944}
945
946// If 'ndef' is a Send or Recv, fills its attr send_device_incarnation
947// if possible.
948void SetIncarnation(const PartitionOptions& opts, NodeDef* ndef) {
949 StringPiece op(ndef->op());
950 if (op != "_Send" && op != "_Recv") {
951 // Not related to send/recv.
952 return;
953 }
954 const string& send_device = GetNodeAttrString(*ndef, "send_device");
955 if (send_device.empty()) {
956 // No known send_device. The runtime will detect it later.
957 return;
958 }
959 int64_t incarnation = PartitionOptions::kIllegalIncarnation;
960 if (!TryGetNodeAttr(*ndef, "send_device_incarnation", &incarnation) ||
961 (incarnation == PartitionOptions::kIllegalIncarnation)) {
962 incarnation = opts.get_incarnation(send_device);
963 SetAttrValue(incarnation,
964 &((*ndef->mutable_attr())["send_device_incarnation"]));
965 }
966}
967
968// Sets attribute send_device_incarnation of all Send/Recv nodes in
969// 'gdef', if possible.
970void SetIncarnation(const PartitionOptions& opts, GraphDef* gdef) {
971 for (NodeDef& ndef : *gdef->mutable_node()) {
972 SetIncarnation(opts, &ndef);
973 }
974 for (FunctionDef& fdef : *gdef->mutable_library()->mutable_function()) {
975 for (NodeDef& ndef : *fdef.mutable_node_def()) {
976 SetIncarnation(opts, &ndef);
977 }
978 }
979}
980
981Status Partition(const PartitionOptions& opts, Graph* g,
982 std::unordered_map<string, GraphDef>* partitions) {
983 Status status;
984 partitions->clear();
985
986 GraphInfo g_info;
987 if (!opts.control_flow_added) {
988 // Add the "code" for distributed execution of control flow. Code is
989 // added only for the frames that are placed on multiple devices. The
990 // new graph is an equivalent transformation of the original graph and
991 // has the property that it can be subsequently partitioned arbitrarily
992 // (down to the level of individual device) for distributed execution.
993 status = AddControlFlow(opts, g, &g_info);
994 if (!status.ok()) return status;
995 }
996
997 // At this point, all the graph mutations have been done. Build memory
998 // and device type info for every node and edge in the graph.
999 status = BuildMemoryDeviceInfo(*g, &g_info);
1000 if (!status.ok()) return status;
1001
1002 string dstp;
1003 std::vector<const Edge*> inputs;
1004 DupRecvTable dup_recv(3);
1005 // For a node dst, 'ref_recvs' remembers the recvs introduced by a ref
1006 // edge to dst. 'ref_control_inputs' remembers the inputs by a non-ref
1007 // edge to dst. We will add a control edge for every pair in
1008 // (ref_recvs x ref_control_inputs).
1009 std::vector<NodeDef*> ref_recvs;
1010 std::vector<string> ref_control_inputs;
1011
1012 int32_t num_data = 0;
1013 int32_t num_control = 0;
1014 for (const Node* dst : g->op_nodes()) {
1015 dstp = opts.node_to_loc(dst);
1016 GraphDef* dst_graph = &(*partitions)[dstp];
1017 NodeDef* dst_def = dst_graph->add_node();
1018 *dst_def = dst->def();
1019 MergeDebugInfo(NodeDebugInfo(dst->def()), dst_def);
1020 dst_def->set_device(dst->assigned_device_name());
1021 dst_def->clear_input(); // Inputs are filled below
1022 if (opts.need_to_record_start_times) {
1023 int64_t start_time;
1024 status = GetNodeAttr(*dst_def, "_start_time", &start_time);
1025 if (errors::IsNotFound(status)) {
1026 start_time = opts.start_times[dst->id()].value();
1027 AddNodeAttr("_start_time", start_time, dst_def);
1028 } else if (!status.ok()) {
1029 return status;
1030 }
1031 }
1032
1033 // Arrange the incoming edges to dst so that input[i] holds the
1034 // input flowing into slot numbered i. Trailing entries in input[]
1035 // hold control edges.
1036 inputs.clear();
1037 inputs.resize(dst->num_inputs(), nullptr);
1038 ref_recvs.clear();
1039 ref_control_inputs.clear();
1040 const Edge* control_flow_edge = nullptr;
1041 int32_t num_control_flow_edges = 0;
1042 int32_t num_input_edges = 0;
1043 for (const Edge* edge : dst->in_edges()) {
1044 if (edge->IsControlEdge()) {
1045 if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
1046 // This is one of the control edges added for control flow. There
1047 // can be multiple such edges as the dest node may have multiple
1048 // remote inputs. We keep track of the number of such edges.
1049 control_flow_edge = edge;
1050 ++num_control_flow_edges;
1051 } else {
1052 inputs.push_back(edge);
1053 }
1054 } else {
1055 DCHECK(inputs[edge->dst_input()] == nullptr);
1056 inputs[edge->dst_input()] = edge;
1057 ++num_input_edges;
1058 }
1059 }
1060
1061 if (num_input_edges != dst->num_inputs()) {
1062 return errors::InvalidArgument("Incomplete graph, missing ",
1063 (dst->num_inputs() - num_input_edges),
1064 " inputs for ", dst->name());
1065 }
1066
1067 // Process in order so that all data edges are added as inputs to
1068 // dst in Edge::dst_input() order.
1069 for (const Edge* edge : inputs) {
1070 const Node* src = edge->src();
1071 if (!src->IsOp()) continue; // Skip Sink/Source nodes.
1072
1073 GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)];
1074 if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
1075 // Same partition and compatible memory types:
1076 AddInput(dst_def, src->name(), edge->src_output());
1077 if (edge->IsControlEdge() ||
1078 !IsRefType(src->output_type(edge->src_output()))) {
1079 ref_control_inputs.push_back(src->name());
1080 }
1081 continue;
1082 }
1083
1084 int64_t send_start_time = 0;
1085 int64_t recv_start_time = 0;
1086 if (opts.scheduling_for_recvs) {
1087 status = GetNodeAttr(src->attrs(), "_start_time", &send_start_time);
1088 if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1089 send_start_time = opts.start_times[src->id()].value();
1090 } else if (!status.ok()) {
1091 return status;
1092 }
1093
1094 status = GetNodeAttr(dst->attrs(), "_start_time", &recv_start_time);
1095 if (errors::IsNotFound(status) && opts.need_to_record_start_times) {
1096 recv_start_time = opts.start_times[dst->id()].value();
1097 } else if (!status.ok()) {
1098 return status;
1099 }
1100 }
1101
1102 // Check whether there is already a send/recv pair transferring
1103 // the same tensor/control from the src to dst partition.
1104 const bool on_host = IsDstInputOnHost(edge, g_info);
1105 DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
1106 auto iter = dup_recv.find(key);
1107 if (iter != dup_recv.end()) {
1108 // We found one. Reuse the data/control transferred already.
1109 const string& recv_node_name = iter->second.recv->name();
1110 if (edge->IsControlEdge()) {
1111 AddInput(dst_def, recv_node_name, Graph::kControlSlot);
1112 } else {
1113 AddInput(dst_def, recv_node_name, 0);
1114 }
1115 ref_control_inputs.push_back(recv_node_name);
1116
1117 // We want the start_time for the recv to be the smallest of the start
1118 // times of it's consumers. So we update this whenever we use a recv,
1119 // and write it out to the attribute at the end of the subroutine
1120 if (iter->second.start_time > recv_start_time) {
1121 iter->second.start_time = recv_start_time;
1122 }
1123 continue;
1124 }
1125
1126 NodeDefBuilder::NodeOut send_from;
1127 if (edge->IsControlEdge()) {
1128 // Insert a dummy const node that will generate a tiny
1129 // data element to be sent from send to recv.
1130 VLOG(1) << "Send/Recv control: " << src->assigned_device_name() << "["
1131 << src->name() << "] -> " << dst->assigned_device_name() << "["
1132 << dst->name() << "]";
1133 NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
1134 if (!status.ok()) return status;
1135 // Set the start time for this dummy node.
1136 if (opts.scheduling_for_recvs) {
1137 AddNodeAttr("_start_time", send_start_time, dummy);
1138 }
1139 AddInput(dummy, src->name(), Graph::kControlSlot);
1140 send_from.Reset(dummy->name(), 0, DT_FLOAT);
1141 } else {
1142 send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
1143 }
1144
1145 string tensor_name_attr;
1146 if (opts.get_tensor_name_attr) {
1147 tensor_name_attr = opts.get_tensor_name_attr(edge);
1148 } else {
1149 tensor_name_attr =
1150 strings::StrCat("edge_", edge->id(), "_", edge->src()->name());
1151 }
1152
1153 // Need to split edge by placing matching send/recv nodes on
1154 // the src/dst sides of the edge.
1155 NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
1156 send_start_time, tensor_name_attr, &status);
1157 if (!status.ok()) return status;
1158
1159 NodeDef* real_recv = nullptr;
1160 NodeDef* recv = AddRecv(opts, g_info, dst_graph, edge, &real_recv,
1161 tensor_name_attr, &status);
1162 if (!status.ok()) return status;
1163
1164 // Fix up the control flow edge.
1165 // NOTE(yuanbyu): 'real_recv' must be the real recv node.
1166 if (src_graph == dst_graph) {
1167 // For same device send/recv, add a control edge from send to recv.
1168 // This prevents the asynchronous recv kernel from being scheduled
1169 // before the data is available.
1170 AddInput(real_recv, send->name(), Graph::kControlSlot);
1171 } else if (control_flow_edge != nullptr) {
1172 // Redirect control edge to the real recv since this is not the same
1173 // device send/recv.
1174 --num_control_flow_edges;
1175 AddInput(real_recv, control_flow_edge->src()->name(),
1176 Graph::kControlSlot);
1177 }
1178
1179 if (!edge->IsControlEdge() &&
1180 IsRefType(src->output_type(edge->src_output()))) {
1181 AddNodeAttr("_start_time", recv_start_time, recv);
1182 if (real_recv != recv) {
1183 AddNodeAttr("_start_time", recv_start_time, real_recv);
1184 }
1185 // If src is of ref type and the edge is not a control edge, dst has
1186 // read semantics and therefore we must control the recv.
1187 ref_recvs.push_back(real_recv);
1188 } else {
1189 // Memorize the send/recv pair, only if this is not a "ref" edge.
1190 // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
1191 // for now we don't do it.
1192 dup_recv[key] = {recv, real_recv, recv_start_time};
1193 ref_control_inputs.push_back(recv->name());
1194 }
1195
1196 if (edge->IsControlEdge()) {
1197 ++num_control;
1198 AddInput(dst_def, recv->name(), Graph::kControlSlot);
1199 } else {
1200 ++num_data;
1201 AddInput(dst_def, recv->name(), 0);
1202 }
1203 }
1204
1205 // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
1206 // NOTE(yuanbyu): Adding these control edges should not introduce
1207 // deadlocks. 'dst' has implicit "read" nodes that, when we split
1208 // across devices, are made explicit; Retargeting the dependencies
1209 // to 'dst' to those nodes would not introduce cycles if there isn't
1210 // one before the transformation.
1211 // NOTE(yuanbyu): This may impact performance because it defers the
1212 // execution of recvs until all the other inputs become available.
1213 AddReadControl(ref_recvs, ref_control_inputs);
1214
1215 // Add back the control edges for control flow that are not used.
1216 if (control_flow_edge != nullptr) {
1217 for (int i = 0; i < num_control_flow_edges; ++i) {
1218 AddInput(dst_def, control_flow_edge->src()->name(),
1219 Graph::kControlSlot);
1220 }
1221 }
1222 }
1223
1224 const FunctionLibraryDefinition* flib_def = opts.flib_def;
1225 if (flib_def == nullptr) {
1226 flib_def = &g->flib_def();
1227 }
1228
1229 // Set versions, function library and send/recv incarnation.
1230 for (auto& it : *partitions) {
1231 GraphDef* gdef = &it.second;
1232 *gdef->mutable_versions() = g->versions();
1233 // Prune unreachable functions from `flib_def` before adding them to `gdef`.
1234 *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
1235
1236 // Traverse the graph to fill every send/recv op's incarnation
1237 // information.
1238 SetIncarnation(opts, gdef);
1239 }
1240
1241 // Set the start times for recvs at the very end.
1242 if (opts.scheduling_for_recvs) {
1243 for (auto& it : dup_recv) {
1244 AddNodeAttr("_start_time", it.second.start_time, it.second.recv);
1245 if (it.second.real_recv != it.second.recv) {
1246 AddNodeAttr("_start_time", it.second.start_time, it.second.real_recv);
1247 }
1248 }
1249 }
1250
1251 VLOG(1) << "Added send/recv: controls=" << num_control
1252 << ", data=" << num_data;
1253 if (VLOG_IS_ON(2)) {
1254 for (auto& it : *partitions) {
1255 GraphDef* gdef = &it.second;
1256 DumpGraphDefToFile(strings::StrCat("partition_", it.first, "_",
1257 reinterpret_cast<uintptr_t>(gdef)),
1258 *gdef);
1259 }
1260 }
1261 return OkStatus();
1262}
1263
1264} // namespace tensorflow
1265