1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
17#define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
18
19#include <set>
20#include <string>
21
22#include "absl/container/flat_hash_set.h"
23#include "absl/strings/string_view.h"
24#include "absl/types/span.h"
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27#include "tensorflow/core/graph/graph.h"
28#include "tensorflow/core/graph/tensor_id.h"
29#include "tensorflow/core/grappler/graph_view.h"
30#include "tensorflow/core/grappler/op_types.h"
31#include "tensorflow/core/lib/core/status.h"
32#include "tensorflow/core/platform/types.h"
33
34namespace tensorflow {
35namespace grappler {
36
37const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl";
38
39// A utility class to simplify the traversal of a GraphDef that, unlike
40// GraphView, supports updating the graph. Note that you should not modify the
41// graph separately, because the view will get out of sync.
42
43class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> {
44 public:
45 explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) {
46 for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node);
47 for (NodeDef& node : *graph->mutable_node()) AddAndDedupFanouts(&node);
48 }
49
50 // Lookup fanouts/fanins using immutable ports.
51 using GraphViewInternal::GetFanout;
52 const absl::flat_hash_set<InputPort>& GetFanout(
53 const GraphView::OutputPort& port) const;
54
55 using GraphViewInternal::GetFanin;
56 absl::flat_hash_set<OutputPort> GetFanin(
57 const GraphView::InputPort& port) const;
58
59 using GraphViewInternal::GetRegularFanin;
60 const OutputPort GetRegularFanin(const GraphView::InputPort& port) const;
61
62 // Adds a new node to graph and updates the view. Returns a pointer to the
63 // node in graph.
64 NodeDef* AddNode(NodeDef&& node);
65
66 // Adds all nodes from the `subgraph` to the underlying graph and updates the
67 // view. `subgraph` doesn't have to be a valid graph definition on it's own,
68 // it can have edges to the nodes that are not in it, however after adding
69 // it to the underlying graph, final graph must be valid.
70 //
71 // If subgraph function library is not empty, all new functions will be added
72 // to the graph. Functions that appear with the same name in both subgraph and
73 // the graph represented by *this, must have identical function definitions.
74 //
75 // IMPORTANT: All nodes and functions of the given subgraph moved into the
76 // underlying graph, which leaves subgraph in valid but undefined state.
77 Status AddSubgraph(GraphDef&& subgraph);
78
79 // Updates node `node_name` op, device, and attributes. This will clear any
80 // existing attributes. If it is not possible to update the node or if the
81 // node does not exist, an error will be returned and nothing will be modified
82 // in the graph.
83 Status UpdateNode(absl::string_view node_name, absl::string_view op,
84 absl::string_view device,
85 absl::Span<const std::pair<string, AttrValue>> attrs);
86
87 // Updates node `from_node_name` name to `to_node_name`. If `to_node_name` is
88 // in use, node `from_node_name` does not exist, or node `from_node_name` has
89 // fanouts and `update_fanouts` is set to false, an error will be returned and
90 // nothing will be modified in the graph.
91 Status UpdateNodeName(absl::string_view from_node_name,
92 absl::string_view to_node_name, bool update_fanouts);
93
94 // Swap node names `from_node_name` and `to_node_name`. Self loops of one node
95 // are removed by updating the inputs introducing self loops to use the other
96 // node's name. Setting `update_fanouts` to false will exclude other fanouts
97 // from having their inputs updated, but inputs introducing self loops will
98 // always be updated regardless of `update_fanouts.
99 //
100 // Example:
101 // 1. foo(other:3, bar:2, ^bar)
102 // 2. bar(foo:3, other:1, foo:1, ^foo)
103 // 3. other(foo:5, bar:6)
104 //
105 // After calling SwapNodeNames("foo", "bar", false):
106 // 1. bar(other:3, foo:2, ^foo)
107 // 2. foo(bar:3, other:1, bar:1, ^bar)
108 // 3. other(foo:5, bar:6)
109 //
110 // After calling SwapNodeNames("foo", "bar", true):
111 // 1. bar(other:3, foo:2, ^foo)
112 // 2. foo(bar:3, other:1, bar:1, ^bar)
113 // 3. other(bar:5, foo:6)
114 //
115 // If it is not possible to swap node names (i.e. nodes do not exist or Switch
116 // control dependency may be introduced), an error will be returned and
117 // nothing will be modified in the graph.
118 Status SwapNodeNames(absl::string_view from_node_name,
119 absl::string_view to_node_name, bool update_fanouts);
120
121 // Updates all fanouts (input ports fetching output tensors) from
122 // `from_node_name` to the `to_node_name`, including control dependencies.
123 //
124 // Example: We have 3 nodes that use `bar` node output tensors as inputs:
125 // 1. foo1(bar:0, bar:1, other:0)
126 // 2. foo2(bar:1, other:1)
127 // 3. foo3(other:2, ^bar)
128 //
129 // After calling UpdateFanouts(bar, new_bar):
130 // 1. foo1(new_bar:0, new_bar:1, other:0)
131 // 2. foo2(new_bar:1, other:1)
132 // 3. foo3(other:2, ^new_bar)
133 Status UpdateFanouts(absl::string_view from_node_name,
134 absl::string_view to_node_name);
135
136 // Adds regular fanin `fanin` to node `node_name`. If the node or fanin do not
137 // exist in the graph, nothing will be modified in the graph. Otherwise fanin
138 // will be added after existing non control dependency fanins. Control
139 // dependencies will be deduped. To add control dependencies, use
140 // AddControllingFanin.
141 Status AddRegularFanin(absl::string_view node_name, const TensorId& fanin);
142
143 // Adds regular fanin `fanin` to node `node_name` at port `port`. If the node
144 // or fanin do not exist in the graph, nothing will be modified in the graph.
145 // Otherwise fanin will be inserted at port `port`. Control dependencies will
146 // be deduped. To add control dependencies, use AddControllingFanin.
147 //
148 // If the port is not a valid port (less than 0 or greater than the number of
149 // regular fanins), this will result in an error and the node will not be
150 // modified.
151 Status AddRegularFaninByPort(absl::string_view node_name, int port,
152 const TensorId& fanin);
153
154 // Adds control dependency `fanin` to the target node named `node_name`. To
155 // add regular fanins, use AddRegularFanin.
156 //
157 // Case 1: If the fanin is not a Switch node, the control dependency is simply
158 // added to the target node:
159 //
160 // fanin -^> target node.
161 //
162 // Case 2: If the fanin is a Switch node, we cannot anchor a control
163 // dependency on it, because unlike other nodes, only one of its outputs will
164 // be generated when the node is activated. In this case, we try to find an
165 // Identity/IdentityN node in the fanout of the relevant port of the Switch
166 // and add it as a fanin to the target node. If no such Identity/IdentityN
167 // node can be found, a new Identity node will be created. In both cases, we
168 // end up with:
169 //
170 // fanin -> Identity{N} -^> target node.
171 //
172 // If the control dependency being added is redundant (control dependency
173 // already exists or control dependency can be deduped from regular fanins),
174 // this will not result in an error and the node will not be modified.
175 Status AddControllingFanin(absl::string_view node_name,
176 const TensorId& fanin);
177
178 // Removes regular fanin `fanin` from node `node_name`. If the node or fanin
179 // do not exist in the graph, nothing will be modified in the graph. If there
180 // are multiple inputs that match the fanin, all of them will be removed. To
181 // remove controlling fanins, use RemoveControllingFanin.
182 //
183 // If the fanin being removed doesn't exist in the node's inputs, this will
184 // not result in an error and the node will not be modified.
185 Status RemoveRegularFanin(absl::string_view node_name, const TensorId& fanin);
186
187 // Removes regular fanin at port `port` from node `node_name`. If the node
188 // does not exist in the graph, nothing will be modified in the graph.
189 // To remove controlling fanins, use RemoveControllingFanin.
190 //
191 // If the port is not a valid port (less than 0 or greater than the last index
192 // of the regular fanins), this will result in an error and the node will not
193 // be modified.
194 Status RemoveRegularFaninByPort(absl::string_view node_name, int port);
195
196 // Removes control dependency `fanin_node_name` from the target node named
197 // `node_name`. If the node or fanin do not exist in the graph, nothing will
198 // be modified in the graph. To remove regular fanins, use RemoveRegularFanin.
199 //
200 // If the fanin being removed doesn't exist in the node's inputs, this will
201 // not result in an error and the node will not be modified.
202 Status RemoveControllingFanin(absl::string_view node_name,
203 absl::string_view fanin_node_name);
204
205 // Removes all fanins from node `node_name`. Control dependencies will be
206 // retained if keep_controlling_fanins is true.
207 //
208 // If no fanins are removed, this will not result in an error and the node
209 // will not be modified.
210 Status RemoveAllFanins(absl::string_view node_name,
211 bool keep_controlling_fanins);
212
213 // Replaces all fanins `from_fanin` with `to_fanin` in node `node_name`. If
214 // the fanins or node do not exist, nothing will be modified in the graph.
215 // Control dependencies will be deduped.
216 //
217 // If the fanin being updated doesn't exist in the node's inputs, this will
218 // not result in an error and the node will not be modified.
219 Status UpdateFanin(absl::string_view node_name, const TensorId& from_fanin,
220 const TensorId& to_fanin);
221
222 // Replaces fanin at port `port` in node `node_name` with fanin `fanin`. If
223 // the fanins or node do not exist, nothing will be modified in the graph.
224 // Control dependencies will be deduped.
225 //
226 // If the port is not a valid port (less than 0 or greater than the last index
227 // of the regular fanins), this will result in an error and the node will not
228 // be modified.
229 Status UpdateRegularFaninByPort(absl::string_view node_name, int port,
230 const TensorId& fanin);
231
232 // Swaps fanins at ports `from_port` and `to_port` in node `node_name`. If the
233 // node does not exist, nothing will be modified in the graph.
234 //
235 // If the ports are not a valid port (less than 0 or greater than the last
236 // index of the regular fanins), this will result in an error and the node
237 // will not be modified.
238 Status SwapRegularFaninsByPorts(absl::string_view node_name, int from_port,
239 int to_port);
240
241 // Updates all regular fanins to equivalent controlling fanins. If it is not
242 // possible, an error will be returned and nothing will be modified in the
243 // graph.
244 Status UpdateAllRegularFaninsToControlling(absl::string_view node_name);
245
246 // Deletes nodes from the graph. If a node can't be safely removed,
247 // specifically if a node still has fanouts, an error will be returned. Nodes
248 // that can't be found are ignored.
249 Status DeleteNodes(const absl::flat_hash_set<string>& nodes_to_delete);
250
251 private:
252 // Adds fanouts for fanins of node to graph, while deduping control
253 // dependencies from existing control dependencies and regular fanins. Note,
254 // node inputs will be mutated if control dependencies can be deduped.
255 void AddAndDedupFanouts(NodeDef* node);
256
257 // Finds next output port smaller than fanin.port_id and update. The
258 // max_regular_output_port is only updated if fanin.port_id is the same as the
259 // current max_regular_output_port and if the fanouts set is empty. If there
260 // are no regular outputs, max_regular_output_port will be erased.
261 void UpdateMaxRegularOutputPortForRemovedFanin(
262 const OutputPort& fanin,
263 const absl::flat_hash_set<InputPort>& fanin_fanouts);
264
265 // Updates max regular output port for newly added fanin by checking the
266 // current max and updating if the newly added fanin is of a larger port.
267 void UpdateMaxRegularOutputPortForAddedFanin(const OutputPort& fanin);
268
269 // Updates all fanouts (input ports fetching output tensors) from `from_node`
270 // to the `to_node`, including control dependencies.
271 //
272 // Example: We have 3 nodes that use `bar` node output tensors as inputs:
273 // 1. foo1(bar:0, bar:1, other:0)
274 // 2. foo2(bar:1, other:1)
275 // 3. foo3(other:2, ^bar)
276 //
277 // After calling UpdateFanouts(bar, new_bar):
278 // 1. foo1(new_bar:0, new_bar:1, other:0)
279 // 2. foo2(new_bar:1, other:1)
280 // 3. foo3(other:2, ^new_bar)
281 //
282 // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the
283 // behavior is undefined.
284 Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node);
285
286 // Adds fanin to node. If fanin is a control dependency, existing control
287 // dependencies will be checked first before adding. Otherwise fanin will be
288 // added after existing non control dependency inputs.
289 bool AddFaninInternal(NodeDef* node, const OutputPort& fanin);
290
291 // Finds control dependency node to be used based on fanin. If fanin is not a
292 // Switch node, fanin.node is simply returned. Otherwise this will try to find
293 // a candidate Identity node consuming fanin, as the control dependency. If it
294 // is not possible or will introduce a self loop, an error message will be
295 // set. If nullptr is returned with no error
296 // GetOrCreateIdentityConsumingSwitch should be called to generate the new
297 // Identity node.
298 NodeDef* GetControllingFaninToAdd(absl::string_view node_name,
299 const OutputPort& fanin, string* error_msg);
300
301 // Finds a generated Identity node consuming Switch node `fanin.node` at port
302 // `fanin.port_id`. If such a node does not exist, a new Identity node will be
303 // created.
304 NodeDef* GetOrCreateIdentityConsumingSwitch(const OutputPort& fanin);
305
306 // Removes all instances of regular fanin `fanin` from node `node`.
307 bool RemoveRegularFaninInternal(NodeDef* node, const OutputPort& fanin);
308
309 // Removes controlling fanin `fanin_node` from node if such controlling fanin
310 // exists.
311 bool RemoveControllingFaninInternal(NodeDef* node, NodeDef* fanin_node);
312
313 // Checks if nodes to be deleted are missing or have any fanouts that will
314 // remain in the graph. If node is removed in either case, the graph will
315 // enter an invalid state.
316 Status CheckNodesCanBeDeleted(
317 const absl::flat_hash_set<string>& nodes_to_delete);
318
319 // Removes fanins of the deleted node from internal state. Control
320 // dependencies are retained iff keep_controlling_fanins is true.
321 void RemoveFaninsInternal(NodeDef* deleted_node,
322 bool keep_controlling_fanins);
323
324 // Removes fanouts of the deleted node from internal state.
325 void RemoveFanoutsInternal(NodeDef* deleted_node);
326};
327
328} // end namespace grappler
329} // end namespace tensorflow
330
331#endif // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_
332