1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/mutable_graph_view.h"
17
18#include <algorithm>
19#include <utility>
20
21#include "absl/container/flat_hash_map.h"
22#include "absl/strings/str_cat.h"
23#include "absl/strings/str_join.h"
24#include "absl/strings/string_view.h"
25#include "absl/strings/substitute.h"
26#include "tensorflow/core/framework/function.h"
27#include "tensorflow/core/framework/graph.pb.h"
28#include "tensorflow/core/framework/node_def.pb.h"
29#include "tensorflow/core/graph/graph.h"
30#include "tensorflow/core/graph/tensor_id.h"
31#include "tensorflow/core/grappler/op_types.h"
32#include "tensorflow/core/grappler/utils.h"
33#include "tensorflow/core/lib/core/errors.h"
34#include "tensorflow/core/lib/core/stringpiece.h"
35#include "tensorflow/core/lib/gtl/map_util.h"
36#include "tensorflow/core/platform/protobuf.h"
37#include "tensorflow/core/platform/types.h"
38
39namespace tensorflow {
40namespace grappler {
41
42namespace {
43
44bool IsTensorIdPortValid(const TensorId& tensor_id) {
45 return tensor_id.index() >= Graph::kControlSlot;
46}
47
48bool IsTensorIdRegular(const TensorId& tensor_id) {
49 return tensor_id.index() > Graph::kControlSlot;
50}
51
52bool IsTensorIdControlling(const TensorId& tensor_id) {
53 return tensor_id.index() == Graph::kControlSlot;
54}
55
56bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) {
57 return port.port_id == Graph::kControlSlot;
58}
59
60// Determines if node is an Identity where it's first regular input is a Switch
61// node.
62bool IsIdentityConsumingSwitch(const MutableGraphView& graph,
63 const NodeDef& node) {
64 if ((IsIdentity(node) || IsIdentityNSingleInput(node)) &&
65 node.input_size() > 0) {
66 TensorId tensor_id = ParseTensorName(node.input(0));
67 if (IsTensorIdControlling(tensor_id)) {
68 return false;
69 }
70
71 NodeDef* input_node = graph.GetNode(tensor_id.node());
72 if (input_node == nullptr) {
73 return false;
74 }
75 return IsSwitch(*input_node);
76 }
77 return false;
78}
79
80// Determines if node input can be deduped by regular inputs when used as a
81// control dependency. Specifically, if a node is an Identity that leads to a
82// Switch node, when used as a control dependency, that control dependency
83// should not be deduped even though the same node is used as a regular input.
84bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
85 const NodeDef& control_node) {
86 return !IsIdentityConsumingSwitch(graph, control_node);
87}
88
89// Determines if node input can be deduped by regular inputs when used as a
90// control dependency. Specifically, if a node is an Identity that leads to a
91// Switch node, when used as a control dependency, that control dependency
92// should not be deduped even though the same node is used as a regular input.
93bool CanDedupControlWithRegularInput(const MutableGraphView& graph,
94 absl::string_view control_node_name) {
95 NodeDef* control_node = graph.GetNode(control_node_name);
96 if (control_node == nullptr) {
97 return false;
98 }
99 return CanDedupControlWithRegularInput(graph, *control_node);
100}
101
102bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
103 absl::string_view fanin_node_name) {
104 const int num_regular_fanins =
105 graph.NumFanins(node, /*include_controlling_nodes=*/false);
106 for (int i = 0; i < num_regular_fanins; ++i) {
107 if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
108 return true;
109 }
110 }
111 return false;
112}
113
114using FanoutsMap =
115 absl::flat_hash_map<MutableGraphView::OutputPort,
116 absl::flat_hash_set<MutableGraphView::InputPort>>;
117
118void SwapControlledFanoutInputs(const MutableGraphView& graph,
119 const FanoutsMap::iterator& control_fanouts,
120 absl::string_view to_node_name) {
121 absl::string_view from_node_name(control_fanouts->first.node->name());
122 string control = TensorIdToString({to_node_name, Graph::kControlSlot});
123 for (const auto& control_fanout : control_fanouts->second) {
124 const int start = graph.NumFanins(*control_fanout.node,
125 /*include_controlling_nodes=*/false);
126 for (int i = start; i < control_fanout.node->input_size(); ++i) {
127 TensorId tensor_id = ParseTensorName(control_fanout.node->input(i));
128 if (tensor_id.node() == from_node_name) {
129 control_fanout.node->set_input(i, control);
130 break;
131 }
132 }
133 }
134}
135
136void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node,
137 absl::string_view to_node_name, int max_port) {
138 MutableGraphView::OutputPort port;
139 port.node = from_node;
140 for (int i = 0; i <= max_port; ++i) {
141 port.port_id = i;
142 auto it = fanouts->find(port);
143 if (it == fanouts->end()) {
144 continue;
145 }
146 string input = TensorIdToString({to_node_name, i});
147 for (const auto& fanout : it->second) {
148 fanout.node->set_input(fanout.port_id, input);
149 }
150 }
151}
152
153using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>;
154
155void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts,
156 MaxOutputPortsMap* max_output_ports, NodeDef* from_node,
157 NodeDef* to_node) {
158 auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot});
159 if (from_control_fanouts != fanouts->end()) {
160 SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name());
161 }
162 auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot});
163 if (to_control_fanouts != fanouts->end()) {
164 SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name());
165 }
166 auto from_max_port = max_output_ports->find(from_node);
167 if (from_max_port != max_output_ports->end()) {
168 SwapRegularFanoutInputs(fanouts, from_node, to_node->name(),
169 from_max_port->second);
170 }
171 auto to_max_port = max_output_ports->find(to_node);
172 if (to_max_port != max_output_ports->end()) {
173 SwapRegularFanoutInputs(fanouts, to_node, from_node->name(),
174 to_max_port->second);
175 }
176}
177
178void SwapFanoutsMapValues(FanoutsMap* fanouts,
179 const MutableGraphView::OutputPort& from_port,
180 const FanoutsMap::iterator& from_fanouts,
181 const MutableGraphView::OutputPort& to_port,
182 const FanoutsMap::iterator& to_fanouts) {
183 const bool from_exists = from_fanouts != fanouts->end();
184 const bool to_exists = to_fanouts != fanouts->end();
185
186 if (from_exists && to_exists) {
187 std::swap(from_fanouts->second, to_fanouts->second);
188 } else if (from_exists) {
189 fanouts->emplace(to_port, std::move(from_fanouts->second));
190 fanouts->erase(from_port);
191 } else if (to_exists) {
192 fanouts->emplace(from_port, std::move(to_fanouts->second));
193 fanouts->erase(to_port);
194 }
195}
196
197void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts,
198 MaxOutputPortsMap* max_output_ports,
199 NodeDef* from_node, NodeDef* to_node) {
200 auto from_max_port = max_output_ports->find(from_node);
201 auto to_max_port = max_output_ports->find(to_node);
202 bool from_exists = from_max_port != max_output_ports->end();
203 bool to_exists = to_max_port != max_output_ports->end();
204
205 auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start,
206 int end) {
207 for (int i = start; i <= end; ++i) {
208 MutableGraphView::OutputPort from_port(from, i);
209 auto from_fanouts = fanouts->find(from_port);
210 if (from_fanouts != fanouts->end()) {
211 MutableGraphView::OutputPort to_port(to, i);
212 fanouts->emplace(to_port, std::move(from_fanouts->second));
213 fanouts->erase(from_port);
214 }
215 }
216 };
217
218 if (from_exists && to_exists) {
219 const int from = from_max_port->second;
220 const int to = to_max_port->second;
221 const int shared = std::min(from, to);
222 for (int i = 0; i <= shared; ++i) {
223 MutableGraphView::OutputPort from_port(from_node, i);
224 auto from_fanouts = fanouts->find(from_port);
225 MutableGraphView::OutputPort to_port(to_node, i);
226 auto to_fanouts = fanouts->find(to_port);
227 SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port,
228 to_fanouts);
229 }
230 if (to > from) {
231 forward_fanouts(to_node, from_node, shared + 1, to);
232 } else if (from > to) {
233 forward_fanouts(from_node, to_node, shared + 1, from);
234 }
235
236 std::swap(from_max_port->second, to_max_port->second);
237 } else if (from_exists) {
238 forward_fanouts(from_node, to_node, 0, from_max_port->second);
239
240 max_output_ports->emplace(to_node, from_max_port->second);
241 max_output_ports->erase(from_node);
242 } else if (to_exists) {
243 forward_fanouts(to_node, from_node, 0, to_max_port->second);
244
245 max_output_ports->emplace(from_node, to_max_port->second);
246 max_output_ports->erase(to_node);
247 }
248}
249
250bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) {
251 return it != fanouts.end() && !it->second.empty();
252}
253
254Status MutationError(absl::string_view function_name, absl::string_view params,
255 absl::string_view msg) {
256 return errors::InvalidArgument(absl::Substitute(
257 "MutableGraphView::$0($1) error: $2.", function_name, params, msg));
258}
259
260using ErrorHandler = std::function<Status(absl::string_view)>;
261
262ErrorHandler UpdateFanoutsError(absl::string_view from_node_name,
263 absl::string_view to_node_name) {
264 return [from_node_name, to_node_name](absl::string_view msg) {
265 string params = absl::Substitute("from_node_name='$0', to_node_name='$1'",
266 from_node_name, to_node_name);
267 return MutationError("UpdateFanouts", params, msg);
268 };
269}
270
271Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) {
272 if (!IsTensorIdRegular(fanin)) {
273 return handler(absl::Substitute("fanin '$0' must be a regular tensor id",
274 fanin.ToString()));
275 }
276 return OkStatus();
277}
278
279Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) {
280 if (!IsTensorIdPortValid(fanin)) {
281 return handler(absl::Substitute("fanin '$0' must be a valid tensor id",
282 fanin.ToString()));
283 }
284 return OkStatus();
285}
286
287Status CheckAddingFaninToSelf(absl::string_view node_name,
288 const TensorId& fanin, ErrorHandler handler) {
289 if (node_name == fanin.node()) {
290 return handler(
291 absl::Substitute("can't add fanin '$0' to self", fanin.ToString()));
292 }
293 return OkStatus();
294}
295
296Status CheckRemovingFaninFromSelf(absl::string_view node_name,
297 const TensorId& fanin, ErrorHandler handler) {
298 if (node_name == fanin.node()) {
299 return handler(absl::Substitute("can't remove fanin '$0' from self",
300 fanin.ToString()));
301 }
302 return OkStatus();
303}
304
305string NodeMissingErrorMsg(absl::string_view node_name) {
306 return absl::Substitute("node '$0' was not found", node_name);
307}
308
309Status CheckNodeExists(absl::string_view node_name, NodeDef* node,
310 ErrorHandler handler) {
311 if (node == nullptr) {
312 return handler(NodeMissingErrorMsg(node_name));
313 }
314 return OkStatus();
315}
316
317Status CheckPortRange(int port, int min, int max, ErrorHandler handler) {
318 if (port < min || port > max) {
319 if (max < min) {
320 return handler("no available ports as node has no regular fanins");
321 }
322 return handler(
323 absl::Substitute("port must be in range [$0, $1]", min, max));
324 }
325 return OkStatus();
326}
327
328string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) {
329 return absl::Substitute(
330 "can't swap node name '$0' as it will become a Switch control dependency",
331 node_name);
332}
333
334string GeneratedNameForIdentityConsumingSwitch(
335 const MutableGraphView::OutputPort& fanin) {
336 return AddPrefixToNodeName(
337 absl::StrCat(fanin.node->name(), "_", fanin.port_id),
338 kMutableGraphViewCtrl);
339}
340
341string PrintInTextFormat(const protobuf::MessageLite& message) {
342 // Unfortunately proto2::TextFormat::Printer::PrintToString does not have
343 // a overload for MessageLite so here we have to use
344 // MessageLite::ShortDebugString.
345 return message.ShortDebugString();
346}
347
348string PrintInTextFormat(const protobuf::Message& message) {
349 string message_text;
350 ::tensorflow::protobuf::TextFormat::Printer printer;
351 printer.SetSingleLineMode(true);
352 printer.PrintToString(message, &message_text);
353 if (!message_text.empty() && message_text[message_text.size() - 1] == ' ') {
354 message_text.resize(message_text.size() - 1);
355 }
356 return message_text;
357}
358
359} // namespace
360
361void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
362 // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
363 // exist, and all regular fanins come before controlling fanins.
364 absl::flat_hash_set<absl::string_view> fanins;
365 absl::flat_hash_set<absl::string_view> controlling_fanins;
366 int max_input_port = -1;
367 int pos = 0;
368 const int last_idx = node->input_size() - 1;
369 int last_pos = last_idx;
370 while (pos <= last_pos) {
371 TensorId tensor_id = ParseTensorName(node->input(pos));
372 absl::string_view input_node_name = tensor_id.node();
373 bool is_control_input = IsTensorIdControlling(tensor_id);
374 bool can_dedup_control_with_regular_input =
375 CanDedupControlWithRegularInput(*this, input_node_name);
376 bool can_dedup_control =
377 is_control_input && (can_dedup_control_with_regular_input ||
378 controlling_fanins.contains(input_node_name));
379 if (!gtl::InsertIfNotPresent(&fanins, input_node_name) &&
380 can_dedup_control) {
381 node->mutable_input()->SwapElements(pos, last_pos);
382 --last_pos;
383 } else {
384 OutputPort output(nodes()[input_node_name], tensor_id.index());
385
386 if (is_control_input) {
387 fanouts()[output].emplace(node, Graph::kControlSlot);
388 } else {
389 max_input_port = pos;
390 max_regular_output_port()[output.node] =
391 std::max(max_regular_output_port()[output.node], output.port_id);
392 fanouts()[output].emplace(node, pos);
393 }
394 ++pos;
395 }
396 if (is_control_input) {
397 controlling_fanins.insert(input_node_name);
398 }
399 }
400
401 if (last_pos < last_idx) {
402 node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
403 }
404
405 if (max_input_port > -1) {
406 max_regular_input_port()[node] = max_input_port;
407 }
408}
409
410void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin(
411 const OutputPort& fanin,
412 const absl::flat_hash_set<InputPort>& fanin_fanouts) {
413 int max_port = max_regular_output_port()[fanin.node];
414 if (!fanin_fanouts.empty() || max_port != fanin.port_id) {
415 return;
416 }
417 bool updated_max_port = false;
418 for (int i = fanin.port_id - 1; i >= 0; --i) {
419 OutputPort fanin_port(fanin.node, i);
420 if (!fanouts()[fanin_port].empty()) {
421 max_regular_output_port()[fanin.node] = i;
422 updated_max_port = true;
423 break;
424 }
425 }
426 if (!updated_max_port) {
427 max_regular_output_port().erase(fanin.node);
428 }
429}
430
431void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin(
432 const OutputPort& fanin) {
433 if (max_regular_output_port()[fanin.node] < fanin.port_id) {
434 max_regular_output_port()[fanin.node] = fanin.port_id;
435 }
436}
437
438const absl::flat_hash_set<MutableGraphView::InputPort>&
439MutableGraphView::GetFanout(const GraphView::OutputPort& port) const {
440 return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node),
441 port.port_id));
442}
443
444absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin(
445 const GraphView::InputPort& port) const {
446 return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node),
447 port.port_id));
448}
449
450const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin(
451 const GraphView::InputPort& port) const {
452 return GetRegularFanin(MutableGraphView::InputPort(
453 const_cast<NodeDef*>(port.node), port.port_id));
454}
455
456NodeDef* MutableGraphView::AddNode(NodeDef&& node) {
457 auto* node_in_graph = graph()->add_node();
458 *node_in_graph = std::move(node);
459
460 AddUniqueNodeOrDie(node_in_graph);
461
462 AddAndDedupFanouts(node_in_graph);
463 return node_in_graph;
464}
465
466Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) {
467 // 1. Add all new functions and check that functions with the same name
468 // have identical definition.
469 const int function_size = subgraph.library().function_size();
470 if (function_size > 0) {
471 absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs;
472 for (const FunctionDef& fdef : graph()->library().function()) {
473 graph_fdefs.emplace(fdef.signature().name(), &fdef);
474 }
475
476 for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) {
477 const auto graph_fdef = graph_fdefs.find(fdef.signature().name());
478
479 if (graph_fdef == graph_fdefs.end()) {
480 VLOG(3) << "Add new function definition: " << fdef.signature().name();
481 graph()->mutable_library()->add_function()->Swap(&fdef);
482 } else {
483 if (!FunctionDefsEqual(fdef, *graph_fdef->second)) {
484 return MutationError(
485 "AddSubgraph",
486 absl::Substitute("function_size=$0", function_size),
487 absl::StrCat(
488 "Found different function definition with the same name: ",
489 fdef.signature().name()));
490 }
491 }
492 }
493 }
494
495 // 2. Add all nodes to the underlying graph.
496 int node_size_before = graph()->node_size();
497
498 for (NodeDef& node : *subgraph.mutable_node()) {
499 auto* node_in_graph = graph()->add_node();
500 node_in_graph->Swap(&node);
501 TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph));
502 }
503
504 // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that
505 // fanins actually exists in the graph, and there is already TODO for that.
506
507 for (int i = node_size_before; i < graph()->node_size(); ++i) {
508 NodeDef* node = graph()->mutable_node(i);
509 AddAndDedupFanouts(node);
510 }
511
512 return OkStatus();
513}
514
515Status MutableGraphView::UpdateNode(
516 absl::string_view node_name, absl::string_view op, absl::string_view device,
517 absl::Span<const std::pair<string, AttrValue>> attrs) {
518 auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
519 std::vector<string> attr_strs;
520 attr_strs.reserve(attrs.size());
521 for (const auto& attr : attrs) {
522 string attr_str = absl::Substitute("('$0', $1)", attr.first,
523 PrintInTextFormat(attr.second));
524 attr_strs.push_back(attr_str);
525 }
526 string params =
527 absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
528 node_name, op, device, absl::StrJoin(attr_strs, ", "));
529 return MutationError("UpdateNodeOp", params, msg);
530 };
531
532 NodeDef* node = GetNode(node_name);
533 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
534
535 MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
536 auto control_fanouts = GetFanout(control_port);
537 if (op == "Switch" && !control_fanouts.empty()) {
538 return error_status(
539 "can't change node op to Switch when node drives a control dependency "
540 "(alternatively, we could add the identity node needed, but it seems "
541 "like an unlikely event and probably a mistake)");
542 }
543
544 if (node->device() != device) {
545 node->set_device(string(device));
546 }
547 node->mutable_attr()->clear();
548 for (const auto& attr : attrs) {
549 (*node->mutable_attr())[attr.first] = attr.second;
550 }
551
552 if (node->op() == op) {
553 return OkStatus();
554 }
555
556 node->set_op(string(op));
557
558 if (CanDedupControlWithRegularInput(*this, *node)) {
559 for (const auto& control_fanout : control_fanouts) {
560 if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
561 RemoveControllingFaninInternal(control_fanout.node, node);
562 }
563 }
564 }
565
566 return OkStatus();
567}
568
569Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name,
570 absl::string_view to_node_name,
571 bool update_fanouts) {
572 auto error_status = [from_node_name, to_node_name,
573 update_fanouts](absl::string_view msg) {
574 string params = absl::Substitute(
575 "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
576 from_node_name, to_node_name, update_fanouts);
577 return MutationError("UpdateNodeName", params, msg);
578 };
579
580 NodeDef* node = GetNode(from_node_name);
581 TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status));
582
583 if (node->name() == to_node_name) {
584 return OkStatus();
585 }
586 if (HasNode(to_node_name)) {
587 return error_status(
588 "can't update node name because new node name is in use");
589 }
590 auto max_output_port = max_regular_output_port().find(node);
591 const bool has_max_output_port =
592 max_output_port != max_regular_output_port().end();
593 auto control_fanouts = fanouts().find({node, Graph::kControlSlot});
594
595 if (update_fanouts) {
596 SwapControlledFanoutInputs(*this, control_fanouts, to_node_name);
597 if (has_max_output_port) {
598 SwapRegularFanoutInputs(&fanouts(), node, to_node_name,
599 max_output_port->second);
600 }
601 } else if (has_max_output_port ||
602 HasFanoutValue(fanouts(), control_fanouts)) {
603 return error_status("can't update node name because node has fanouts");
604 }
605
606 nodes().erase(node->name());
607 node->set_name(string(to_node_name));
608 nodes().emplace(node->name(), node);
609 return OkStatus();
610}
611
612Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name,
613 absl::string_view to_node_name,
614 bool update_fanouts) {
615 auto error_status = [from_node_name, to_node_name,
616 update_fanouts](absl::string_view msg) {
617 string params = absl::Substitute(
618 "from_node_name='$0', to_node_name='$1', update_fanouts=$2",
619 from_node_name, to_node_name, update_fanouts);
620 return MutationError("SwapNodeNames", params, msg);
621 };
622
623 NodeDef* from_node = GetNode(from_node_name);
624 TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status));
625 if (from_node_name == to_node_name) {
626 return OkStatus();
627 }
628 NodeDef* to_node = GetNode(to_node_name);
629 TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status));
630
631 auto swap_names = [this, from_node, to_node]() {
632 nodes().erase(from_node->name());
633 nodes().erase(to_node->name());
634 std::swap(*from_node->mutable_name(), *to_node->mutable_name());
635 nodes().emplace(from_node->name(), from_node);
636 nodes().emplace(to_node->name(), to_node);
637 };
638
639 if (update_fanouts) {
640 SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node,
641 to_node);
642 swap_names();
643 return OkStatus();
644 }
645
646 bool from_is_switch = IsSwitch(*from_node);
647 MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot);
648 auto to_control_fanouts = fanouts().find(to_control);
649 if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) {
650 return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name));
651 }
652
653 bool to_is_switch = IsSwitch(*to_node);
654 MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot);
655 auto from_control_fanouts = fanouts().find(from_control);
656 if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) {
657 return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name));
658 }
659
660 // Swap node names.
661 swap_names();
662
663 // Swap controlling fanouts.
664 //
665 // Note: To and from control fanout iterators are still valid as no mutations
666 // has been performed on fanouts().
667 SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts,
668 to_control, to_control_fanouts);
669
670 // Swap regular fanouts.
671 SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(),
672 from_node, to_node);
673
674 // Update fanins to remove self loops.
675 auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) {
676 for (int i = 0; i < node->input_size(); ++i) {
677 TensorId tensor_id = ParseTensorName(node->input(i));
678 if (tensor_id.node() == node->name()) {
679 const int idx = tensor_id.index();
680 const int node_idx =
681 IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i;
682
683 MutableGraphView::OutputPort from_fanin(node, idx);
684 absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin];
685 from_fanouts->erase({node, node_idx});
686 UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts);
687
688 MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx);
689 fanouts()[to_fanin].insert({node, node_idx});
690 UpdateMaxRegularOutputPortForAddedFanin(to_fanin);
691 node->set_input(i, TensorIdToString({old_node_name, idx}));
692 }
693 }
694 };
695 update_fanins(from_node, to_node->name());
696 update_fanins(to_node, from_node->name());
697
698 // Dedup control dependencies.
699 auto dedup_control_fanouts =
700 [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) {
701 if (CanDedupControlWithRegularInput(*this, *node) &&
702 control_fanouts != fanouts().end()) {
703 for (auto it = control_fanouts->second.begin();
704 it != control_fanouts->second.end();) {
705 // Advance `it` before invalidation from removal.
706 const auto& control_fanout = *it++;
707 if (HasRegularFaninNode(*this, *control_fanout.node,
708 node->name())) {
709 RemoveControllingFaninInternal(control_fanout.node, node);
710 }
711 }
712 }
713 };
714 auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) {
715 OutputPort port;
716 port.node = node;
717 const int max_port =
718 gtl::FindWithDefault(max_regular_output_port(), node, -1);
719 for (int i = 0; i <= max_port; ++i) {
720 port.port_id = i;
721 auto it = fanouts().find(port);
722 if (it == fanouts().end()) {
723 continue;
724 }
725 for (const auto& fanout : it->second) {
726 auto fanout_controls =
727 fanouts().find({fanout.node, Graph::kControlSlot});
728 dedup_control_fanouts(fanout.node, fanout_controls);
729 }
730 }
731 };
732
733 if (!from_is_switch) {
734 if (to_is_switch) {
735 dedup_switch_control(from_node);
736 } else {
737 // Fetch iterator again as the original iterator might have been
738 // invalidated by container rehash triggered due to mutations.
739 auto from_control_fanouts = fanouts().find(from_control);
740 dedup_control_fanouts(from_node, from_control_fanouts);
741 }
742 }
743 if (!to_is_switch) {
744 if (from_is_switch) {
745 dedup_switch_control(to_node);
746 } else {
747 // Fetch iterator again as the original iterator might have been
748 // invalidated by container rehash triggered due to mutations.
749 auto to_control_fanouts = fanouts().find(to_control);
750 dedup_control_fanouts(to_node, to_control_fanouts);
751 }
752 }
753
754 return OkStatus();
755}
756
757Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
758 absl::string_view to_node_name) {
759 NodeDef* from_node = GetNode(from_node_name);
760 TF_RETURN_IF_ERROR(
761 CheckNodeExists(from_node_name, from_node,
762 UpdateFanoutsError(from_node_name, to_node_name)));
763 NodeDef* to_node = GetNode(to_node_name);
764 TF_RETURN_IF_ERROR(CheckNodeExists(
765 to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name)));
766
767 return UpdateFanoutsInternal(from_node, to_node);
768}
769
770Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node,
771 NodeDef* to_node) {
772 VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'.",
773 from_node->name(), to_node->name());
774 if (from_node == to_node) {
775 return OkStatus();
776 }
777
778 // Update internal state with the new output_port->input_port edge.
779 const auto add_edge = [this](const OutputPort& output_port,
780 const InputPort& input_port) {
781 fanouts()[output_port].insert(input_port);
782 };
783
784 // Remove invalidated edge from the internal state.
785 const auto remove_edge = [this](const OutputPort& output_port,
786 const InputPort& input_port) {
787 fanouts()[output_port].erase(input_port);
788 };
789
790 // For the control fanouts we do not know the input index in a NodeDef,
791 // so we have to traverse all control inputs.
792
793 auto control_fanouts =
794 GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot));
795
796 bool to_node_is_switch = IsSwitch(*to_node);
797 for (const InputPort& control_port : control_fanouts) {
798 // Node can't be control dependency of itself.
799 if (control_port.node == to_node) continue;
800
801 // Can't add Switch node as a control dependency.
802 if (to_node_is_switch) {
803 // Trying to add a Switch as a control dependency, which if allowed will
804 // make the graph invalid.
805 return UpdateFanoutsError(from_node->name(), to_node->name())(
806 absl::Substitute("can't update fanouts to node '$0' as it will "
807 "become a Switch control dependency",
808 to_node->name()));
809 }
810
811 NodeDef* node = control_port.node;
812 RemoveControllingFaninInternal(node, from_node);
813 AddFaninInternal(node, {to_node, Graph::kControlSlot});
814 }
815
816 // First we update regular fanouts. For the regular fanouts
817 // `input_port:port_id` is the input index in NodeDef.
818
819 auto regular_edges =
820 GetFanoutEdges(*from_node, /*include_controlled_edges=*/false);
821
822 // Maximum index of the `from_node` output tensor that is still used as an
823 // input to some other node.
824 int keep_max_regular_output_port = -1;
825
826 for (const Edge& edge : regular_edges) {
827 const OutputPort output_port = edge.src;
828 const InputPort input_port = edge.dst;
829
830 // If the `to_node` reads from the `from_node`, skip this edge (see
831 // AddAndUpdateFanoutsWithoutSelfLoops test for an example).
832 if (input_port.node == to_node) {
833 keep_max_regular_output_port =
834 std::max(keep_max_regular_output_port, output_port.port_id);
835 continue;
836 }
837
838 // Update input at destination node.
839 input_port.node->set_input(
840 input_port.port_id,
841 TensorIdToString({to_node->name(), output_port.port_id}));
842
843 // Remove old edge between the `from_node` and the fanout node.
844 remove_edge(output_port, input_port);
845 // Add an edge between the `to_node` and new fanout node.
846 add_edge(OutputPort(to_node, output_port.port_id), input_port);
847 // Dedup control dependency.
848 if (CanDedupControlWithRegularInput(*this, *to_node)) {
849 RemoveControllingFaninInternal(input_port.node, to_node);
850 }
851 }
852
853 // Because we update all regular fanouts of `from_node`, we can just copy
854 // the value `num_regular_outputs`.
855 max_regular_output_port()[to_node] = max_regular_output_port()[from_node];
856
857 // Check if all fanouts were updated to read from the `to_node`.
858 if (keep_max_regular_output_port >= 0) {
859 max_regular_output_port()[from_node] = keep_max_regular_output_port;
860 } else {
861 max_regular_output_port().erase(from_node);
862 }
863
864 return OkStatus();
865}
866
867bool MutableGraphView::AddFaninInternal(NodeDef* node,
868 const OutputPort& fanin) {
869 int num_regular_fanins =
870 NumFanins(*node, /*include_controlling_nodes=*/false);
871 bool input_is_control = IsOutputPortControlling(fanin);
872 bool can_dedup_control_with_regular_input =
873 CanDedupControlWithRegularInput(*this, *fanin.node);
874 // Don't add duplicate control dependencies.
875 if (input_is_control) {
876 const int start =
877 can_dedup_control_with_regular_input ? 0 : num_regular_fanins;
878 for (int i = start; i < node->input_size(); ++i) {
879 if (ParseTensorName(node->input(i)).node() == fanin.node->name()) {
880 return false;
881 }
882 }
883 }
884
885 InputPort input;
886 input.node = node;
887 input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins;
888
889 node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
890 if (!input_is_control) {
891 const int last_node_input = node->input_size() - 1;
892 // If there are control dependencies in node, move newly inserted fanin to
893 // be before such control dependencies.
894 if (num_regular_fanins < last_node_input) {
895 node->mutable_input()->SwapElements(last_node_input, num_regular_fanins);
896 }
897 }
898
899 fanouts()[fanin].insert(input);
900 if (max_regular_output_port()[fanin.node] < fanin.port_id) {
901 max_regular_output_port()[fanin.node] = fanin.port_id;
902 }
903
904 // Update max input port and dedup control dependencies.
905 if (!input_is_control) {
906 max_regular_input_port()[node] = num_regular_fanins;
907 if (can_dedup_control_with_regular_input) {
908 RemoveControllingFaninInternal(node, fanin.node);
909 }
910 }
911
912 return true;
913}
914
915Status MutableGraphView::AddRegularFanin(absl::string_view node_name,
916 const TensorId& fanin) {
917 auto error_status = [node_name, fanin](absl::string_view msg) {
918 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
919 fanin.ToString());
920 return MutationError("AddRegularFanin", params, msg);
921 };
922
923 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
924 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
925 NodeDef* node = GetNode(node_name);
926 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
927 NodeDef* fanin_node = GetNode(fanin.node());
928 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
929
930 AddFaninInternal(node, {fanin_node, fanin.index()});
931 return OkStatus();
932}
933
934Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name,
935 int port,
936 const TensorId& fanin) {
937 auto error_status = [node_name, port, fanin](absl::string_view msg) {
938 string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
939 node_name, port, fanin.ToString());
940 return MutationError("AddRegularFaninByPort", params, msg);
941 };
942
943 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
944 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
945 NodeDef* node = GetNode(node_name);
946 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
947 const int num_regular_fanins =
948 NumFanins(*node, /*include_controlling_nodes=*/false);
949 TF_RETURN_IF_ERROR(
950 CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status));
951 NodeDef* fanin_node = GetNode(fanin.node());
952 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
953
954 const int last_node_input = node->input_size();
955 node->add_input(TensorIdToString(fanin));
956 node->mutable_input()->SwapElements(num_regular_fanins, last_node_input);
957 for (int i = num_regular_fanins - 1; i >= port; --i) {
958 TensorId tensor_id = ParseTensorName(node->input(i));
959 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
960 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
961 fanouts_set->erase({node, i});
962 fanouts_set->insert({node, i + 1});
963 node->mutable_input()->SwapElements(i, i + 1);
964 }
965
966 OutputPort fanin_port(fanin_node, fanin.index());
967 fanouts()[fanin_port].insert({node, port});
968 UpdateMaxRegularOutputPortForAddedFanin(fanin_port);
969
970 max_regular_input_port()[node] = num_regular_fanins;
971 if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
972 RemoveControllingFaninInternal(node, fanin_node);
973 }
974
975 return OkStatus();
976}
977
978NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
979 const OutputPort& fanin,
980 string* error_msg) {
981 if (!IsSwitch(*fanin.node)) {
982 return fanin.node;
983 } else {
984 if (IsOutputPortControlling(fanin)) {
985 // Can't add a Switch node control dependency.
986 TensorId tensor_id(fanin.node->name(), fanin.port_id);
987 *error_msg = absl::Substitute(
988 "can't add fanin '$0' as it will become a Switch control dependency",
989 tensor_id.ToString());
990 return nullptr;
991 }
992 // We can't anchor control dependencies directly on the switch node: unlike
993 // other nodes only one of the outputs of the switch node will be generated
994 // when the switch node is executed, and we need to make sure the control
995 // dependency is only triggered when the corresponding output is triggered.
996 // We start by looking for an identity node connected to the output of the
997 // switch node, and use it to anchor the control dependency.
998 for (const auto& fanout : GetFanout(fanin)) {
999 if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
1000 if (fanout.node->name() == node_name) {
1001 *error_msg =
1002 absl::Substitute("can't add found fanin '$0' to self",
1003 AsControlDependency(fanout.node->name()));
1004 return nullptr;
1005 }
1006 return fanout.node;
1007 }
1008 }
1009
1010 // No node found, check if node to be created is itself.
1011 if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
1012 *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
1013 AsControlDependency(string(node_name)));
1014 }
1015 }
1016 return nullptr;
1017}
1018
1019NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
1020 const OutputPort& fanin) {
1021 // We haven't found an existing node where we can anchor the control
1022 // dependency: add a new identity node.
1023 string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
1024 NodeDef* identity_node = GetNode(identity_name);
1025 if (identity_node == nullptr) {
1026 NodeDef new_node;
1027 new_node.set_name(identity_name);
1028 new_node.set_op("Identity");
1029 new_node.set_device(fanin.node->device());
1030 (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
1031 new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
1032 identity_node = AddNode(std::move(new_node));
1033 }
1034 return identity_node;
1035}
1036
1037Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
1038 const TensorId& fanin) {
1039 auto error_status = [node_name, fanin](absl::string_view msg) {
1040 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1041 fanin.ToString());
1042 return MutationError("AddControllingFanin", params, msg);
1043 };
1044
1045 TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status));
1046 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1047 NodeDef* node = GetNode(node_name);
1048 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1049 NodeDef* fanin_node = GetNode(fanin.node());
1050 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1051
1052 OutputPort fanin_port(fanin_node, fanin.index());
1053
1054 string error_msg = "";
1055 NodeDef* control_node = GetControllingFaninToAdd(
1056 node_name, {fanin_node, fanin.index()}, &error_msg);
1057 if (!error_msg.empty()) {
1058 return error_status(error_msg);
1059 }
1060 if (control_node == nullptr) {
1061 control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
1062 }
1063 AddFaninInternal(node, {control_node, Graph::kControlSlot});
1064
1065 return OkStatus();
1066}
1067
1068bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node,
1069 const OutputPort& fanin) {
1070 auto remove_input = [this, node](const OutputPort& fanin_port,
1071 int node_input_port, bool update_max_port) {
1072 InputPort input(node, node_input_port);
1073
1074 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1075 fanouts_set->erase(input);
1076 if (update_max_port) {
1077 UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set);
1078 }
1079 return fanouts_set;
1080 };
1081
1082 auto mutable_inputs = node->mutable_input();
1083 bool modified = false;
1084 const int num_regular_fanins =
1085 NumFanins(*node, /*include_controlling_nodes=*/false);
1086 int i;
1087 int curr_pos = 0;
1088 for (i = 0; i < num_regular_fanins; ++i) {
1089 TensorId tensor_id = ParseTensorName(node->input(i));
1090 if (tensor_id.node() == fanin.node->name() &&
1091 tensor_id.index() == fanin.port_id) {
1092 remove_input(fanin, i, /*update_max_port=*/true);
1093 modified = true;
1094 } else if (modified) {
1095 // Regular inputs will need to have their ports updated.
1096 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1097 auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false);
1098 fanouts_set->insert({node, curr_pos});
1099 // Shift inputs to be retained.
1100 mutable_inputs->SwapElements(i, curr_pos);
1101 ++curr_pos;
1102 } else {
1103 // Skip inputs to be retained until first modification.
1104 ++curr_pos;
1105 }
1106 }
1107
1108 if (modified) {
1109 const int last_regular_input_port = curr_pos - 1;
1110 if (last_regular_input_port < 0) {
1111 max_regular_input_port().erase(node);
1112 } else {
1113 max_regular_input_port()[node] = last_regular_input_port;
1114 }
1115 if (curr_pos < i) {
1116 // Remove fanins from node inputs.
1117 mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos);
1118 }
1119 }
1120
1121 return modified;
1122}
1123
1124Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name,
1125 const TensorId& fanin) {
1126 auto error_status = [node_name, fanin](absl::string_view msg) {
1127 string params = absl::Substitute("node_name='$0', fanin='$1'", node_name,
1128 fanin.ToString());
1129 return MutationError("RemoveRegularFanin", params, msg);
1130 };
1131
1132 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1133 TF_RETURN_IF_ERROR(
1134 CheckRemovingFaninFromSelf(node_name, fanin, error_status));
1135 NodeDef* node = GetNode(node_name);
1136 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1137 NodeDef* fanin_node = GetNode(fanin.node());
1138 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1139
1140 RemoveRegularFaninInternal(node, {fanin_node, fanin.index()});
1141 return OkStatus();
1142}
1143
1144Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name,
1145 int port) {
1146 auto error_status = [node_name, port](absl::string_view msg) {
1147 string params =
1148 absl::Substitute("node_name='$0', port=$1", node_name, port);
1149 return MutationError("RemoveRegularFaninByPort", params, msg);
1150 };
1151
1152 NodeDef* node = GetNode(node_name);
1153 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1154 const int last_regular_fanin_port =
1155 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1156 TF_RETURN_IF_ERROR(
1157 CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1158
1159 TensorId tensor_id = ParseTensorName(node->input(port));
1160 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1161 fanouts()[fanin_port].erase({node, port});
1162 auto mutable_inputs = node->mutable_input();
1163 for (int i = port + 1; i <= last_regular_fanin_port; ++i) {
1164 TensorId tensor_id = ParseTensorName(node->input(i));
1165 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1166 absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port];
1167 fanouts_set->erase({node, i});
1168 fanouts_set->insert({node, i - 1});
1169 mutable_inputs->SwapElements(i - 1, i);
1170 }
1171 const int last_node_input = node->input_size() - 1;
1172 if (last_regular_fanin_port < last_node_input) {
1173 mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input);
1174 }
1175 mutable_inputs->RemoveLast();
1176
1177 const int updated_last_regular_input_port = last_regular_fanin_port - 1;
1178 if (updated_last_regular_input_port < 0) {
1179 max_regular_input_port().erase(node);
1180 } else {
1181 max_regular_input_port()[node] = updated_last_regular_input_port;
1182 }
1183
1184 return OkStatus();
1185}
1186
1187bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node,
1188 NodeDef* fanin_node) {
1189 for (int i = node->input_size() - 1; i >= 0; --i) {
1190 TensorId tensor_id = ParseTensorName(node->input(i));
1191 if (tensor_id.index() > Graph::kControlSlot) {
1192 break;
1193 }
1194 if (tensor_id.node() == fanin_node->name()) {
1195 fanouts()[{fanin_node, Graph::kControlSlot}].erase(
1196 {node, Graph::kControlSlot});
1197 node->mutable_input()->SwapElements(i, node->input_size() - 1);
1198 node->mutable_input()->RemoveLast();
1199 return true;
1200 }
1201 }
1202 return false;
1203}
1204
1205Status MutableGraphView::RemoveControllingFanin(
1206 absl::string_view node_name, absl::string_view fanin_node_name) {
1207 auto error_status = [node_name, fanin_node_name](absl::string_view msg) {
1208 string params = absl::Substitute("node_name='$0', fanin_node_name='$1'",
1209 node_name, fanin_node_name);
1210 return MutationError("RemoveControllingFanin", params, msg);
1211 };
1212
1213 TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf(
1214 node_name, {fanin_node_name, Graph::kControlSlot}, error_status));
1215 NodeDef* node = GetNode(node_name);
1216 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1217 NodeDef* fanin_node = GetNode(fanin_node_name);
1218 TF_RETURN_IF_ERROR(
1219 CheckNodeExists(fanin_node_name, fanin_node, error_status));
1220
1221 RemoveControllingFaninInternal(node, fanin_node);
1222 return OkStatus();
1223}
1224
1225Status MutableGraphView::RemoveAllFanins(absl::string_view node_name,
1226 bool keep_controlling_fanins) {
1227 NodeDef* node = GetNode(node_name);
1228 if (node == nullptr) {
1229 string params =
1230 absl::Substitute("node_name='$0', keep_controlling_fanins=$1",
1231 node_name, keep_controlling_fanins);
1232 return MutationError("RemoveAllFanins", params,
1233 NodeMissingErrorMsg(node_name));
1234 }
1235
1236 if (node->input().empty()) {
1237 return OkStatus();
1238 }
1239
1240 const int num_regular_fanins =
1241 NumFanins(*node, /*include_controlling_nodes=*/false);
1242 RemoveFaninsInternal(node, keep_controlling_fanins);
1243 if (keep_controlling_fanins) {
1244 if (num_regular_fanins == 0) {
1245 return OkStatus();
1246 } else if (num_regular_fanins < node->input_size()) {
1247 node->mutable_input()->DeleteSubrange(0, num_regular_fanins);
1248 } else {
1249 node->clear_input();
1250 }
1251 } else {
1252 node->clear_input();
1253 }
1254 return OkStatus();
1255}
1256
1257Status MutableGraphView::UpdateFanin(absl::string_view node_name,
1258 const TensorId& from_fanin,
1259 const TensorId& to_fanin) {
1260 auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) {
1261 string params =
1262 absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'",
1263 node_name, from_fanin.ToString(), to_fanin.ToString());
1264 return MutationError("UpdateFanin", params, msg);
1265 };
1266
1267 TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status));
1268 TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status));
1269 NodeDef* node = GetNode(node_name);
1270 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1271 NodeDef* from_fanin_node = GetNode(from_fanin.node());
1272 TF_RETURN_IF_ERROR(
1273 CheckNodeExists(from_fanin.node(), from_fanin_node, error_status));
1274 NodeDef* to_fanin_node = GetNode(to_fanin.node());
1275 TF_RETURN_IF_ERROR(
1276 CheckNodeExists(to_fanin.node(), to_fanin_node, error_status));
1277
1278 // When replacing a non control dependency fanin with a control dependency, or
1279 // vice versa, remove and add, so ports can be updated properly in fanout(s).
1280 bool to_fanin_is_control = IsTensorIdControlling(to_fanin);
1281 if (to_fanin_is_control && IsSwitch(*to_fanin_node)) {
1282 // Can't add Switch node as a control dependency.
1283 return error_status(
1284 absl::Substitute("can't update to fanin '$0' as it will become a "
1285 "Switch control dependency",
1286 to_fanin.ToString()));
1287 }
1288 if (node_name == from_fanin.node() || node_name == to_fanin.node()) {
1289 return error_status("can't update fanin to or from self");
1290 }
1291
1292 if (from_fanin == to_fanin) {
1293 return OkStatus();
1294 }
1295
1296 bool from_fanin_is_control = IsTensorIdControlling(from_fanin);
1297 if (from_fanin_is_control || to_fanin_is_control) {
1298 bool modified = false;
1299 if (from_fanin_is_control) {
1300 modified |= RemoveControllingFaninInternal(node, from_fanin_node);
1301 } else {
1302 modified |= RemoveRegularFaninInternal(
1303 node, {from_fanin_node, from_fanin.index()});
1304 }
1305 if (modified) {
1306 AddFaninInternal(node, {to_fanin_node, to_fanin.index()});
1307 }
1308 return OkStatus();
1309 }
1310
1311 // In place mutation of regular fanins, requires no shifting of ports.
1312 string to_fanin_string = TensorIdToString(to_fanin);
1313 const int num_regular_fanins =
1314 NumFanins(*node, /*include_controlling_nodes=*/false);
1315 bool modified = false;
1316 for (int i = 0; i < num_regular_fanins; ++i) {
1317 if (ParseTensorName(node->input(i)) == from_fanin) {
1318 InputPort input(node, i);
1319
1320 OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1321 fanouts()[from_fanin_port].erase(input);
1322
1323 OutputPort to_fanin_port(to_fanin_node, to_fanin.index());
1324 fanouts()[to_fanin_port].insert(input);
1325
1326 node->set_input(i, to_fanin_string);
1327 modified = true;
1328 }
1329 }
1330
1331 // Dedup control dependencies and update max regular output ports.
1332 if (modified) {
1333 OutputPort from_fanin_port(from_fanin_node, from_fanin.index());
1334 UpdateMaxRegularOutputPortForRemovedFanin(
1335 {from_fanin_node, from_fanin.index()}, fanouts()[from_fanin_port]);
1336 if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) {
1337 max_regular_output_port()[to_fanin_node] = to_fanin.index();
1338 }
1339 if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) {
1340 RemoveControllingFaninInternal(node, to_fanin_node);
1341 }
1342 }
1343
1344 return OkStatus();
1345}
1346
1347Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name,
1348 int port,
1349 const TensorId& fanin) {
1350 auto error_status = [node_name, port, fanin](absl::string_view msg) {
1351 string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'",
1352 node_name, port, fanin.ToString());
1353 return MutationError("UpdateRegularFaninByPort", params, msg);
1354 };
1355
1356 TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status));
1357 TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status));
1358 NodeDef* node = GetNode(node_name);
1359 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1360 const int last_regular_fanin_port =
1361 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1362 TF_RETURN_IF_ERROR(
1363 CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status));
1364 NodeDef* fanin_node = GetNode(fanin.node());
1365 TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
1366
1367 TensorId tensor_id = ParseTensorName(node->input(port));
1368 if (tensor_id == fanin) {
1369 return OkStatus();
1370 }
1371
1372 InputPort input(node, port);
1373 OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1374 absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port];
1375 from_fanouts->erase(input);
1376 UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts);
1377
1378 OutputPort to_fanin_port(fanin_node, fanin.index());
1379 fanouts()[to_fanin_port].insert(input);
1380 UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port);
1381
1382 node->set_input(port, TensorIdToString(fanin));
1383
1384 if (CanDedupControlWithRegularInput(*this, *fanin_node)) {
1385 RemoveControllingFaninInternal(node, fanin_node);
1386 }
1387
1388 return OkStatus();
1389}
1390
1391Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name,
1392 int from_port, int to_port) {
1393 auto error_status = [node_name, from_port, to_port](absl::string_view msg) {
1394 string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2",
1395 node_name, from_port, to_port);
1396 return MutationError("SwapRegularFaninsByPorts", params, msg);
1397 };
1398
1399 NodeDef* node = GetNode(node_name);
1400 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1401 const int last_regular_fanin_port =
1402 gtl::FindWithDefault(max_regular_input_port(), node, -1);
1403 TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0,
1404 last_regular_fanin_port, error_status));
1405 TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port,
1406 error_status));
1407
1408 if (from_port == to_port) {
1409 return OkStatus();
1410 }
1411 TensorId from_fanin = ParseTensorName(node->input(from_port));
1412 TensorId to_fanin = ParseTensorName(node->input(to_port));
1413 if (from_fanin == to_fanin) {
1414 return OkStatus();
1415 }
1416
1417 InputPort from_input(node, from_port);
1418 InputPort to_input(node, to_port);
1419 NodeDef* from_fanin_node = GetNode(from_fanin.node());
1420 absl::flat_hash_set<InputPort>* from_fanouts =
1421 &fanouts()[{from_fanin_node, from_fanin.index()}];
1422 from_fanouts->erase(from_input);
1423 from_fanouts->insert(to_input);
1424 NodeDef* to_fanin_node = GetNode(to_fanin.node());
1425 absl::flat_hash_set<InputPort>* to_fanouts =
1426 &fanouts()[{to_fanin_node, to_fanin.index()}];
1427 to_fanouts->erase(to_input);
1428 to_fanouts->insert(from_input);
1429
1430 node->mutable_input()->SwapElements(from_port, to_port);
1431
1432 return OkStatus();
1433}
1434
1435Status MutableGraphView::UpdateAllRegularFaninsToControlling(
1436 absl::string_view node_name) {
1437 auto error_status = [node_name](absl::string_view msg) {
1438 string params = absl::Substitute("node_name='$0'", node_name);
1439 return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
1440 };
1441
1442 NodeDef* node = GetNode(node_name);
1443 TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
1444
1445 const int num_regular_fanins =
1446 NumFanins(*node, /*include_controlling_nodes=*/false);
1447 std::vector<OutputPort> regular_fanins;
1448 regular_fanins.reserve(num_regular_fanins);
1449 std::vector<NodeDef*> controlling_fanins;
1450 controlling_fanins.reserve(num_regular_fanins);
1451
1452 // Get all regular fanins and derive controlling fanins.
1453 for (int i = 0; i < num_regular_fanins; ++i) {
1454 TensorId tensor_id = ParseTensorName(node->input(i));
1455 OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
1456
1457 string error_msg = "";
1458 NodeDef* control_node =
1459 GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
1460 if (!error_msg.empty()) {
1461 return error_status(error_msg);
1462 }
1463
1464 regular_fanins.push_back(fanin_port);
1465 controlling_fanins.push_back(control_node);
1466 }
1467
1468 // Replace regular fanins with controlling fanins and dedup.
1469 int pos = 0;
1470 InputPort input_port(node, Graph::kControlSlot);
1471 absl::flat_hash_set<absl::string_view> controls;
1472 for (int i = 0; i < num_regular_fanins; ++i) {
1473 OutputPort fanin_port = regular_fanins[i];
1474 NodeDef* control = controlling_fanins[i];
1475 if (control == nullptr) {
1476 control = GetOrCreateIdentityConsumingSwitch(fanin_port);
1477 }
1478 fanouts()[fanin_port].erase({node, i});
1479 if (controls.contains(control->name())) {
1480 continue;
1481 }
1482 controls.insert(control->name());
1483 node->set_input(pos, AsControlDependency(control->name()));
1484 fanouts()[{control, Graph::kControlSlot}].insert(input_port);
1485 ++pos;
1486 }
1487
1488 // Shift existing controlling fanins and dedup.
1489 for (int i = num_regular_fanins; i < node->input_size(); ++i) {
1490 TensorId tensor_id = ParseTensorName(node->input(i));
1491 if (controls.contains(tensor_id.node())) {
1492 continue;
1493 }
1494 controls.insert(tensor_id.node());
1495 node->mutable_input()->SwapElements(pos, i);
1496 ++pos;
1497 }
1498
1499 // Remove duplicate controls and leftover regular fanins.
1500 node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
1501 max_regular_input_port().erase(node);
1502
1503 return OkStatus();
1504}
1505
1506Status MutableGraphView::CheckNodesCanBeDeleted(
1507 const absl::flat_hash_set<string>& nodes_to_delete) {
1508 std::vector<string> missing_nodes;
1509 std::vector<string> nodes_with_fanouts;
1510 for (const string& node_name_to_delete : nodes_to_delete) {
1511 NodeDef* node = GetNode(node_name_to_delete);
1512 if (node == nullptr) {
1513 // Can't delete missing node.
1514 missing_nodes.push_back(node_name_to_delete);
1515 continue;
1516 }
1517 const int max_port = gtl::FindWithDefault(max_regular_output_port(), node,
1518 Graph::kControlSlot);
1519 for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1520 auto it = fanouts().find({node, i});
1521 bool has_retained_fanout = false;
1522 if (it != fanouts().end()) {
1523 for (const auto& fanout : it->second) {
1524 // Check if fanouts are of nodes to be deleted, and if so, they can be
1525 // ignored, as they will be removed also.
1526 if (!nodes_to_delete.contains(fanout.node->name())) {
1527 // Removing node will leave graph in an invalid state.
1528 has_retained_fanout = true;
1529 break;
1530 }
1531 }
1532 }
1533 if (has_retained_fanout) {
1534 nodes_with_fanouts.push_back(node_name_to_delete);
1535 break;
1536 }
1537 }
1538 }
1539
1540 // Error message can get quite long, so we only show the first 5 node names.
1541 auto sort_and_sample = [](std::vector<string>* s) {
1542 constexpr int kMaxNodeNames = 5;
1543 std::sort(s->begin(), s->end());
1544 if (s->size() > kMaxNodeNames) {
1545 return absl::StrCat(
1546 absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", "), ", ...");
1547 }
1548 return absl::StrJoin(*s, ", ");
1549 };
1550
1551 if (!missing_nodes.empty()) {
1552 VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0].",
1553 sort_and_sample(&missing_nodes));
1554 }
1555 if (!nodes_with_fanouts.empty()) {
1556 std::vector<string> input_node_names(nodes_to_delete.begin(),
1557 nodes_to_delete.end());
1558 string params = absl::Substitute("nodes_to_delete={$0}",
1559 sort_and_sample(&input_node_names));
1560 string error_msg =
1561 absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]",
1562 sort_and_sample(&nodes_with_fanouts));
1563 return MutationError("DeleteNodes", params, error_msg);
1564 }
1565
1566 return OkStatus();
1567}
1568
1569Status MutableGraphView::DeleteNodes(
1570 const absl::flat_hash_set<string>& nodes_to_delete) {
1571 TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete));
1572
1573 // Find nodes in internal state and delete.
1574 for (const string& node_name_to_delete : nodes_to_delete) {
1575 NodeDef* node = GetNode(node_name_to_delete);
1576 if (node != nullptr) {
1577 RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false);
1578 RemoveFanoutsInternal(node);
1579 }
1580 }
1581 for (const string& node_name_to_delete : nodes_to_delete) {
1582 nodes().erase(node_name_to_delete);
1583 }
1584
1585 // Find nodes in graph and delete by partitioning into nodes to retain and
1586 // nodes to delete based on input set of nodes to delete by name.
1587 // TODO(lyandy): Use a node name->idx hashmap if this is a performance
1588 // bottleneck.
1589 int pos = 0;
1590 const int last_idx = graph()->node_size() - 1;
1591 int last_pos = last_idx;
1592 while (pos <= last_pos) {
1593 if (nodes_to_delete.contains(graph()->node(pos).name())) {
1594 graph()->mutable_node()->SwapElements(pos, last_pos);
1595 --last_pos;
1596 } else {
1597 ++pos;
1598 }
1599 }
1600 if (last_pos < last_idx) {
1601 graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos);
1602 }
1603
1604 return OkStatus();
1605}
1606
1607void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node,
1608 bool keep_controlling_fanins) {
1609 for (int i = 0; i < deleted_node->input_size(); ++i) {
1610 TensorId tensor_id = ParseTensorName(deleted_node->input(i));
1611 bool is_control = IsTensorIdControlling(tensor_id);
1612 if (keep_controlling_fanins && is_control) {
1613 break;
1614 }
1615 OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index());
1616
1617 InputPort input;
1618 input.node = deleted_node;
1619 input.port_id = is_control ? Graph::kControlSlot : i;
1620
1621 auto it = fanouts().find(fanin);
1622 if (it != fanouts().end()) {
1623 absl::flat_hash_set<InputPort>* fanouts_set = &it->second;
1624 fanouts_set->erase(input);
1625 UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set);
1626 }
1627 }
1628 max_regular_input_port().erase(deleted_node);
1629}
1630
1631void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) {
1632 const int max_port =
1633 gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1);
1634 for (int i = Graph::kControlSlot; i <= max_port; ++i) {
1635 fanouts().erase({deleted_node, i});
1636 }
1637 max_regular_output_port().erase(deleted_node);
1638}
1639
1640} // end namespace grappler
1641} // end namespace tensorflow
1642