1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/grappler/mutable_graph_view.h" |
17 | |
18 | #include <algorithm> |
19 | #include <utility> |
20 | |
21 | #include "absl/container/flat_hash_map.h" |
22 | #include "absl/strings/str_cat.h" |
23 | #include "absl/strings/str_join.h" |
24 | #include "absl/strings/string_view.h" |
25 | #include "absl/strings/substitute.h" |
26 | #include "tensorflow/core/framework/function.h" |
27 | #include "tensorflow/core/framework/graph.pb.h" |
28 | #include "tensorflow/core/framework/node_def.pb.h" |
29 | #include "tensorflow/core/graph/graph.h" |
30 | #include "tensorflow/core/graph/tensor_id.h" |
31 | #include "tensorflow/core/grappler/op_types.h" |
32 | #include "tensorflow/core/grappler/utils.h" |
33 | #include "tensorflow/core/lib/core/errors.h" |
34 | #include "tensorflow/core/lib/core/stringpiece.h" |
35 | #include "tensorflow/core/lib/gtl/map_util.h" |
36 | #include "tensorflow/core/platform/protobuf.h" |
37 | #include "tensorflow/core/platform/types.h" |
38 | |
39 | namespace tensorflow { |
40 | namespace grappler { |
41 | |
42 | namespace { |
43 | |
44 | bool IsTensorIdPortValid(const TensorId& tensor_id) { |
45 | return tensor_id.index() >= Graph::kControlSlot; |
46 | } |
47 | |
48 | bool IsTensorIdRegular(const TensorId& tensor_id) { |
49 | return tensor_id.index() > Graph::kControlSlot; |
50 | } |
51 | |
52 | bool IsTensorIdControlling(const TensorId& tensor_id) { |
53 | return tensor_id.index() == Graph::kControlSlot; |
54 | } |
55 | |
56 | bool IsOutputPortControlling(const MutableGraphView::OutputPort& port) { |
57 | return port.port_id == Graph::kControlSlot; |
58 | } |
59 | |
60 | // Determines if node is an Identity where it's first regular input is a Switch |
61 | // node. |
62 | bool IsIdentityConsumingSwitch(const MutableGraphView& graph, |
63 | const NodeDef& node) { |
64 | if ((IsIdentity(node) || IsIdentityNSingleInput(node)) && |
65 | node.input_size() > 0) { |
66 | TensorId tensor_id = ParseTensorName(node.input(0)); |
67 | if (IsTensorIdControlling(tensor_id)) { |
68 | return false; |
69 | } |
70 | |
71 | NodeDef* input_node = graph.GetNode(tensor_id.node()); |
72 | if (input_node == nullptr) { |
73 | return false; |
74 | } |
75 | return IsSwitch(*input_node); |
76 | } |
77 | return false; |
78 | } |
79 | |
80 | // Determines if node input can be deduped by regular inputs when used as a |
81 | // control dependency. Specifically, if a node is an Identity that leads to a |
82 | // Switch node, when used as a control dependency, that control dependency |
83 | // should not be deduped even though the same node is used as a regular input. |
84 | bool CanDedupControlWithRegularInput(const MutableGraphView& graph, |
85 | const NodeDef& control_node) { |
86 | return !IsIdentityConsumingSwitch(graph, control_node); |
87 | } |
88 | |
89 | // Determines if node input can be deduped by regular inputs when used as a |
90 | // control dependency. Specifically, if a node is an Identity that leads to a |
91 | // Switch node, when used as a control dependency, that control dependency |
92 | // should not be deduped even though the same node is used as a regular input. |
93 | bool CanDedupControlWithRegularInput(const MutableGraphView& graph, |
94 | absl::string_view control_node_name) { |
95 | NodeDef* control_node = graph.GetNode(control_node_name); |
96 | if (control_node == nullptr) { |
97 | return false; |
98 | } |
99 | return CanDedupControlWithRegularInput(graph, *control_node); |
100 | } |
101 | |
102 | bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node, |
103 | absl::string_view fanin_node_name) { |
104 | const int num_regular_fanins = |
105 | graph.NumFanins(node, /*include_controlling_nodes=*/false); |
106 | for (int i = 0; i < num_regular_fanins; ++i) { |
107 | if (ParseTensorName(node.input(i)).node() == fanin_node_name) { |
108 | return true; |
109 | } |
110 | } |
111 | return false; |
112 | } |
113 | |
114 | using FanoutsMap = |
115 | absl::flat_hash_map<MutableGraphView::OutputPort, |
116 | absl::flat_hash_set<MutableGraphView::InputPort>>; |
117 | |
118 | void SwapControlledFanoutInputs(const MutableGraphView& graph, |
119 | const FanoutsMap::iterator& control_fanouts, |
120 | absl::string_view to_node_name) { |
121 | absl::string_view from_node_name(control_fanouts->first.node->name()); |
122 | string control = TensorIdToString({to_node_name, Graph::kControlSlot}); |
123 | for (const auto& control_fanout : control_fanouts->second) { |
124 | const int start = graph.NumFanins(*control_fanout.node, |
125 | /*include_controlling_nodes=*/false); |
126 | for (int i = start; i < control_fanout.node->input_size(); ++i) { |
127 | TensorId tensor_id = ParseTensorName(control_fanout.node->input(i)); |
128 | if (tensor_id.node() == from_node_name) { |
129 | control_fanout.node->set_input(i, control); |
130 | break; |
131 | } |
132 | } |
133 | } |
134 | } |
135 | |
136 | void SwapRegularFanoutInputs(FanoutsMap* fanouts, NodeDef* from_node, |
137 | absl::string_view to_node_name, int max_port) { |
138 | MutableGraphView::OutputPort port; |
139 | port.node = from_node; |
140 | for (int i = 0; i <= max_port; ++i) { |
141 | port.port_id = i; |
142 | auto it = fanouts->find(port); |
143 | if (it == fanouts->end()) { |
144 | continue; |
145 | } |
146 | string input = TensorIdToString({to_node_name, i}); |
147 | for (const auto& fanout : it->second) { |
148 | fanout.node->set_input(fanout.port_id, input); |
149 | } |
150 | } |
151 | } |
152 | |
153 | using MaxOutputPortsMap = absl::flat_hash_map<const NodeDef*, int>; |
154 | |
155 | void SwapFanoutInputs(const MutableGraphView& graph, FanoutsMap* fanouts, |
156 | MaxOutputPortsMap* max_output_ports, NodeDef* from_node, |
157 | NodeDef* to_node) { |
158 | auto from_control_fanouts = fanouts->find({from_node, Graph::kControlSlot}); |
159 | if (from_control_fanouts != fanouts->end()) { |
160 | SwapControlledFanoutInputs(graph, from_control_fanouts, to_node->name()); |
161 | } |
162 | auto to_control_fanouts = fanouts->find({to_node, Graph::kControlSlot}); |
163 | if (to_control_fanouts != fanouts->end()) { |
164 | SwapControlledFanoutInputs(graph, to_control_fanouts, from_node->name()); |
165 | } |
166 | auto from_max_port = max_output_ports->find(from_node); |
167 | if (from_max_port != max_output_ports->end()) { |
168 | SwapRegularFanoutInputs(fanouts, from_node, to_node->name(), |
169 | from_max_port->second); |
170 | } |
171 | auto to_max_port = max_output_ports->find(to_node); |
172 | if (to_max_port != max_output_ports->end()) { |
173 | SwapRegularFanoutInputs(fanouts, to_node, from_node->name(), |
174 | to_max_port->second); |
175 | } |
176 | } |
177 | |
178 | void SwapFanoutsMapValues(FanoutsMap* fanouts, |
179 | const MutableGraphView::OutputPort& from_port, |
180 | const FanoutsMap::iterator& from_fanouts, |
181 | const MutableGraphView::OutputPort& to_port, |
182 | const FanoutsMap::iterator& to_fanouts) { |
183 | const bool from_exists = from_fanouts != fanouts->end(); |
184 | const bool to_exists = to_fanouts != fanouts->end(); |
185 | |
186 | if (from_exists && to_exists) { |
187 | std::swap(from_fanouts->second, to_fanouts->second); |
188 | } else if (from_exists) { |
189 | fanouts->emplace(to_port, std::move(from_fanouts->second)); |
190 | fanouts->erase(from_port); |
191 | } else if (to_exists) { |
192 | fanouts->emplace(from_port, std::move(to_fanouts->second)); |
193 | fanouts->erase(to_port); |
194 | } |
195 | } |
196 | |
197 | void SwapRegularFanoutsAndMaxPortValues(FanoutsMap* fanouts, |
198 | MaxOutputPortsMap* max_output_ports, |
199 | NodeDef* from_node, NodeDef* to_node) { |
200 | auto from_max_port = max_output_ports->find(from_node); |
201 | auto to_max_port = max_output_ports->find(to_node); |
202 | bool from_exists = from_max_port != max_output_ports->end(); |
203 | bool to_exists = to_max_port != max_output_ports->end(); |
204 | |
205 | auto forward_fanouts = [fanouts](NodeDef* from, NodeDef* to, int start, |
206 | int end) { |
207 | for (int i = start; i <= end; ++i) { |
208 | MutableGraphView::OutputPort from_port(from, i); |
209 | auto from_fanouts = fanouts->find(from_port); |
210 | if (from_fanouts != fanouts->end()) { |
211 | MutableGraphView::OutputPort to_port(to, i); |
212 | fanouts->emplace(to_port, std::move(from_fanouts->second)); |
213 | fanouts->erase(from_port); |
214 | } |
215 | } |
216 | }; |
217 | |
218 | if (from_exists && to_exists) { |
219 | const int from = from_max_port->second; |
220 | const int to = to_max_port->second; |
221 | const int shared = std::min(from, to); |
222 | for (int i = 0; i <= shared; ++i) { |
223 | MutableGraphView::OutputPort from_port(from_node, i); |
224 | auto from_fanouts = fanouts->find(from_port); |
225 | MutableGraphView::OutputPort to_port(to_node, i); |
226 | auto to_fanouts = fanouts->find(to_port); |
227 | SwapFanoutsMapValues(fanouts, from_port, from_fanouts, to_port, |
228 | to_fanouts); |
229 | } |
230 | if (to > from) { |
231 | forward_fanouts(to_node, from_node, shared + 1, to); |
232 | } else if (from > to) { |
233 | forward_fanouts(from_node, to_node, shared + 1, from); |
234 | } |
235 | |
236 | std::swap(from_max_port->second, to_max_port->second); |
237 | } else if (from_exists) { |
238 | forward_fanouts(from_node, to_node, 0, from_max_port->second); |
239 | |
240 | max_output_ports->emplace(to_node, from_max_port->second); |
241 | max_output_ports->erase(from_node); |
242 | } else if (to_exists) { |
243 | forward_fanouts(to_node, from_node, 0, to_max_port->second); |
244 | |
245 | max_output_ports->emplace(from_node, to_max_port->second); |
246 | max_output_ports->erase(to_node); |
247 | } |
248 | } |
249 | |
250 | bool HasFanoutValue(const FanoutsMap& fanouts, const FanoutsMap::iterator& it) { |
251 | return it != fanouts.end() && !it->second.empty(); |
252 | } |
253 | |
254 | Status MutationError(absl::string_view function_name, absl::string_view params, |
255 | absl::string_view msg) { |
256 | return errors::InvalidArgument(absl::Substitute( |
257 | "MutableGraphView::$0($1) error: $2." , function_name, params, msg)); |
258 | } |
259 | |
260 | using ErrorHandler = std::function<Status(absl::string_view)>; |
261 | |
262 | ErrorHandler UpdateFanoutsError(absl::string_view from_node_name, |
263 | absl::string_view to_node_name) { |
264 | return [from_node_name, to_node_name](absl::string_view msg) { |
265 | string params = absl::Substitute("from_node_name='$0', to_node_name='$1'" , |
266 | from_node_name, to_node_name); |
267 | return MutationError("UpdateFanouts" , params, msg); |
268 | }; |
269 | } |
270 | |
271 | Status CheckFaninIsRegular(const TensorId& fanin, ErrorHandler handler) { |
272 | if (!IsTensorIdRegular(fanin)) { |
273 | return handler(absl::Substitute("fanin '$0' must be a regular tensor id" , |
274 | fanin.ToString())); |
275 | } |
276 | return OkStatus(); |
277 | } |
278 | |
279 | Status CheckFaninIsValid(const TensorId& fanin, ErrorHandler handler) { |
280 | if (!IsTensorIdPortValid(fanin)) { |
281 | return handler(absl::Substitute("fanin '$0' must be a valid tensor id" , |
282 | fanin.ToString())); |
283 | } |
284 | return OkStatus(); |
285 | } |
286 | |
287 | Status CheckAddingFaninToSelf(absl::string_view node_name, |
288 | const TensorId& fanin, ErrorHandler handler) { |
289 | if (node_name == fanin.node()) { |
290 | return handler( |
291 | absl::Substitute("can't add fanin '$0' to self" , fanin.ToString())); |
292 | } |
293 | return OkStatus(); |
294 | } |
295 | |
296 | Status CheckRemovingFaninFromSelf(absl::string_view node_name, |
297 | const TensorId& fanin, ErrorHandler handler) { |
298 | if (node_name == fanin.node()) { |
299 | return handler(absl::Substitute("can't remove fanin '$0' from self" , |
300 | fanin.ToString())); |
301 | } |
302 | return OkStatus(); |
303 | } |
304 | |
305 | string NodeMissingErrorMsg(absl::string_view node_name) { |
306 | return absl::Substitute("node '$0' was not found" , node_name); |
307 | } |
308 | |
309 | Status CheckNodeExists(absl::string_view node_name, NodeDef* node, |
310 | ErrorHandler handler) { |
311 | if (node == nullptr) { |
312 | return handler(NodeMissingErrorMsg(node_name)); |
313 | } |
314 | return OkStatus(); |
315 | } |
316 | |
317 | Status CheckPortRange(int port, int min, int max, ErrorHandler handler) { |
318 | if (port < min || port > max) { |
319 | if (max < min) { |
320 | return handler("no available ports as node has no regular fanins" ); |
321 | } |
322 | return handler( |
323 | absl::Substitute("port must be in range [$0, $1]" , min, max)); |
324 | } |
325 | return OkStatus(); |
326 | } |
327 | |
328 | string SwapNodeNamesSwitchControlErrorMsg(absl::string_view node_name) { |
329 | return absl::Substitute( |
330 | "can't swap node name '$0' as it will become a Switch control dependency" , |
331 | node_name); |
332 | } |
333 | |
334 | string GeneratedNameForIdentityConsumingSwitch( |
335 | const MutableGraphView::OutputPort& fanin) { |
336 | return AddPrefixToNodeName( |
337 | absl::StrCat(fanin.node->name(), "_" , fanin.port_id), |
338 | kMutableGraphViewCtrl); |
339 | } |
340 | |
341 | string PrintInTextFormat(const protobuf::MessageLite& message) { |
342 | // Unfortunately proto2::TextFormat::Printer::PrintToString does not have |
343 | // a overload for MessageLite so here we have to use |
344 | // MessageLite::ShortDebugString. |
345 | return message.ShortDebugString(); |
346 | } |
347 | |
348 | string PrintInTextFormat(const protobuf::Message& message) { |
349 | string message_text; |
350 | ::tensorflow::protobuf::TextFormat::Printer printer; |
351 | printer.SetSingleLineMode(true); |
352 | printer.PrintToString(message, &message_text); |
353 | if (!message_text.empty() && message_text[message_text.size() - 1] == ' ') { |
354 | message_text.resize(message_text.size() - 1); |
355 | } |
356 | return message_text; |
357 | } |
358 | |
359 | } // namespace |
360 | |
361 | void MutableGraphView::AddAndDedupFanouts(NodeDef* node) { |
362 | // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins |
363 | // exist, and all regular fanins come before controlling fanins. |
364 | absl::flat_hash_set<absl::string_view> fanins; |
365 | absl::flat_hash_set<absl::string_view> controlling_fanins; |
366 | int max_input_port = -1; |
367 | int pos = 0; |
368 | const int last_idx = node->input_size() - 1; |
369 | int last_pos = last_idx; |
370 | while (pos <= last_pos) { |
371 | TensorId tensor_id = ParseTensorName(node->input(pos)); |
372 | absl::string_view input_node_name = tensor_id.node(); |
373 | bool is_control_input = IsTensorIdControlling(tensor_id); |
374 | bool can_dedup_control_with_regular_input = |
375 | CanDedupControlWithRegularInput(*this, input_node_name); |
376 | bool can_dedup_control = |
377 | is_control_input && (can_dedup_control_with_regular_input || |
378 | controlling_fanins.contains(input_node_name)); |
379 | if (!gtl::InsertIfNotPresent(&fanins, input_node_name) && |
380 | can_dedup_control) { |
381 | node->mutable_input()->SwapElements(pos, last_pos); |
382 | --last_pos; |
383 | } else { |
384 | OutputPort output(nodes()[input_node_name], tensor_id.index()); |
385 | |
386 | if (is_control_input) { |
387 | fanouts()[output].emplace(node, Graph::kControlSlot); |
388 | } else { |
389 | max_input_port = pos; |
390 | max_regular_output_port()[output.node] = |
391 | std::max(max_regular_output_port()[output.node], output.port_id); |
392 | fanouts()[output].emplace(node, pos); |
393 | } |
394 | ++pos; |
395 | } |
396 | if (is_control_input) { |
397 | controlling_fanins.insert(input_node_name); |
398 | } |
399 | } |
400 | |
401 | if (last_pos < last_idx) { |
402 | node->mutable_input()->DeleteSubrange(last_pos + 1, last_idx - last_pos); |
403 | } |
404 | |
405 | if (max_input_port > -1) { |
406 | max_regular_input_port()[node] = max_input_port; |
407 | } |
408 | } |
409 | |
410 | void MutableGraphView::UpdateMaxRegularOutputPortForRemovedFanin( |
411 | const OutputPort& fanin, |
412 | const absl::flat_hash_set<InputPort>& fanin_fanouts) { |
413 | int max_port = max_regular_output_port()[fanin.node]; |
414 | if (!fanin_fanouts.empty() || max_port != fanin.port_id) { |
415 | return; |
416 | } |
417 | bool updated_max_port = false; |
418 | for (int i = fanin.port_id - 1; i >= 0; --i) { |
419 | OutputPort fanin_port(fanin.node, i); |
420 | if (!fanouts()[fanin_port].empty()) { |
421 | max_regular_output_port()[fanin.node] = i; |
422 | updated_max_port = true; |
423 | break; |
424 | } |
425 | } |
426 | if (!updated_max_port) { |
427 | max_regular_output_port().erase(fanin.node); |
428 | } |
429 | } |
430 | |
431 | void MutableGraphView::UpdateMaxRegularOutputPortForAddedFanin( |
432 | const OutputPort& fanin) { |
433 | if (max_regular_output_port()[fanin.node] < fanin.port_id) { |
434 | max_regular_output_port()[fanin.node] = fanin.port_id; |
435 | } |
436 | } |
437 | |
438 | const absl::flat_hash_set<MutableGraphView::InputPort>& |
439 | MutableGraphView::GetFanout(const GraphView::OutputPort& port) const { |
440 | return GetFanout(MutableGraphView::OutputPort(const_cast<NodeDef*>(port.node), |
441 | port.port_id)); |
442 | } |
443 | |
444 | absl::flat_hash_set<MutableGraphView::OutputPort> MutableGraphView::GetFanin( |
445 | const GraphView::InputPort& port) const { |
446 | return GetFanin(MutableGraphView::InputPort(const_cast<NodeDef*>(port.node), |
447 | port.port_id)); |
448 | } |
449 | |
450 | const MutableGraphView::OutputPort MutableGraphView::GetRegularFanin( |
451 | const GraphView::InputPort& port) const { |
452 | return GetRegularFanin(MutableGraphView::InputPort( |
453 | const_cast<NodeDef*>(port.node), port.port_id)); |
454 | } |
455 | |
456 | NodeDef* MutableGraphView::AddNode(NodeDef&& node) { |
457 | auto* node_in_graph = graph()->add_node(); |
458 | *node_in_graph = std::move(node); |
459 | |
460 | AddUniqueNodeOrDie(node_in_graph); |
461 | |
462 | AddAndDedupFanouts(node_in_graph); |
463 | return node_in_graph; |
464 | } |
465 | |
466 | Status MutableGraphView::AddSubgraph(GraphDef&& subgraph) { |
467 | // 1. Add all new functions and check that functions with the same name |
468 | // have identical definition. |
469 | const int function_size = subgraph.library().function_size(); |
470 | if (function_size > 0) { |
471 | absl::flat_hash_map<absl::string_view, const FunctionDef*> graph_fdefs; |
472 | for (const FunctionDef& fdef : graph()->library().function()) { |
473 | graph_fdefs.emplace(fdef.signature().name(), &fdef); |
474 | } |
475 | |
476 | for (FunctionDef& fdef : *subgraph.mutable_library()->mutable_function()) { |
477 | const auto graph_fdef = graph_fdefs.find(fdef.signature().name()); |
478 | |
479 | if (graph_fdef == graph_fdefs.end()) { |
480 | VLOG(3) << "Add new function definition: " << fdef.signature().name(); |
481 | graph()->mutable_library()->add_function()->Swap(&fdef); |
482 | } else { |
483 | if (!FunctionDefsEqual(fdef, *graph_fdef->second)) { |
484 | return MutationError( |
485 | "AddSubgraph" , |
486 | absl::Substitute("function_size=$0" , function_size), |
487 | absl::StrCat( |
488 | "Found different function definition with the same name: " , |
489 | fdef.signature().name())); |
490 | } |
491 | } |
492 | } |
493 | } |
494 | |
495 | // 2. Add all nodes to the underlying graph. |
496 | int node_size_before = graph()->node_size(); |
497 | |
498 | for (NodeDef& node : *subgraph.mutable_node()) { |
499 | auto* node_in_graph = graph()->add_node(); |
500 | node_in_graph->Swap(&node); |
501 | TF_RETURN_IF_ERROR(AddUniqueNode(node_in_graph)); |
502 | } |
503 | |
504 | // TODO(ezhulenev, lyandy): Right now AddAndDedupFanouts do not check that |
505 | // fanins actually exists in the graph, and there is already TODO for that. |
506 | |
507 | for (int i = node_size_before; i < graph()->node_size(); ++i) { |
508 | NodeDef* node = graph()->mutable_node(i); |
509 | AddAndDedupFanouts(node); |
510 | } |
511 | |
512 | return OkStatus(); |
513 | } |
514 | |
515 | Status MutableGraphView::UpdateNode( |
516 | absl::string_view node_name, absl::string_view op, absl::string_view device, |
517 | absl::Span<const std::pair<string, AttrValue>> attrs) { |
518 | auto error_status = [node_name, op, device, attrs](absl::string_view msg) { |
519 | std::vector<string> attr_strs; |
520 | attr_strs.reserve(attrs.size()); |
521 | for (const auto& attr : attrs) { |
522 | string attr_str = absl::Substitute("('$0', $1)" , attr.first, |
523 | PrintInTextFormat(attr.second)); |
524 | attr_strs.push_back(attr_str); |
525 | } |
526 | string params = |
527 | absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}" , |
528 | node_name, op, device, absl::StrJoin(attr_strs, ", " )); |
529 | return MutationError("UpdateNodeOp" , params, msg); |
530 | }; |
531 | |
532 | NodeDef* node = GetNode(node_name); |
533 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
534 | |
535 | MutableGraphView::OutputPort control_port(node, Graph::kControlSlot); |
536 | auto control_fanouts = GetFanout(control_port); |
537 | if (op == "Switch" && !control_fanouts.empty()) { |
538 | return error_status( |
539 | "can't change node op to Switch when node drives a control dependency " |
540 | "(alternatively, we could add the identity node needed, but it seems " |
541 | "like an unlikely event and probably a mistake)" ); |
542 | } |
543 | |
544 | if (node->device() != device) { |
545 | node->set_device(string(device)); |
546 | } |
547 | node->mutable_attr()->clear(); |
548 | for (const auto& attr : attrs) { |
549 | (*node->mutable_attr())[attr.first] = attr.second; |
550 | } |
551 | |
552 | if (node->op() == op) { |
553 | return OkStatus(); |
554 | } |
555 | |
556 | node->set_op(string(op)); |
557 | |
558 | if (CanDedupControlWithRegularInput(*this, *node)) { |
559 | for (const auto& control_fanout : control_fanouts) { |
560 | if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) { |
561 | RemoveControllingFaninInternal(control_fanout.node, node); |
562 | } |
563 | } |
564 | } |
565 | |
566 | return OkStatus(); |
567 | } |
568 | |
569 | Status MutableGraphView::UpdateNodeName(absl::string_view from_node_name, |
570 | absl::string_view to_node_name, |
571 | bool update_fanouts) { |
572 | auto error_status = [from_node_name, to_node_name, |
573 | update_fanouts](absl::string_view msg) { |
574 | string params = absl::Substitute( |
575 | "from_node_name='$0', to_node_name='$1', update_fanouts=$2" , |
576 | from_node_name, to_node_name, update_fanouts); |
577 | return MutationError("UpdateNodeName" , params, msg); |
578 | }; |
579 | |
580 | NodeDef* node = GetNode(from_node_name); |
581 | TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, node, error_status)); |
582 | |
583 | if (node->name() == to_node_name) { |
584 | return OkStatus(); |
585 | } |
586 | if (HasNode(to_node_name)) { |
587 | return error_status( |
588 | "can't update node name because new node name is in use" ); |
589 | } |
590 | auto max_output_port = max_regular_output_port().find(node); |
591 | const bool has_max_output_port = |
592 | max_output_port != max_regular_output_port().end(); |
593 | auto control_fanouts = fanouts().find({node, Graph::kControlSlot}); |
594 | |
595 | if (update_fanouts) { |
596 | SwapControlledFanoutInputs(*this, control_fanouts, to_node_name); |
597 | if (has_max_output_port) { |
598 | SwapRegularFanoutInputs(&fanouts(), node, to_node_name, |
599 | max_output_port->second); |
600 | } |
601 | } else if (has_max_output_port || |
602 | HasFanoutValue(fanouts(), control_fanouts)) { |
603 | return error_status("can't update node name because node has fanouts" ); |
604 | } |
605 | |
606 | nodes().erase(node->name()); |
607 | node->set_name(string(to_node_name)); |
608 | nodes().emplace(node->name(), node); |
609 | return OkStatus(); |
610 | } |
611 | |
612 | Status MutableGraphView::SwapNodeNames(absl::string_view from_node_name, |
613 | absl::string_view to_node_name, |
614 | bool update_fanouts) { |
615 | auto error_status = [from_node_name, to_node_name, |
616 | update_fanouts](absl::string_view msg) { |
617 | string params = absl::Substitute( |
618 | "from_node_name='$0', to_node_name='$1', update_fanouts=$2" , |
619 | from_node_name, to_node_name, update_fanouts); |
620 | return MutationError("SwapNodeNames" , params, msg); |
621 | }; |
622 | |
623 | NodeDef* from_node = GetNode(from_node_name); |
624 | TF_RETURN_IF_ERROR(CheckNodeExists(from_node_name, from_node, error_status)); |
625 | if (from_node_name == to_node_name) { |
626 | return OkStatus(); |
627 | } |
628 | NodeDef* to_node = GetNode(to_node_name); |
629 | TF_RETURN_IF_ERROR(CheckNodeExists(to_node_name, to_node, error_status)); |
630 | |
631 | auto swap_names = [this, from_node, to_node]() { |
632 | nodes().erase(from_node->name()); |
633 | nodes().erase(to_node->name()); |
634 | std::swap(*from_node->mutable_name(), *to_node->mutable_name()); |
635 | nodes().emplace(from_node->name(), from_node); |
636 | nodes().emplace(to_node->name(), to_node); |
637 | }; |
638 | |
639 | if (update_fanouts) { |
640 | SwapFanoutInputs(*this, &fanouts(), &max_regular_output_port(), from_node, |
641 | to_node); |
642 | swap_names(); |
643 | return OkStatus(); |
644 | } |
645 | |
646 | bool from_is_switch = IsSwitch(*from_node); |
647 | MutableGraphView::OutputPort to_control(to_node, Graph::kControlSlot); |
648 | auto to_control_fanouts = fanouts().find(to_control); |
649 | if (from_is_switch && HasFanoutValue(fanouts(), to_control_fanouts)) { |
650 | return error_status(SwapNodeNamesSwitchControlErrorMsg(from_node_name)); |
651 | } |
652 | |
653 | bool to_is_switch = IsSwitch(*to_node); |
654 | MutableGraphView::OutputPort from_control(from_node, Graph::kControlSlot); |
655 | auto from_control_fanouts = fanouts().find(from_control); |
656 | if (to_is_switch && HasFanoutValue(fanouts(), from_control_fanouts)) { |
657 | return error_status(SwapNodeNamesSwitchControlErrorMsg(to_node_name)); |
658 | } |
659 | |
660 | // Swap node names. |
661 | swap_names(); |
662 | |
663 | // Swap controlling fanouts. |
664 | // |
665 | // Note: To and from control fanout iterators are still valid as no mutations |
666 | // has been performed on fanouts(). |
667 | SwapFanoutsMapValues(&fanouts(), from_control, from_control_fanouts, |
668 | to_control, to_control_fanouts); |
669 | |
670 | // Swap regular fanouts. |
671 | SwapRegularFanoutsAndMaxPortValues(&fanouts(), &max_regular_output_port(), |
672 | from_node, to_node); |
673 | |
674 | // Update fanins to remove self loops. |
675 | auto update_fanins = [this](NodeDef* node, absl::string_view old_node_name) { |
676 | for (int i = 0; i < node->input_size(); ++i) { |
677 | TensorId tensor_id = ParseTensorName(node->input(i)); |
678 | if (tensor_id.node() == node->name()) { |
679 | const int idx = tensor_id.index(); |
680 | const int node_idx = |
681 | IsTensorIdControlling(tensor_id) ? Graph::kControlSlot : i; |
682 | |
683 | MutableGraphView::OutputPort from_fanin(node, idx); |
684 | absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin]; |
685 | from_fanouts->erase({node, node_idx}); |
686 | UpdateMaxRegularOutputPortForRemovedFanin(from_fanin, *from_fanouts); |
687 | |
688 | MutableGraphView::OutputPort to_fanin(nodes().at(old_node_name), idx); |
689 | fanouts()[to_fanin].insert({node, node_idx}); |
690 | UpdateMaxRegularOutputPortForAddedFanin(to_fanin); |
691 | node->set_input(i, TensorIdToString({old_node_name, idx})); |
692 | } |
693 | } |
694 | }; |
695 | update_fanins(from_node, to_node->name()); |
696 | update_fanins(to_node, from_node->name()); |
697 | |
698 | // Dedup control dependencies. |
699 | auto dedup_control_fanouts = |
700 | [this](NodeDef* node, const FanoutsMap::iterator& control_fanouts) { |
701 | if (CanDedupControlWithRegularInput(*this, *node) && |
702 | control_fanouts != fanouts().end()) { |
703 | for (auto it = control_fanouts->second.begin(); |
704 | it != control_fanouts->second.end();) { |
705 | // Advance `it` before invalidation from removal. |
706 | const auto& control_fanout = *it++; |
707 | if (HasRegularFaninNode(*this, *control_fanout.node, |
708 | node->name())) { |
709 | RemoveControllingFaninInternal(control_fanout.node, node); |
710 | } |
711 | } |
712 | } |
713 | }; |
714 | auto dedup_switch_control = [this, dedup_control_fanouts](NodeDef* node) { |
715 | OutputPort port; |
716 | port.node = node; |
717 | const int max_port = |
718 | gtl::FindWithDefault(max_regular_output_port(), node, -1); |
719 | for (int i = 0; i <= max_port; ++i) { |
720 | port.port_id = i; |
721 | auto it = fanouts().find(port); |
722 | if (it == fanouts().end()) { |
723 | continue; |
724 | } |
725 | for (const auto& fanout : it->second) { |
726 | auto fanout_controls = |
727 | fanouts().find({fanout.node, Graph::kControlSlot}); |
728 | dedup_control_fanouts(fanout.node, fanout_controls); |
729 | } |
730 | } |
731 | }; |
732 | |
733 | if (!from_is_switch) { |
734 | if (to_is_switch) { |
735 | dedup_switch_control(from_node); |
736 | } else { |
737 | // Fetch iterator again as the original iterator might have been |
738 | // invalidated by container rehash triggered due to mutations. |
739 | auto from_control_fanouts = fanouts().find(from_control); |
740 | dedup_control_fanouts(from_node, from_control_fanouts); |
741 | } |
742 | } |
743 | if (!to_is_switch) { |
744 | if (from_is_switch) { |
745 | dedup_switch_control(to_node); |
746 | } else { |
747 | // Fetch iterator again as the original iterator might have been |
748 | // invalidated by container rehash triggered due to mutations. |
749 | auto to_control_fanouts = fanouts().find(to_control); |
750 | dedup_control_fanouts(to_node, to_control_fanouts); |
751 | } |
752 | } |
753 | |
754 | return OkStatus(); |
755 | } |
756 | |
757 | Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name, |
758 | absl::string_view to_node_name) { |
759 | NodeDef* from_node = GetNode(from_node_name); |
760 | TF_RETURN_IF_ERROR( |
761 | CheckNodeExists(from_node_name, from_node, |
762 | UpdateFanoutsError(from_node_name, to_node_name))); |
763 | NodeDef* to_node = GetNode(to_node_name); |
764 | TF_RETURN_IF_ERROR(CheckNodeExists( |
765 | to_node_name, to_node, UpdateFanoutsError(from_node_name, to_node_name))); |
766 | |
767 | return UpdateFanoutsInternal(from_node, to_node); |
768 | } |
769 | |
770 | Status MutableGraphView::UpdateFanoutsInternal(NodeDef* from_node, |
771 | NodeDef* to_node) { |
772 | VLOG(2) << absl::Substitute("Update fanouts from '$0' to '$1'." , |
773 | from_node->name(), to_node->name()); |
774 | if (from_node == to_node) { |
775 | return OkStatus(); |
776 | } |
777 | |
778 | // Update internal state with the new output_port->input_port edge. |
779 | const auto add_edge = [this](const OutputPort& output_port, |
780 | const InputPort& input_port) { |
781 | fanouts()[output_port].insert(input_port); |
782 | }; |
783 | |
784 | // Remove invalidated edge from the internal state. |
785 | const auto remove_edge = [this](const OutputPort& output_port, |
786 | const InputPort& input_port) { |
787 | fanouts()[output_port].erase(input_port); |
788 | }; |
789 | |
790 | // For the control fanouts we do not know the input index in a NodeDef, |
791 | // so we have to traverse all control inputs. |
792 | |
793 | auto control_fanouts = |
794 | GetFanout(GraphView::OutputPort(from_node, Graph::kControlSlot)); |
795 | |
796 | bool to_node_is_switch = IsSwitch(*to_node); |
797 | for (const InputPort& control_port : control_fanouts) { |
798 | // Node can't be control dependency of itself. |
799 | if (control_port.node == to_node) continue; |
800 | |
801 | // Can't add Switch node as a control dependency. |
802 | if (to_node_is_switch) { |
803 | // Trying to add a Switch as a control dependency, which if allowed will |
804 | // make the graph invalid. |
805 | return UpdateFanoutsError(from_node->name(), to_node->name())( |
806 | absl::Substitute("can't update fanouts to node '$0' as it will " |
807 | "become a Switch control dependency" , |
808 | to_node->name())); |
809 | } |
810 | |
811 | NodeDef* node = control_port.node; |
812 | RemoveControllingFaninInternal(node, from_node); |
813 | AddFaninInternal(node, {to_node, Graph::kControlSlot}); |
814 | } |
815 | |
816 | // First we update regular fanouts. For the regular fanouts |
817 | // `input_port:port_id` is the input index in NodeDef. |
818 | |
819 | auto regular_edges = |
820 | GetFanoutEdges(*from_node, /*include_controlled_edges=*/false); |
821 | |
822 | // Maximum index of the `from_node` output tensor that is still used as an |
823 | // input to some other node. |
824 | int keep_max_regular_output_port = -1; |
825 | |
826 | for (const Edge& edge : regular_edges) { |
827 | const OutputPort output_port = edge.src; |
828 | const InputPort input_port = edge.dst; |
829 | |
830 | // If the `to_node` reads from the `from_node`, skip this edge (see |
831 | // AddAndUpdateFanoutsWithoutSelfLoops test for an example). |
832 | if (input_port.node == to_node) { |
833 | keep_max_regular_output_port = |
834 | std::max(keep_max_regular_output_port, output_port.port_id); |
835 | continue; |
836 | } |
837 | |
838 | // Update input at destination node. |
839 | input_port.node->set_input( |
840 | input_port.port_id, |
841 | TensorIdToString({to_node->name(), output_port.port_id})); |
842 | |
843 | // Remove old edge between the `from_node` and the fanout node. |
844 | remove_edge(output_port, input_port); |
845 | // Add an edge between the `to_node` and new fanout node. |
846 | add_edge(OutputPort(to_node, output_port.port_id), input_port); |
847 | // Dedup control dependency. |
848 | if (CanDedupControlWithRegularInput(*this, *to_node)) { |
849 | RemoveControllingFaninInternal(input_port.node, to_node); |
850 | } |
851 | } |
852 | |
853 | // Because we update all regular fanouts of `from_node`, we can just copy |
854 | // the value `num_regular_outputs`. |
855 | max_regular_output_port()[to_node] = max_regular_output_port()[from_node]; |
856 | |
857 | // Check if all fanouts were updated to read from the `to_node`. |
858 | if (keep_max_regular_output_port >= 0) { |
859 | max_regular_output_port()[from_node] = keep_max_regular_output_port; |
860 | } else { |
861 | max_regular_output_port().erase(from_node); |
862 | } |
863 | |
864 | return OkStatus(); |
865 | } |
866 | |
867 | bool MutableGraphView::AddFaninInternal(NodeDef* node, |
868 | const OutputPort& fanin) { |
869 | int num_regular_fanins = |
870 | NumFanins(*node, /*include_controlling_nodes=*/false); |
871 | bool input_is_control = IsOutputPortControlling(fanin); |
872 | bool can_dedup_control_with_regular_input = |
873 | CanDedupControlWithRegularInput(*this, *fanin.node); |
874 | // Don't add duplicate control dependencies. |
875 | if (input_is_control) { |
876 | const int start = |
877 | can_dedup_control_with_regular_input ? 0 : num_regular_fanins; |
878 | for (int i = start; i < node->input_size(); ++i) { |
879 | if (ParseTensorName(node->input(i)).node() == fanin.node->name()) { |
880 | return false; |
881 | } |
882 | } |
883 | } |
884 | |
885 | InputPort input; |
886 | input.node = node; |
887 | input.port_id = input_is_control ? Graph::kControlSlot : num_regular_fanins; |
888 | |
889 | node->add_input(TensorIdToString({fanin.node->name(), fanin.port_id})); |
890 | if (!input_is_control) { |
891 | const int last_node_input = node->input_size() - 1; |
892 | // If there are control dependencies in node, move newly inserted fanin to |
893 | // be before such control dependencies. |
894 | if (num_regular_fanins < last_node_input) { |
895 | node->mutable_input()->SwapElements(last_node_input, num_regular_fanins); |
896 | } |
897 | } |
898 | |
899 | fanouts()[fanin].insert(input); |
900 | if (max_regular_output_port()[fanin.node] < fanin.port_id) { |
901 | max_regular_output_port()[fanin.node] = fanin.port_id; |
902 | } |
903 | |
904 | // Update max input port and dedup control dependencies. |
905 | if (!input_is_control) { |
906 | max_regular_input_port()[node] = num_regular_fanins; |
907 | if (can_dedup_control_with_regular_input) { |
908 | RemoveControllingFaninInternal(node, fanin.node); |
909 | } |
910 | } |
911 | |
912 | return true; |
913 | } |
914 | |
915 | Status MutableGraphView::AddRegularFanin(absl::string_view node_name, |
916 | const TensorId& fanin) { |
917 | auto error_status = [node_name, fanin](absl::string_view msg) { |
918 | string params = absl::Substitute("node_name='$0', fanin='$1'" , node_name, |
919 | fanin.ToString()); |
920 | return MutationError("AddRegularFanin" , params, msg); |
921 | }; |
922 | |
923 | TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status)); |
924 | TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status)); |
925 | NodeDef* node = GetNode(node_name); |
926 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
927 | NodeDef* fanin_node = GetNode(fanin.node()); |
928 | TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); |
929 | |
930 | AddFaninInternal(node, {fanin_node, fanin.index()}); |
931 | return OkStatus(); |
932 | } |
933 | |
934 | Status MutableGraphView::AddRegularFaninByPort(absl::string_view node_name, |
935 | int port, |
936 | const TensorId& fanin) { |
937 | auto error_status = [node_name, port, fanin](absl::string_view msg) { |
938 | string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'" , |
939 | node_name, port, fanin.ToString()); |
940 | return MutationError("AddRegularFaninByPort" , params, msg); |
941 | }; |
942 | |
943 | TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status)); |
944 | TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status)); |
945 | NodeDef* node = GetNode(node_name); |
946 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
947 | const int num_regular_fanins = |
948 | NumFanins(*node, /*include_controlling_nodes=*/false); |
949 | TF_RETURN_IF_ERROR( |
950 | CheckPortRange(port, /*min=*/0, num_regular_fanins, error_status)); |
951 | NodeDef* fanin_node = GetNode(fanin.node()); |
952 | TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); |
953 | |
954 | const int last_node_input = node->input_size(); |
955 | node->add_input(TensorIdToString(fanin)); |
956 | node->mutable_input()->SwapElements(num_regular_fanins, last_node_input); |
957 | for (int i = num_regular_fanins - 1; i >= port; --i) { |
958 | TensorId tensor_id = ParseTensorName(node->input(i)); |
959 | OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index()); |
960 | absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port]; |
961 | fanouts_set->erase({node, i}); |
962 | fanouts_set->insert({node, i + 1}); |
963 | node->mutable_input()->SwapElements(i, i + 1); |
964 | } |
965 | |
966 | OutputPort fanin_port(fanin_node, fanin.index()); |
967 | fanouts()[fanin_port].insert({node, port}); |
968 | UpdateMaxRegularOutputPortForAddedFanin(fanin_port); |
969 | |
970 | max_regular_input_port()[node] = num_regular_fanins; |
971 | if (CanDedupControlWithRegularInput(*this, *fanin_node)) { |
972 | RemoveControllingFaninInternal(node, fanin_node); |
973 | } |
974 | |
975 | return OkStatus(); |
976 | } |
977 | |
978 | NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name, |
979 | const OutputPort& fanin, |
980 | string* error_msg) { |
981 | if (!IsSwitch(*fanin.node)) { |
982 | return fanin.node; |
983 | } else { |
984 | if (IsOutputPortControlling(fanin)) { |
985 | // Can't add a Switch node control dependency. |
986 | TensorId tensor_id(fanin.node->name(), fanin.port_id); |
987 | *error_msg = absl::Substitute( |
988 | "can't add fanin '$0' as it will become a Switch control dependency" , |
989 | tensor_id.ToString()); |
990 | return nullptr; |
991 | } |
992 | // We can't anchor control dependencies directly on the switch node: unlike |
993 | // other nodes only one of the outputs of the switch node will be generated |
994 | // when the switch node is executed, and we need to make sure the control |
995 | // dependency is only triggered when the corresponding output is triggered. |
996 | // We start by looking for an identity node connected to the output of the |
997 | // switch node, and use it to anchor the control dependency. |
998 | for (const auto& fanout : GetFanout(fanin)) { |
999 | if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) { |
1000 | if (fanout.node->name() == node_name) { |
1001 | *error_msg = |
1002 | absl::Substitute("can't add found fanin '$0' to self" , |
1003 | AsControlDependency(fanout.node->name())); |
1004 | return nullptr; |
1005 | } |
1006 | return fanout.node; |
1007 | } |
1008 | } |
1009 | |
1010 | // No node found, check if node to be created is itself. |
1011 | if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) { |
1012 | *error_msg = absl::Substitute("can't add generated fanin '$0' to self" , |
1013 | AsControlDependency(string(node_name))); |
1014 | } |
1015 | } |
1016 | return nullptr; |
1017 | } |
1018 | |
1019 | NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch( |
1020 | const OutputPort& fanin) { |
1021 | // We haven't found an existing node where we can anchor the control |
1022 | // dependency: add a new identity node. |
1023 | string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin); |
1024 | NodeDef* identity_node = GetNode(identity_name); |
1025 | if (identity_node == nullptr) { |
1026 | NodeDef new_node; |
1027 | new_node.set_name(identity_name); |
1028 | new_node.set_op("Identity" ); |
1029 | new_node.set_device(fanin.node->device()); |
1030 | (*new_node.mutable_attr())["T" ].set_type(fanin.node->attr().at("T" ).type()); |
1031 | new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id})); |
1032 | identity_node = AddNode(std::move(new_node)); |
1033 | } |
1034 | return identity_node; |
1035 | } |
1036 | |
1037 | Status MutableGraphView::AddControllingFanin(absl::string_view node_name, |
1038 | const TensorId& fanin) { |
1039 | auto error_status = [node_name, fanin](absl::string_view msg) { |
1040 | string params = absl::Substitute("node_name='$0', fanin='$1'" , node_name, |
1041 | fanin.ToString()); |
1042 | return MutationError("AddControllingFanin" , params, msg); |
1043 | }; |
1044 | |
1045 | TF_RETURN_IF_ERROR(CheckFaninIsValid(fanin, error_status)); |
1046 | TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status)); |
1047 | NodeDef* node = GetNode(node_name); |
1048 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1049 | NodeDef* fanin_node = GetNode(fanin.node()); |
1050 | TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); |
1051 | |
1052 | OutputPort fanin_port(fanin_node, fanin.index()); |
1053 | |
1054 | string error_msg = "" ; |
1055 | NodeDef* control_node = GetControllingFaninToAdd( |
1056 | node_name, {fanin_node, fanin.index()}, &error_msg); |
1057 | if (!error_msg.empty()) { |
1058 | return error_status(error_msg); |
1059 | } |
1060 | if (control_node == nullptr) { |
1061 | control_node = GetOrCreateIdentityConsumingSwitch(fanin_port); |
1062 | } |
1063 | AddFaninInternal(node, {control_node, Graph::kControlSlot}); |
1064 | |
1065 | return OkStatus(); |
1066 | } |
1067 | |
1068 | bool MutableGraphView::RemoveRegularFaninInternal(NodeDef* node, |
1069 | const OutputPort& fanin) { |
1070 | auto remove_input = [this, node](const OutputPort& fanin_port, |
1071 | int node_input_port, bool update_max_port) { |
1072 | InputPort input(node, node_input_port); |
1073 | |
1074 | absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port]; |
1075 | fanouts_set->erase(input); |
1076 | if (update_max_port) { |
1077 | UpdateMaxRegularOutputPortForRemovedFanin(fanin_port, *fanouts_set); |
1078 | } |
1079 | return fanouts_set; |
1080 | }; |
1081 | |
1082 | auto mutable_inputs = node->mutable_input(); |
1083 | bool modified = false; |
1084 | const int num_regular_fanins = |
1085 | NumFanins(*node, /*include_controlling_nodes=*/false); |
1086 | int i; |
1087 | int curr_pos = 0; |
1088 | for (i = 0; i < num_regular_fanins; ++i) { |
1089 | TensorId tensor_id = ParseTensorName(node->input(i)); |
1090 | if (tensor_id.node() == fanin.node->name() && |
1091 | tensor_id.index() == fanin.port_id) { |
1092 | remove_input(fanin, i, /*update_max_port=*/true); |
1093 | modified = true; |
1094 | } else if (modified) { |
1095 | // Regular inputs will need to have their ports updated. |
1096 | OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index()); |
1097 | auto fanouts_set = remove_input(fanin_port, i, /*update_max_port=*/false); |
1098 | fanouts_set->insert({node, curr_pos}); |
1099 | // Shift inputs to be retained. |
1100 | mutable_inputs->SwapElements(i, curr_pos); |
1101 | ++curr_pos; |
1102 | } else { |
1103 | // Skip inputs to be retained until first modification. |
1104 | ++curr_pos; |
1105 | } |
1106 | } |
1107 | |
1108 | if (modified) { |
1109 | const int last_regular_input_port = curr_pos - 1; |
1110 | if (last_regular_input_port < 0) { |
1111 | max_regular_input_port().erase(node); |
1112 | } else { |
1113 | max_regular_input_port()[node] = last_regular_input_port; |
1114 | } |
1115 | if (curr_pos < i) { |
1116 | // Remove fanins from node inputs. |
1117 | mutable_inputs->DeleteSubrange(curr_pos, i - curr_pos); |
1118 | } |
1119 | } |
1120 | |
1121 | return modified; |
1122 | } |
1123 | |
1124 | Status MutableGraphView::RemoveRegularFanin(absl::string_view node_name, |
1125 | const TensorId& fanin) { |
1126 | auto error_status = [node_name, fanin](absl::string_view msg) { |
1127 | string params = absl::Substitute("node_name='$0', fanin='$1'" , node_name, |
1128 | fanin.ToString()); |
1129 | return MutationError("RemoveRegularFanin" , params, msg); |
1130 | }; |
1131 | |
1132 | TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status)); |
1133 | TF_RETURN_IF_ERROR( |
1134 | CheckRemovingFaninFromSelf(node_name, fanin, error_status)); |
1135 | NodeDef* node = GetNode(node_name); |
1136 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1137 | NodeDef* fanin_node = GetNode(fanin.node()); |
1138 | TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); |
1139 | |
1140 | RemoveRegularFaninInternal(node, {fanin_node, fanin.index()}); |
1141 | return OkStatus(); |
1142 | } |
1143 | |
1144 | Status MutableGraphView::RemoveRegularFaninByPort(absl::string_view node_name, |
1145 | int port) { |
1146 | auto error_status = [node_name, port](absl::string_view msg) { |
1147 | string params = |
1148 | absl::Substitute("node_name='$0', port=$1" , node_name, port); |
1149 | return MutationError("RemoveRegularFaninByPort" , params, msg); |
1150 | }; |
1151 | |
1152 | NodeDef* node = GetNode(node_name); |
1153 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1154 | const int last_regular_fanin_port = |
1155 | gtl::FindWithDefault(max_regular_input_port(), node, -1); |
1156 | TF_RETURN_IF_ERROR( |
1157 | CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status)); |
1158 | |
1159 | TensorId tensor_id = ParseTensorName(node->input(port)); |
1160 | OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index()); |
1161 | fanouts()[fanin_port].erase({node, port}); |
1162 | auto mutable_inputs = node->mutable_input(); |
1163 | for (int i = port + 1; i <= last_regular_fanin_port; ++i) { |
1164 | TensorId tensor_id = ParseTensorName(node->input(i)); |
1165 | OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index()); |
1166 | absl::flat_hash_set<InputPort>* fanouts_set = &fanouts()[fanin_port]; |
1167 | fanouts_set->erase({node, i}); |
1168 | fanouts_set->insert({node, i - 1}); |
1169 | mutable_inputs->SwapElements(i - 1, i); |
1170 | } |
1171 | const int last_node_input = node->input_size() - 1; |
1172 | if (last_regular_fanin_port < last_node_input) { |
1173 | mutable_inputs->SwapElements(last_regular_fanin_port, last_node_input); |
1174 | } |
1175 | mutable_inputs->RemoveLast(); |
1176 | |
1177 | const int updated_last_regular_input_port = last_regular_fanin_port - 1; |
1178 | if (updated_last_regular_input_port < 0) { |
1179 | max_regular_input_port().erase(node); |
1180 | } else { |
1181 | max_regular_input_port()[node] = updated_last_regular_input_port; |
1182 | } |
1183 | |
1184 | return OkStatus(); |
1185 | } |
1186 | |
1187 | bool MutableGraphView::RemoveControllingFaninInternal(NodeDef* node, |
1188 | NodeDef* fanin_node) { |
1189 | for (int i = node->input_size() - 1; i >= 0; --i) { |
1190 | TensorId tensor_id = ParseTensorName(node->input(i)); |
1191 | if (tensor_id.index() > Graph::kControlSlot) { |
1192 | break; |
1193 | } |
1194 | if (tensor_id.node() == fanin_node->name()) { |
1195 | fanouts()[{fanin_node, Graph::kControlSlot}].erase( |
1196 | {node, Graph::kControlSlot}); |
1197 | node->mutable_input()->SwapElements(i, node->input_size() - 1); |
1198 | node->mutable_input()->RemoveLast(); |
1199 | return true; |
1200 | } |
1201 | } |
1202 | return false; |
1203 | } |
1204 | |
1205 | Status MutableGraphView::RemoveControllingFanin( |
1206 | absl::string_view node_name, absl::string_view fanin_node_name) { |
1207 | auto error_status = [node_name, fanin_node_name](absl::string_view msg) { |
1208 | string params = absl::Substitute("node_name='$0', fanin_node_name='$1'" , |
1209 | node_name, fanin_node_name); |
1210 | return MutationError("RemoveControllingFanin" , params, msg); |
1211 | }; |
1212 | |
1213 | TF_RETURN_IF_ERROR(CheckRemovingFaninFromSelf( |
1214 | node_name, {fanin_node_name, Graph::kControlSlot}, error_status)); |
1215 | NodeDef* node = GetNode(node_name); |
1216 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1217 | NodeDef* fanin_node = GetNode(fanin_node_name); |
1218 | TF_RETURN_IF_ERROR( |
1219 | CheckNodeExists(fanin_node_name, fanin_node, error_status)); |
1220 | |
1221 | RemoveControllingFaninInternal(node, fanin_node); |
1222 | return OkStatus(); |
1223 | } |
1224 | |
1225 | Status MutableGraphView::RemoveAllFanins(absl::string_view node_name, |
1226 | bool keep_controlling_fanins) { |
1227 | NodeDef* node = GetNode(node_name); |
1228 | if (node == nullptr) { |
1229 | string params = |
1230 | absl::Substitute("node_name='$0', keep_controlling_fanins=$1" , |
1231 | node_name, keep_controlling_fanins); |
1232 | return MutationError("RemoveAllFanins" , params, |
1233 | NodeMissingErrorMsg(node_name)); |
1234 | } |
1235 | |
1236 | if (node->input().empty()) { |
1237 | return OkStatus(); |
1238 | } |
1239 | |
1240 | const int num_regular_fanins = |
1241 | NumFanins(*node, /*include_controlling_nodes=*/false); |
1242 | RemoveFaninsInternal(node, keep_controlling_fanins); |
1243 | if (keep_controlling_fanins) { |
1244 | if (num_regular_fanins == 0) { |
1245 | return OkStatus(); |
1246 | } else if (num_regular_fanins < node->input_size()) { |
1247 | node->mutable_input()->DeleteSubrange(0, num_regular_fanins); |
1248 | } else { |
1249 | node->clear_input(); |
1250 | } |
1251 | } else { |
1252 | node->clear_input(); |
1253 | } |
1254 | return OkStatus(); |
1255 | } |
1256 | |
1257 | Status MutableGraphView::UpdateFanin(absl::string_view node_name, |
1258 | const TensorId& from_fanin, |
1259 | const TensorId& to_fanin) { |
1260 | auto error_status = [node_name, from_fanin, to_fanin](absl::string_view msg) { |
1261 | string params = |
1262 | absl::Substitute("node_name='$0', from_fanin='$1', to_fanin='$2'" , |
1263 | node_name, from_fanin.ToString(), to_fanin.ToString()); |
1264 | return MutationError("UpdateFanin" , params, msg); |
1265 | }; |
1266 | |
1267 | TF_RETURN_IF_ERROR(CheckFaninIsValid(from_fanin, error_status)); |
1268 | TF_RETURN_IF_ERROR(CheckFaninIsValid(to_fanin, error_status)); |
1269 | NodeDef* node = GetNode(node_name); |
1270 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1271 | NodeDef* from_fanin_node = GetNode(from_fanin.node()); |
1272 | TF_RETURN_IF_ERROR( |
1273 | CheckNodeExists(from_fanin.node(), from_fanin_node, error_status)); |
1274 | NodeDef* to_fanin_node = GetNode(to_fanin.node()); |
1275 | TF_RETURN_IF_ERROR( |
1276 | CheckNodeExists(to_fanin.node(), to_fanin_node, error_status)); |
1277 | |
1278 | // When replacing a non control dependency fanin with a control dependency, or |
1279 | // vice versa, remove and add, so ports can be updated properly in fanout(s). |
1280 | bool to_fanin_is_control = IsTensorIdControlling(to_fanin); |
1281 | if (to_fanin_is_control && IsSwitch(*to_fanin_node)) { |
1282 | // Can't add Switch node as a control dependency. |
1283 | return error_status( |
1284 | absl::Substitute("can't update to fanin '$0' as it will become a " |
1285 | "Switch control dependency" , |
1286 | to_fanin.ToString())); |
1287 | } |
1288 | if (node_name == from_fanin.node() || node_name == to_fanin.node()) { |
1289 | return error_status("can't update fanin to or from self" ); |
1290 | } |
1291 | |
1292 | if (from_fanin == to_fanin) { |
1293 | return OkStatus(); |
1294 | } |
1295 | |
1296 | bool from_fanin_is_control = IsTensorIdControlling(from_fanin); |
1297 | if (from_fanin_is_control || to_fanin_is_control) { |
1298 | bool modified = false; |
1299 | if (from_fanin_is_control) { |
1300 | modified |= RemoveControllingFaninInternal(node, from_fanin_node); |
1301 | } else { |
1302 | modified |= RemoveRegularFaninInternal( |
1303 | node, {from_fanin_node, from_fanin.index()}); |
1304 | } |
1305 | if (modified) { |
1306 | AddFaninInternal(node, {to_fanin_node, to_fanin.index()}); |
1307 | } |
1308 | return OkStatus(); |
1309 | } |
1310 | |
1311 | // In place mutation of regular fanins, requires no shifting of ports. |
1312 | string to_fanin_string = TensorIdToString(to_fanin); |
1313 | const int num_regular_fanins = |
1314 | NumFanins(*node, /*include_controlling_nodes=*/false); |
1315 | bool modified = false; |
1316 | for (int i = 0; i < num_regular_fanins; ++i) { |
1317 | if (ParseTensorName(node->input(i)) == from_fanin) { |
1318 | InputPort input(node, i); |
1319 | |
1320 | OutputPort from_fanin_port(from_fanin_node, from_fanin.index()); |
1321 | fanouts()[from_fanin_port].erase(input); |
1322 | |
1323 | OutputPort to_fanin_port(to_fanin_node, to_fanin.index()); |
1324 | fanouts()[to_fanin_port].insert(input); |
1325 | |
1326 | node->set_input(i, to_fanin_string); |
1327 | modified = true; |
1328 | } |
1329 | } |
1330 | |
1331 | // Dedup control dependencies and update max regular output ports. |
1332 | if (modified) { |
1333 | OutputPort from_fanin_port(from_fanin_node, from_fanin.index()); |
1334 | UpdateMaxRegularOutputPortForRemovedFanin( |
1335 | {from_fanin_node, from_fanin.index()}, fanouts()[from_fanin_port]); |
1336 | if (max_regular_output_port()[to_fanin_node] < to_fanin.index()) { |
1337 | max_regular_output_port()[to_fanin_node] = to_fanin.index(); |
1338 | } |
1339 | if (CanDedupControlWithRegularInput(*this, *to_fanin_node)) { |
1340 | RemoveControllingFaninInternal(node, to_fanin_node); |
1341 | } |
1342 | } |
1343 | |
1344 | return OkStatus(); |
1345 | } |
1346 | |
1347 | Status MutableGraphView::UpdateRegularFaninByPort(absl::string_view node_name, |
1348 | int port, |
1349 | const TensorId& fanin) { |
1350 | auto error_status = [node_name, port, fanin](absl::string_view msg) { |
1351 | string params = absl::Substitute("node_name='$0', port=$1, fanin='$2'" , |
1352 | node_name, port, fanin.ToString()); |
1353 | return MutationError("UpdateRegularFaninByPort" , params, msg); |
1354 | }; |
1355 | |
1356 | TF_RETURN_IF_ERROR(CheckFaninIsRegular(fanin, error_status)); |
1357 | TF_RETURN_IF_ERROR(CheckAddingFaninToSelf(node_name, fanin, error_status)); |
1358 | NodeDef* node = GetNode(node_name); |
1359 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1360 | const int last_regular_fanin_port = |
1361 | gtl::FindWithDefault(max_regular_input_port(), node, -1); |
1362 | TF_RETURN_IF_ERROR( |
1363 | CheckPortRange(port, /*min=*/0, last_regular_fanin_port, error_status)); |
1364 | NodeDef* fanin_node = GetNode(fanin.node()); |
1365 | TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status)); |
1366 | |
1367 | TensorId tensor_id = ParseTensorName(node->input(port)); |
1368 | if (tensor_id == fanin) { |
1369 | return OkStatus(); |
1370 | } |
1371 | |
1372 | InputPort input(node, port); |
1373 | OutputPort from_fanin_port(nodes()[tensor_id.node()], tensor_id.index()); |
1374 | absl::flat_hash_set<InputPort>* from_fanouts = &fanouts()[from_fanin_port]; |
1375 | from_fanouts->erase(input); |
1376 | UpdateMaxRegularOutputPortForRemovedFanin(from_fanin_port, *from_fanouts); |
1377 | |
1378 | OutputPort to_fanin_port(fanin_node, fanin.index()); |
1379 | fanouts()[to_fanin_port].insert(input); |
1380 | UpdateMaxRegularOutputPortForAddedFanin(to_fanin_port); |
1381 | |
1382 | node->set_input(port, TensorIdToString(fanin)); |
1383 | |
1384 | if (CanDedupControlWithRegularInput(*this, *fanin_node)) { |
1385 | RemoveControllingFaninInternal(node, fanin_node); |
1386 | } |
1387 | |
1388 | return OkStatus(); |
1389 | } |
1390 | |
1391 | Status MutableGraphView::SwapRegularFaninsByPorts(absl::string_view node_name, |
1392 | int from_port, int to_port) { |
1393 | auto error_status = [node_name, from_port, to_port](absl::string_view msg) { |
1394 | string params = absl::Substitute("node_name='$0', from_port=$1, to_port=$2" , |
1395 | node_name, from_port, to_port); |
1396 | return MutationError("SwapRegularFaninsByPorts" , params, msg); |
1397 | }; |
1398 | |
1399 | NodeDef* node = GetNode(node_name); |
1400 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1401 | const int last_regular_fanin_port = |
1402 | gtl::FindWithDefault(max_regular_input_port(), node, -1); |
1403 | TF_RETURN_IF_ERROR(CheckPortRange(from_port, /*min=*/0, |
1404 | last_regular_fanin_port, error_status)); |
1405 | TF_RETURN_IF_ERROR(CheckPortRange(to_port, /*min=*/0, last_regular_fanin_port, |
1406 | error_status)); |
1407 | |
1408 | if (from_port == to_port) { |
1409 | return OkStatus(); |
1410 | } |
1411 | TensorId from_fanin = ParseTensorName(node->input(from_port)); |
1412 | TensorId to_fanin = ParseTensorName(node->input(to_port)); |
1413 | if (from_fanin == to_fanin) { |
1414 | return OkStatus(); |
1415 | } |
1416 | |
1417 | InputPort from_input(node, from_port); |
1418 | InputPort to_input(node, to_port); |
1419 | NodeDef* from_fanin_node = GetNode(from_fanin.node()); |
1420 | absl::flat_hash_set<InputPort>* from_fanouts = |
1421 | &fanouts()[{from_fanin_node, from_fanin.index()}]; |
1422 | from_fanouts->erase(from_input); |
1423 | from_fanouts->insert(to_input); |
1424 | NodeDef* to_fanin_node = GetNode(to_fanin.node()); |
1425 | absl::flat_hash_set<InputPort>* to_fanouts = |
1426 | &fanouts()[{to_fanin_node, to_fanin.index()}]; |
1427 | to_fanouts->erase(to_input); |
1428 | to_fanouts->insert(from_input); |
1429 | |
1430 | node->mutable_input()->SwapElements(from_port, to_port); |
1431 | |
1432 | return OkStatus(); |
1433 | } |
1434 | |
1435 | Status MutableGraphView::UpdateAllRegularFaninsToControlling( |
1436 | absl::string_view node_name) { |
1437 | auto error_status = [node_name](absl::string_view msg) { |
1438 | string params = absl::Substitute("node_name='$0'" , node_name); |
1439 | return MutationError("UpdateAllRegularFaninsToControlling" , params, msg); |
1440 | }; |
1441 | |
1442 | NodeDef* node = GetNode(node_name); |
1443 | TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status)); |
1444 | |
1445 | const int num_regular_fanins = |
1446 | NumFanins(*node, /*include_controlling_nodes=*/false); |
1447 | std::vector<OutputPort> regular_fanins; |
1448 | regular_fanins.reserve(num_regular_fanins); |
1449 | std::vector<NodeDef*> controlling_fanins; |
1450 | controlling_fanins.reserve(num_regular_fanins); |
1451 | |
1452 | // Get all regular fanins and derive controlling fanins. |
1453 | for (int i = 0; i < num_regular_fanins; ++i) { |
1454 | TensorId tensor_id = ParseTensorName(node->input(i)); |
1455 | OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index()); |
1456 | |
1457 | string error_msg = "" ; |
1458 | NodeDef* control_node = |
1459 | GetControllingFaninToAdd(node_name, fanin_port, &error_msg); |
1460 | if (!error_msg.empty()) { |
1461 | return error_status(error_msg); |
1462 | } |
1463 | |
1464 | regular_fanins.push_back(fanin_port); |
1465 | controlling_fanins.push_back(control_node); |
1466 | } |
1467 | |
1468 | // Replace regular fanins with controlling fanins and dedup. |
1469 | int pos = 0; |
1470 | InputPort input_port(node, Graph::kControlSlot); |
1471 | absl::flat_hash_set<absl::string_view> controls; |
1472 | for (int i = 0; i < num_regular_fanins; ++i) { |
1473 | OutputPort fanin_port = regular_fanins[i]; |
1474 | NodeDef* control = controlling_fanins[i]; |
1475 | if (control == nullptr) { |
1476 | control = GetOrCreateIdentityConsumingSwitch(fanin_port); |
1477 | } |
1478 | fanouts()[fanin_port].erase({node, i}); |
1479 | if (controls.contains(control->name())) { |
1480 | continue; |
1481 | } |
1482 | controls.insert(control->name()); |
1483 | node->set_input(pos, AsControlDependency(control->name())); |
1484 | fanouts()[{control, Graph::kControlSlot}].insert(input_port); |
1485 | ++pos; |
1486 | } |
1487 | |
1488 | // Shift existing controlling fanins and dedup. |
1489 | for (int i = num_regular_fanins; i < node->input_size(); ++i) { |
1490 | TensorId tensor_id = ParseTensorName(node->input(i)); |
1491 | if (controls.contains(tensor_id.node())) { |
1492 | continue; |
1493 | } |
1494 | controls.insert(tensor_id.node()); |
1495 | node->mutable_input()->SwapElements(pos, i); |
1496 | ++pos; |
1497 | } |
1498 | |
1499 | // Remove duplicate controls and leftover regular fanins. |
1500 | node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos); |
1501 | max_regular_input_port().erase(node); |
1502 | |
1503 | return OkStatus(); |
1504 | } |
1505 | |
1506 | Status MutableGraphView::CheckNodesCanBeDeleted( |
1507 | const absl::flat_hash_set<string>& nodes_to_delete) { |
1508 | std::vector<string> missing_nodes; |
1509 | std::vector<string> nodes_with_fanouts; |
1510 | for (const string& node_name_to_delete : nodes_to_delete) { |
1511 | NodeDef* node = GetNode(node_name_to_delete); |
1512 | if (node == nullptr) { |
1513 | // Can't delete missing node. |
1514 | missing_nodes.push_back(node_name_to_delete); |
1515 | continue; |
1516 | } |
1517 | const int max_port = gtl::FindWithDefault(max_regular_output_port(), node, |
1518 | Graph::kControlSlot); |
1519 | for (int i = Graph::kControlSlot; i <= max_port; ++i) { |
1520 | auto it = fanouts().find({node, i}); |
1521 | bool has_retained_fanout = false; |
1522 | if (it != fanouts().end()) { |
1523 | for (const auto& fanout : it->second) { |
1524 | // Check if fanouts are of nodes to be deleted, and if so, they can be |
1525 | // ignored, as they will be removed also. |
1526 | if (!nodes_to_delete.contains(fanout.node->name())) { |
1527 | // Removing node will leave graph in an invalid state. |
1528 | has_retained_fanout = true; |
1529 | break; |
1530 | } |
1531 | } |
1532 | } |
1533 | if (has_retained_fanout) { |
1534 | nodes_with_fanouts.push_back(node_name_to_delete); |
1535 | break; |
1536 | } |
1537 | } |
1538 | } |
1539 | |
1540 | // Error message can get quite long, so we only show the first 5 node names. |
1541 | auto sort_and_sample = [](std::vector<string>* s) { |
1542 | constexpr int kMaxNodeNames = 5; |
1543 | std::sort(s->begin(), s->end()); |
1544 | if (s->size() > kMaxNodeNames) { |
1545 | return absl::StrCat( |
1546 | absl::StrJoin(s->begin(), s->begin() + kMaxNodeNames, ", " ), ", ..." ); |
1547 | } |
1548 | return absl::StrJoin(*s, ", " ); |
1549 | }; |
1550 | |
1551 | if (!missing_nodes.empty()) { |
1552 | VLOG(2) << absl::Substitute("Attempting to delete missing node(s) [$0]." , |
1553 | sort_and_sample(&missing_nodes)); |
1554 | } |
1555 | if (!nodes_with_fanouts.empty()) { |
1556 | std::vector<string> input_node_names(nodes_to_delete.begin(), |
1557 | nodes_to_delete.end()); |
1558 | string params = absl::Substitute("nodes_to_delete={$0}" , |
1559 | sort_and_sample(&input_node_names)); |
1560 | string error_msg = |
1561 | absl::Substitute("can't delete node(s) with retained fanouts(s) [$0]" , |
1562 | sort_and_sample(&nodes_with_fanouts)); |
1563 | return MutationError("DeleteNodes" , params, error_msg); |
1564 | } |
1565 | |
1566 | return OkStatus(); |
1567 | } |
1568 | |
1569 | Status MutableGraphView::DeleteNodes( |
1570 | const absl::flat_hash_set<string>& nodes_to_delete) { |
1571 | TF_RETURN_IF_ERROR(CheckNodesCanBeDeleted(nodes_to_delete)); |
1572 | |
1573 | // Find nodes in internal state and delete. |
1574 | for (const string& node_name_to_delete : nodes_to_delete) { |
1575 | NodeDef* node = GetNode(node_name_to_delete); |
1576 | if (node != nullptr) { |
1577 | RemoveFaninsInternal(node, /*keep_controlling_fanins=*/false); |
1578 | RemoveFanoutsInternal(node); |
1579 | } |
1580 | } |
1581 | for (const string& node_name_to_delete : nodes_to_delete) { |
1582 | nodes().erase(node_name_to_delete); |
1583 | } |
1584 | |
1585 | // Find nodes in graph and delete by partitioning into nodes to retain and |
1586 | // nodes to delete based on input set of nodes to delete by name. |
1587 | // TODO(lyandy): Use a node name->idx hashmap if this is a performance |
1588 | // bottleneck. |
1589 | int pos = 0; |
1590 | const int last_idx = graph()->node_size() - 1; |
1591 | int last_pos = last_idx; |
1592 | while (pos <= last_pos) { |
1593 | if (nodes_to_delete.contains(graph()->node(pos).name())) { |
1594 | graph()->mutable_node()->SwapElements(pos, last_pos); |
1595 | --last_pos; |
1596 | } else { |
1597 | ++pos; |
1598 | } |
1599 | } |
1600 | if (last_pos < last_idx) { |
1601 | graph()->mutable_node()->DeleteSubrange(last_pos + 1, last_idx - last_pos); |
1602 | } |
1603 | |
1604 | return OkStatus(); |
1605 | } |
1606 | |
1607 | void MutableGraphView::RemoveFaninsInternal(NodeDef* deleted_node, |
1608 | bool keep_controlling_fanins) { |
1609 | for (int i = 0; i < deleted_node->input_size(); ++i) { |
1610 | TensorId tensor_id = ParseTensorName(deleted_node->input(i)); |
1611 | bool is_control = IsTensorIdControlling(tensor_id); |
1612 | if (keep_controlling_fanins && is_control) { |
1613 | break; |
1614 | } |
1615 | OutputPort fanin(nodes()[tensor_id.node()], tensor_id.index()); |
1616 | |
1617 | InputPort input; |
1618 | input.node = deleted_node; |
1619 | input.port_id = is_control ? Graph::kControlSlot : i; |
1620 | |
1621 | auto it = fanouts().find(fanin); |
1622 | if (it != fanouts().end()) { |
1623 | absl::flat_hash_set<InputPort>* fanouts_set = &it->second; |
1624 | fanouts_set->erase(input); |
1625 | UpdateMaxRegularOutputPortForRemovedFanin(fanin, *fanouts_set); |
1626 | } |
1627 | } |
1628 | max_regular_input_port().erase(deleted_node); |
1629 | } |
1630 | |
1631 | void MutableGraphView::RemoveFanoutsInternal(NodeDef* deleted_node) { |
1632 | const int max_port = |
1633 | gtl::FindWithDefault(max_regular_output_port(), deleted_node, -1); |
1634 | for (int i = Graph::kControlSlot; i <= max_port; ++i) { |
1635 | fanouts().erase({deleted_node, i}); |
1636 | } |
1637 | max_regular_output_port().erase(deleted_node); |
1638 | } |
1639 | |
1640 | } // end namespace grappler |
1641 | } // end namespace tensorflow |
1642 | |