1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ |
17 | #define TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ |
18 | |
19 | #include <set> |
20 | #include <string> |
21 | |
22 | #include "absl/container/flat_hash_set.h" |
23 | #include "absl/strings/string_view.h" |
24 | #include "absl/types/span.h" |
25 | #include "tensorflow/core/framework/graph.pb.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | #include "tensorflow/core/graph/graph.h" |
28 | #include "tensorflow/core/graph/tensor_id.h" |
29 | #include "tensorflow/core/grappler/graph_view.h" |
30 | #include "tensorflow/core/grappler/op_types.h" |
31 | #include "tensorflow/core/lib/core/status.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | |
34 | namespace tensorflow { |
35 | namespace grappler { |
36 | |
37 | const char kMutableGraphViewCtrl[] = "ConstantFoldingCtrl" ; |
38 | |
39 | // A utility class to simplify the traversal of a GraphDef that, unlike |
40 | // GraphView, supports updating the graph. Note that you should not modify the |
41 | // graph separately, because the view will get out of sync. |
42 | |
43 | class MutableGraphView : public internal::GraphViewInternal<GraphDef, NodeDef> { |
44 | public: |
45 | explicit MutableGraphView(GraphDef* graph) : GraphViewInternal(graph) { |
46 | for (NodeDef& node : *graph->mutable_node()) AddUniqueNodeOrDie(&node); |
47 | for (NodeDef& node : *graph->mutable_node()) AddAndDedupFanouts(&node); |
48 | } |
49 | |
50 | // Lookup fanouts/fanins using immutable ports. |
51 | using GraphViewInternal::GetFanout; |
52 | const absl::flat_hash_set<InputPort>& GetFanout( |
53 | const GraphView::OutputPort& port) const; |
54 | |
55 | using GraphViewInternal::GetFanin; |
56 | absl::flat_hash_set<OutputPort> GetFanin( |
57 | const GraphView::InputPort& port) const; |
58 | |
59 | using GraphViewInternal::GetRegularFanin; |
60 | const OutputPort GetRegularFanin(const GraphView::InputPort& port) const; |
61 | |
62 | // Adds a new node to graph and updates the view. Returns a pointer to the |
63 | // node in graph. |
64 | NodeDef* AddNode(NodeDef&& node); |
65 | |
66 | // Adds all nodes from the `subgraph` to the underlying graph and updates the |
67 | // view. `subgraph` doesn't have to be a valid graph definition on it's own, |
68 | // it can have edges to the nodes that are not in it, however after adding |
69 | // it to the underlying graph, final graph must be valid. |
70 | // |
71 | // If subgraph function library is not empty, all new functions will be added |
72 | // to the graph. Functions that appear with the same name in both subgraph and |
73 | // the graph represented by *this, must have identical function definitions. |
74 | // |
75 | // IMPORTANT: All nodes and functions of the given subgraph moved into the |
76 | // underlying graph, which leaves subgraph in valid but undefined state. |
77 | Status AddSubgraph(GraphDef&& subgraph); |
78 | |
79 | // Updates node `node_name` op, device, and attributes. This will clear any |
80 | // existing attributes. If it is not possible to update the node or if the |
81 | // node does not exist, an error will be returned and nothing will be modified |
82 | // in the graph. |
83 | Status UpdateNode(absl::string_view node_name, absl::string_view op, |
84 | absl::string_view device, |
85 | absl::Span<const std::pair<string, AttrValue>> attrs); |
86 | |
87 | // Updates node `from_node_name` name to `to_node_name`. If `to_node_name` is |
88 | // in use, node `from_node_name` does not exist, or node `from_node_name` has |
89 | // fanouts and `update_fanouts` is set to false, an error will be returned and |
90 | // nothing will be modified in the graph. |
91 | Status UpdateNodeName(absl::string_view from_node_name, |
92 | absl::string_view to_node_name, bool update_fanouts); |
93 | |
94 | // Swap node names `from_node_name` and `to_node_name`. Self loops of one node |
95 | // are removed by updating the inputs introducing self loops to use the other |
96 | // node's name. Setting `update_fanouts` to false will exclude other fanouts |
97 | // from having their inputs updated, but inputs introducing self loops will |
98 | // always be updated regardless of `update_fanouts. |
99 | // |
100 | // Example: |
101 | // 1. foo(other:3, bar:2, ^bar) |
102 | // 2. bar(foo:3, other:1, foo:1, ^foo) |
103 | // 3. other(foo:5, bar:6) |
104 | // |
105 | // After calling SwapNodeNames("foo", "bar", false): |
106 | // 1. bar(other:3, foo:2, ^foo) |
107 | // 2. foo(bar:3, other:1, bar:1, ^bar) |
108 | // 3. other(foo:5, bar:6) |
109 | // |
110 | // After calling SwapNodeNames("foo", "bar", true): |
111 | // 1. bar(other:3, foo:2, ^foo) |
112 | // 2. foo(bar:3, other:1, bar:1, ^bar) |
113 | // 3. other(bar:5, foo:6) |
114 | // |
115 | // If it is not possible to swap node names (i.e. nodes do not exist or Switch |
116 | // control dependency may be introduced), an error will be returned and |
117 | // nothing will be modified in the graph. |
118 | Status SwapNodeNames(absl::string_view from_node_name, |
119 | absl::string_view to_node_name, bool update_fanouts); |
120 | |
121 | // Updates all fanouts (input ports fetching output tensors) from |
122 | // `from_node_name` to the `to_node_name`, including control dependencies. |
123 | // |
124 | // Example: We have 3 nodes that use `bar` node output tensors as inputs: |
125 | // 1. foo1(bar:0, bar:1, other:0) |
126 | // 2. foo2(bar:1, other:1) |
127 | // 3. foo3(other:2, ^bar) |
128 | // |
129 | // After calling UpdateFanouts(bar, new_bar): |
130 | // 1. foo1(new_bar:0, new_bar:1, other:0) |
131 | // 2. foo2(new_bar:1, other:1) |
132 | // 3. foo3(other:2, ^new_bar) |
133 | Status UpdateFanouts(absl::string_view from_node_name, |
134 | absl::string_view to_node_name); |
135 | |
136 | // Adds regular fanin `fanin` to node `node_name`. If the node or fanin do not |
137 | // exist in the graph, nothing will be modified in the graph. Otherwise fanin |
138 | // will be added after existing non control dependency fanins. Control |
139 | // dependencies will be deduped. To add control dependencies, use |
140 | // AddControllingFanin. |
141 | Status AddRegularFanin(absl::string_view node_name, const TensorId& fanin); |
142 | |
143 | // Adds regular fanin `fanin` to node `node_name` at port `port`. If the node |
144 | // or fanin do not exist in the graph, nothing will be modified in the graph. |
145 | // Otherwise fanin will be inserted at port `port`. Control dependencies will |
146 | // be deduped. To add control dependencies, use AddControllingFanin. |
147 | // |
148 | // If the port is not a valid port (less than 0 or greater than the number of |
149 | // regular fanins), this will result in an error and the node will not be |
150 | // modified. |
151 | Status AddRegularFaninByPort(absl::string_view node_name, int port, |
152 | const TensorId& fanin); |
153 | |
154 | // Adds control dependency `fanin` to the target node named `node_name`. To |
155 | // add regular fanins, use AddRegularFanin. |
156 | // |
157 | // Case 1: If the fanin is not a Switch node, the control dependency is simply |
158 | // added to the target node: |
159 | // |
160 | // fanin -^> target node. |
161 | // |
162 | // Case 2: If the fanin is a Switch node, we cannot anchor a control |
163 | // dependency on it, because unlike other nodes, only one of its outputs will |
164 | // be generated when the node is activated. In this case, we try to find an |
165 | // Identity/IdentityN node in the fanout of the relevant port of the Switch |
166 | // and add it as a fanin to the target node. If no such Identity/IdentityN |
167 | // node can be found, a new Identity node will be created. In both cases, we |
168 | // end up with: |
169 | // |
170 | // fanin -> Identity{N} -^> target node. |
171 | // |
172 | // If the control dependency being added is redundant (control dependency |
173 | // already exists or control dependency can be deduped from regular fanins), |
174 | // this will not result in an error and the node will not be modified. |
175 | Status AddControllingFanin(absl::string_view node_name, |
176 | const TensorId& fanin); |
177 | |
178 | // Removes regular fanin `fanin` from node `node_name`. If the node or fanin |
179 | // do not exist in the graph, nothing will be modified in the graph. If there |
180 | // are multiple inputs that match the fanin, all of them will be removed. To |
181 | // remove controlling fanins, use RemoveControllingFanin. |
182 | // |
183 | // If the fanin being removed doesn't exist in the node's inputs, this will |
184 | // not result in an error and the node will not be modified. |
185 | Status RemoveRegularFanin(absl::string_view node_name, const TensorId& fanin); |
186 | |
187 | // Removes regular fanin at port `port` from node `node_name`. If the node |
188 | // does not exist in the graph, nothing will be modified in the graph. |
189 | // To remove controlling fanins, use RemoveControllingFanin. |
190 | // |
191 | // If the port is not a valid port (less than 0 or greater than the last index |
192 | // of the regular fanins), this will result in an error and the node will not |
193 | // be modified. |
194 | Status RemoveRegularFaninByPort(absl::string_view node_name, int port); |
195 | |
196 | // Removes control dependency `fanin_node_name` from the target node named |
197 | // `node_name`. If the node or fanin do not exist in the graph, nothing will |
198 | // be modified in the graph. To remove regular fanins, use RemoveRegularFanin. |
199 | // |
200 | // If the fanin being removed doesn't exist in the node's inputs, this will |
201 | // not result in an error and the node will not be modified. |
202 | Status RemoveControllingFanin(absl::string_view node_name, |
203 | absl::string_view fanin_node_name); |
204 | |
205 | // Removes all fanins from node `node_name`. Control dependencies will be |
206 | // retained if keep_controlling_fanins is true. |
207 | // |
208 | // If no fanins are removed, this will not result in an error and the node |
209 | // will not be modified. |
210 | Status RemoveAllFanins(absl::string_view node_name, |
211 | bool keep_controlling_fanins); |
212 | |
213 | // Replaces all fanins `from_fanin` with `to_fanin` in node `node_name`. If |
214 | // the fanins or node do not exist, nothing will be modified in the graph. |
215 | // Control dependencies will be deduped. |
216 | // |
217 | // If the fanin being updated doesn't exist in the node's inputs, this will |
218 | // not result in an error and the node will not be modified. |
219 | Status UpdateFanin(absl::string_view node_name, const TensorId& from_fanin, |
220 | const TensorId& to_fanin); |
221 | |
222 | // Replaces fanin at port `port` in node `node_name` with fanin `fanin`. If |
223 | // the fanins or node do not exist, nothing will be modified in the graph. |
224 | // Control dependencies will be deduped. |
225 | // |
226 | // If the port is not a valid port (less than 0 or greater than the last index |
227 | // of the regular fanins), this will result in an error and the node will not |
228 | // be modified. |
229 | Status UpdateRegularFaninByPort(absl::string_view node_name, int port, |
230 | const TensorId& fanin); |
231 | |
232 | // Swaps fanins at ports `from_port` and `to_port` in node `node_name`. If the |
233 | // node does not exist, nothing will be modified in the graph. |
234 | // |
235 | // If the ports are not a valid port (less than 0 or greater than the last |
236 | // index of the regular fanins), this will result in an error and the node |
237 | // will not be modified. |
238 | Status SwapRegularFaninsByPorts(absl::string_view node_name, int from_port, |
239 | int to_port); |
240 | |
241 | // Updates all regular fanins to equivalent controlling fanins. If it is not |
242 | // possible, an error will be returned and nothing will be modified in the |
243 | // graph. |
244 | Status UpdateAllRegularFaninsToControlling(absl::string_view node_name); |
245 | |
246 | // Deletes nodes from the graph. If a node can't be safely removed, |
247 | // specifically if a node still has fanouts, an error will be returned. Nodes |
248 | // that can't be found are ignored. |
249 | Status DeleteNodes(const absl::flat_hash_set<string>& nodes_to_delete); |
250 | |
251 | private: |
252 | // Adds fanouts for fanins of node to graph, while deduping control |
253 | // dependencies from existing control dependencies and regular fanins. Note, |
254 | // node inputs will be mutated if control dependencies can be deduped. |
255 | void AddAndDedupFanouts(NodeDef* node); |
256 | |
257 | // Finds next output port smaller than fanin.port_id and update. The |
258 | // max_regular_output_port is only updated if fanin.port_id is the same as the |
259 | // current max_regular_output_port and if the fanouts set is empty. If there |
260 | // are no regular outputs, max_regular_output_port will be erased. |
261 | void UpdateMaxRegularOutputPortForRemovedFanin( |
262 | const OutputPort& fanin, |
263 | const absl::flat_hash_set<InputPort>& fanin_fanouts); |
264 | |
265 | // Updates max regular output port for newly added fanin by checking the |
266 | // current max and updating if the newly added fanin is of a larger port. |
267 | void UpdateMaxRegularOutputPortForAddedFanin(const OutputPort& fanin); |
268 | |
269 | // Updates all fanouts (input ports fetching output tensors) from `from_node` |
270 | // to the `to_node`, including control dependencies. |
271 | // |
272 | // Example: We have 3 nodes that use `bar` node output tensors as inputs: |
273 | // 1. foo1(bar:0, bar:1, other:0) |
274 | // 2. foo2(bar:1, other:1) |
275 | // 3. foo3(other:2, ^bar) |
276 | // |
277 | // After calling UpdateFanouts(bar, new_bar): |
278 | // 1. foo1(new_bar:0, new_bar:1, other:0) |
279 | // 2. foo2(new_bar:1, other:1) |
280 | // 3. foo3(other:2, ^new_bar) |
281 | // |
282 | // IMPORTANT: If `from_node` or `to_node` is not in the underlying graph, the |
283 | // behavior is undefined. |
284 | Status UpdateFanoutsInternal(NodeDef* from_node, NodeDef* to_node); |
285 | |
286 | // Adds fanin to node. If fanin is a control dependency, existing control |
287 | // dependencies will be checked first before adding. Otherwise fanin will be |
288 | // added after existing non control dependency inputs. |
289 | bool AddFaninInternal(NodeDef* node, const OutputPort& fanin); |
290 | |
291 | // Finds control dependency node to be used based on fanin. If fanin is not a |
292 | // Switch node, fanin.node is simply returned. Otherwise this will try to find |
293 | // a candidate Identity node consuming fanin, as the control dependency. If it |
294 | // is not possible or will introduce a self loop, an error message will be |
295 | // set. If nullptr is returned with no error |
296 | // GetOrCreateIdentityConsumingSwitch should be called to generate the new |
297 | // Identity node. |
298 | NodeDef* GetControllingFaninToAdd(absl::string_view node_name, |
299 | const OutputPort& fanin, string* error_msg); |
300 | |
301 | // Finds a generated Identity node consuming Switch node `fanin.node` at port |
302 | // `fanin.port_id`. If such a node does not exist, a new Identity node will be |
303 | // created. |
304 | NodeDef* GetOrCreateIdentityConsumingSwitch(const OutputPort& fanin); |
305 | |
306 | // Removes all instances of regular fanin `fanin` from node `node`. |
307 | bool RemoveRegularFaninInternal(NodeDef* node, const OutputPort& fanin); |
308 | |
309 | // Removes controlling fanin `fanin_node` from node if such controlling fanin |
310 | // exists. |
311 | bool RemoveControllingFaninInternal(NodeDef* node, NodeDef* fanin_node); |
312 | |
313 | // Checks if nodes to be deleted are missing or have any fanouts that will |
314 | // remain in the graph. If node is removed in either case, the graph will |
315 | // enter an invalid state. |
316 | Status CheckNodesCanBeDeleted( |
317 | const absl::flat_hash_set<string>& nodes_to_delete); |
318 | |
319 | // Removes fanins of the deleted node from internal state. Control |
320 | // dependencies are retained iff keep_controlling_fanins is true. |
321 | void RemoveFaninsInternal(NodeDef* deleted_node, |
322 | bool keep_controlling_fanins); |
323 | |
324 | // Removes fanouts of the deleted node from internal state. |
325 | void RemoveFanoutsInternal(NodeDef* deleted_node); |
326 | }; |
327 | |
328 | } // end namespace grappler |
329 | } // end namespace tensorflow |
330 | |
331 | #endif // TENSORFLOW_CORE_GRAPPLER_MUTABLE_GRAPH_VIEW_H_ |
332 | |