1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ |
17 | #define TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ |
18 | |
19 | #include <unordered_map> |
20 | #include <unordered_set> |
21 | #include "absl/container/flat_hash_map.h" |
22 | #include "absl/container/flat_hash_set.h" |
23 | #include "absl/hash/hash.h" |
24 | #include "absl/strings/string_view.h" |
25 | #include "tensorflow/core/framework/graph.pb.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | #include "tensorflow/core/framework/op_def.pb.h" |
28 | #include "tensorflow/core/graph/tensor_id.h" |
29 | #include "tensorflow/core/grappler/utils.h" |
30 | #include "tensorflow/core/lib/gtl/map_util.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | |
33 | namespace tensorflow { |
34 | namespace grappler { |
35 | |
36 | // Map a node/op's input/output port_id to arg_id. |
37 | // |
38 | // The port_id refers to the n-th tensor of the node, while the arg_id refers to |
39 | // the n-th arg of the op. These two can be different if an op's arg is a list |
40 | // of tensors. |
41 | // |
42 | // We return -1 for any invalid port_id (i.e., no corresponding arg_id). |
43 | int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); |
44 | int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); |
45 | |
46 | namespace internal { |
47 | |
48 | // GraphViewInternal is a helper class to simplify graph traversal. It creates |
49 | // an immutable view of the nodes and edges represented by a GraphDef protocol |
50 | // buffer. |
51 | // |
52 | // There are two public classes implementing GraphViewInternal: |
53 | // |
54 | // - GraphView: constructed from the `const GraphDef` and doesn't allow |
55 | // to mutate underlying graph via input/output ports lookup functions (ports |
56 | // have const pointers to nodes). |
57 | // |
58 | // - MutableGraphView: constructed from the 'GraphDef` and allows to mutate |
59 | // the graph via input/output ports lookup functions (ports have non-const |
60 | // pointers to nodes), and also have couple additional functions to |
61 | // add/remove/replace nodes in the graph. |
62 | // |
63 | // --------------------------- !!! WARNING !!! --------------------------------- |
64 | // Removing nodes from the graph outside of MutableGraphView will |
65 | // lead to segfaults! Guaranteed by absl::string_view! |
66 | // ----------------------------------------------------------------------------- |
67 | // |
68 | template <typename GraphDefT, typename NodeDefT> |
69 | class GraphViewInternal { |
70 | public: |
71 | struct Port { |
72 | Port() : node(nullptr), port_id(0) {} |
73 | Port(NodeDefT* n, int port) : node(n), port_id(port) {} |
74 | |
75 | bool operator==(const Port& other) const { |
76 | return node == other.node && port_id == other.port_id; |
77 | } |
78 | |
79 | template <typename H> |
80 | friend H AbslHashValue(H h, const Port& p) { |
81 | return H::combine(std::move(h), p.node, p.port_id); |
82 | } |
83 | |
84 | NodeDefT* node; |
85 | int port_id; |
86 | }; |
87 | |
88 | struct InputPort : public Port { |
89 | using Port::Port; |
90 | }; |
91 | |
92 | struct OutputPort : public Port { |
93 | using Port::Port; |
94 | }; |
95 | |
96 | struct Edge { |
97 | Edge(OutputPort s, InputPort d) : src(s), dst(d) {} |
98 | |
99 | bool operator==(const Edge& other) const { |
100 | return src == other.src && dst == other.dst; |
101 | } |
102 | |
103 | template <typename H> |
104 | friend H AbslHashValue(H h, const Edge& e) { |
105 | return H::combine(std::move(h), e.src, e.dst); |
106 | } |
107 | |
108 | OutputPort src; |
109 | InputPort dst; |
110 | }; |
111 | |
112 | GraphDefT* graph() const { return graph_; } |
113 | |
114 | // Finds a node by name or return `nullptr` if it's not in the graph view. |
115 | NodeDefT* GetNode(absl::string_view node_name) const { |
116 | return gtl::FindWithDefault(nodes_, node_name, nullptr); |
117 | } |
118 | |
119 | // Checks if a node by name is in the graph view. |
120 | bool HasNode(absl::string_view node_name) const { |
121 | return GetNode(node_name) != nullptr; |
122 | } |
123 | |
124 | // Gets the specified input port. Note that the special '-1' port_id can be |
125 | // used to access the controlling nodes (i.e. the nodes connected to node_name |
126 | // through an incoming control dependency). |
127 | InputPort GetInputPort(absl::string_view node_name, int port_id) const { |
128 | return InputPort(GetNode(node_name), port_id); |
129 | } |
130 | |
131 | // Gets the specified output port. Note that the special '-1' port_id can be |
132 | // used to access the controlled nodes (i.e. the nodes connected to node_name |
133 | // through an outgoing control dependency). |
134 | OutputPort GetOutputPort(absl::string_view node_name, int port_id) const { |
135 | return OutputPort(GetNode(node_name), port_id); |
136 | } |
137 | |
138 | // Gets the input port(s) in the immediate fanout of an output port. |
139 | const absl::flat_hash_set<InputPort>& GetFanout( |
140 | const OutputPort& port) const { |
141 | return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_); |
142 | } |
143 | |
144 | // Gets the output port(s) in the immediate fanin of an input port. |
145 | absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const { |
146 | if (port.port_id >= 0) { |
147 | OutputPort regular_fanin = GetRegularFanin(port); |
148 | if (regular_fanin.node == nullptr) { |
149 | return {}; |
150 | } |
151 | return {regular_fanin}; |
152 | } |
153 | |
154 | // Collect fanin for the control input. |
155 | absl::flat_hash_set<OutputPort> result; |
156 | const int first_control_port = |
157 | gtl::FindWithDefault(max_regular_input_port_, port.node, -1) + 1; |
158 | for (int i = first_control_port; i < port.node->input_size(); ++i) { |
159 | TensorId tensor_id = ParseTensorName(port.node->input(i)); |
160 | |
161 | auto it = nodes_.find(tensor_id.node()); |
162 | if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); |
163 | } |
164 | return result; |
165 | } |
166 | |
167 | // Special case: regular (i.e. non-control) input ports can only have one |
168 | // fanin. If port.port_id is out of range or is a control dependency, then an |
169 | // empty OutputPort is returned. |
170 | const OutputPort GetRegularFanin(const InputPort& port) const { |
171 | if (port.port_id < 0 || |
172 | port.port_id > |
173 | gtl::FindWithDefault(max_regular_input_port_, port.node, -1)) { |
174 | return OutputPort(); |
175 | } |
176 | |
177 | TensorId tensor_id = ParseTensorName(port.node->input(port.port_id)); |
178 | return GetOutputPort(tensor_id.node(), tensor_id.index()); |
179 | } |
180 | |
181 | // Checks if a tensor id is a fanin of the node. |
182 | bool HasFanin(const NodeDefT& node, const TensorId& fanin) const { |
183 | int end = node.input_size(); |
184 | if (end == 0 || fanin.index() < -1) { |
185 | return false; |
186 | } |
187 | |
188 | const int num_regular_fanins = |
189 | gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1; |
190 | int start = 0; |
191 | if (fanin.index() > -1) { |
192 | end = num_regular_fanins; |
193 | } else { |
194 | start = num_regular_fanins; |
195 | } |
196 | for (int i = start; i < end; ++i) { |
197 | if (ParseTensorName(node.input(i)) == fanin) { |
198 | return true; |
199 | } |
200 | } |
201 | return false; |
202 | } |
203 | |
204 | // Gets all the input ports in the immediate fanout of a node. Include the |
205 | // controlled nodes iff include_controlled_nodes is true. |
206 | absl::flat_hash_set<InputPort> GetFanouts( |
207 | const NodeDefT& node, bool include_controlled_nodes) const { |
208 | absl::flat_hash_set<InputPort> result; |
209 | |
210 | OutputPort port; |
211 | port.node = const_cast<NodeDefT*>(&node); |
212 | const int first_port_id = include_controlled_nodes ? -1 : 0; |
213 | const int last_port_id = |
214 | gtl::FindWithDefault(max_regular_output_port_, &node, -1); |
215 | |
216 | for (int i = first_port_id; i <= last_port_id; ++i) { |
217 | port.port_id = i; |
218 | auto it = fanouts_.find(port); |
219 | if (it != fanouts_.end()) { |
220 | result.insert(it->second.begin(), it->second.end()); |
221 | } |
222 | } |
223 | return result; |
224 | } |
225 | |
226 | // Gets all the output ports in the immediate fanin of a node. Include the |
227 | // controlling nodes iff include_controlling_nodes is true. |
228 | absl::flat_hash_set<OutputPort> GetFanins( |
229 | const NodeDefT& node, bool include_controlling_nodes) const { |
230 | absl::flat_hash_set<OutputPort> result; |
231 | const int max_input_port = |
232 | include_controlling_nodes |
233 | ? node.input_size() - 1 |
234 | : gtl::FindWithDefault(max_regular_input_port_, &node, -1); |
235 | for (int i = 0; i <= max_input_port; ++i) { |
236 | TensorId tensor_id = ParseTensorName(node.input(i)); |
237 | |
238 | auto it = nodes_.find(tensor_id.node()); |
239 | if (it != nodes_.end()) result.emplace(it->second, tensor_id.index()); |
240 | } |
241 | return result; |
242 | } |
243 | |
244 | // Gets the number of ports in the immediate fanin of a node. Count the |
245 | // controlling nodes iff include_controlling_nodes is true. |
246 | int NumFanins(const NodeDefT& node, bool include_controlling_nodes) const { |
247 | if (include_controlling_nodes) { |
248 | return node.input_size(); |
249 | } |
250 | return gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1; |
251 | } |
252 | |
253 | // Gets the number of ports in the immediate fanout of a node. Count the |
254 | // controlled nodes iff include_controlled_nodes is true. |
255 | int NumFanouts(const NodeDefT& node, bool include_controlled_nodes) const { |
256 | int count = 0; |
257 | |
258 | OutputPort port; |
259 | port.node = const_cast<NodeDefT*>(&node); |
260 | const int first_port_id = include_controlled_nodes ? -1 : 0; |
261 | const int last_port_id = |
262 | gtl::FindWithDefault(max_regular_output_port_, &node, -1); |
263 | |
264 | for (int i = first_port_id; i <= last_port_id; ++i) { |
265 | port.port_id = i; |
266 | auto it = fanouts_.find(port); |
267 | if (it != fanouts_.end()) count += it->second.size(); |
268 | } |
269 | |
270 | return count; |
271 | } |
272 | |
273 | // Gets all the edges in the immediate fanout of a node. Include the |
274 | // controlled edges iff include_controlled_edges is true. |
275 | absl::flat_hash_set<Edge> GetFanoutEdges( |
276 | const NodeDefT& node, bool include_controlled_edges) const { |
277 | absl::flat_hash_set<Edge> result; |
278 | |
279 | OutputPort port; |
280 | port.node = const_cast<NodeDefT*>(&node); |
281 | const int first_port_id = include_controlled_edges ? -1 : 0; |
282 | const int last_port_id = |
283 | gtl::FindWithDefault(max_regular_output_port_, &node, -1); |
284 | |
285 | for (int i = first_port_id; i <= last_port_id; ++i) { |
286 | port.port_id = i; |
287 | auto it = fanouts_.find(port); |
288 | if (it != fanouts_.end()) { |
289 | for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) { |
290 | result.emplace(/*src=*/port, /*dst=*/*itr); |
291 | } |
292 | } |
293 | } |
294 | return result; |
295 | } |
296 | |
297 | // Gets all the edges in the immediate fanin of a node. Include the |
298 | // controlling edges iff include_controlling_edges is true. |
299 | absl::flat_hash_set<Edge> GetFaninEdges( |
300 | const NodeDefT& node, bool include_controlling_edges) const { |
301 | absl::flat_hash_set<Edge> result; |
302 | const int max_input_port = |
303 | include_controlling_edges |
304 | ? node.input_size() - 1 |
305 | : gtl::FindWithDefault(max_regular_input_port_, &node, -1); |
306 | for (int i = 0; i <= max_input_port; ++i) { |
307 | TensorId tensor_id = ParseTensorName(node.input(i)); |
308 | |
309 | auto it = nodes_.find(tensor_id.node()); |
310 | if (it != nodes_.end()) { |
311 | result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()), |
312 | /*dst=*/InputPort(const_cast<NodeDefT*>(&node), i)); |
313 | } |
314 | } |
315 | return result; |
316 | } |
317 | |
318 | protected: |
319 | explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {} |
320 | |
321 | Status AddUniqueNode(NodeDefT* node) { |
322 | auto inserted = nodes_.emplace(node->name(), node); |
323 | return inserted.second |
324 | ? OkStatus() |
325 | : errors::InvalidArgument("Non unique node name detected: " , |
326 | node->name()); |
327 | } |
328 | |
329 | // TODO(ezhulenev): Remove this function. |
330 | void AddUniqueNodeOrDie(NodeDefT* node) { |
331 | Status st = AddUniqueNode(node); |
332 | CHECK(st.ok()) << st.error_message(); |
333 | } |
334 | |
335 | // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins |
336 | // exist, and all regular fanins come before controlling fanins. |
337 | void AddFanouts(NodeDefT* node) { |
338 | int max_input_port = -1; |
339 | for (int i = 0; i < node->input_size(); ++i) { |
340 | TensorId tensor_id = ParseTensorName(node->input(i)); |
341 | OutputPort output(nodes_[tensor_id.node()], tensor_id.index()); |
342 | |
343 | if (output.port_id < 0) { |
344 | fanouts_[output].emplace(node, -1); |
345 | } else { |
346 | max_input_port = i; |
347 | max_regular_output_port_[output.node] = |
348 | std::max(max_regular_output_port_[output.node], output.port_id); |
349 | fanouts_[output].emplace(node, i); |
350 | } |
351 | } |
352 | if (max_input_port > -1) { |
353 | max_regular_input_port_[node] = max_input_port; |
354 | } |
355 | } |
356 | |
357 | // Access to the mutable internal state for MutableGraphView. |
358 | absl::flat_hash_map<absl::string_view, NodeDefT*>& nodes() { return nodes_; } |
359 | |
360 | absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>& fanouts() { |
361 | return fanouts_; |
362 | } |
363 | |
364 | absl::flat_hash_map<const NodeDefT*, int>& max_regular_input_port() { |
365 | return max_regular_input_port_; |
366 | } |
367 | |
368 | absl::flat_hash_map<const NodeDefT*, int>& max_regular_output_port() { |
369 | return max_regular_output_port_; |
370 | } |
371 | |
372 | private: |
373 | GraphDefT* graph_; // must outlive the graph view |
374 | |
375 | // A mapping from the node name to the node itself. |
376 | absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_; |
377 | |
378 | // A mapping from the output port to all inputs that read from it. |
379 | absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_; |
380 | |
381 | // Keep a maximum index of input tensors of the node. |
382 | absl::flat_hash_map<const NodeDefT*, int> max_regular_input_port_; |
383 | |
384 | // Keep a maximum index of tensor fetched from the node. It doesn't guarantee |
385 | // that all tensors in the [0, max_regular_output_port] range are actually |
386 | // fetched by other nodes. |
387 | absl::flat_hash_map<const NodeDefT*, int> max_regular_output_port_; |
388 | |
389 | // If the node has no fanouts at given output port (output tensor consumers) |
390 | // we return a reference to this set from `GetFanout` (we can't construct new |
391 | // empty set every time, because we need a non-dangling reference). |
392 | absl::flat_hash_set<InputPort> fanout_not_found_value_; |
393 | }; |
394 | |
395 | } // namespace internal |
396 | |
397 | // Immutable GraphView that keeps the constness of the GraphDef. If you need to |
398 | // mutate the graph or the nodes via the graph view lookup functions, see |
399 | // MutableGraphView. |
400 | class GraphView |
401 | : public internal::GraphViewInternal<const GraphDef, const NodeDef> { |
402 | public: |
403 | explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) { |
404 | for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node); |
405 | for (const NodeDef& node : graph->node()) AddFanouts(&node); |
406 | } |
407 | }; |
408 | |
409 | // Returns true if node has one (or zero) fanout nodes at given output port. |
410 | bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node, |
411 | int port = 0); |
412 | |
413 | // Returns true if node has at least one fanout node at given output port. |
414 | bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0); |
415 | // Returns true if the node has at least one input control dependency. |
416 | bool HasControlFanin(const GraphView& graph_view, const NodeDef* node); |
417 | // Returns true if the node has at least one output control dependency. |
418 | bool HasControlFanout(const GraphView& graph_view, const NodeDef* node); |
419 | // Returns true if the node has at least one input or output control dependency. |
420 | bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node); |
421 | |
422 | } // end namespace grappler |
423 | } // end namespace tensorflow |
424 | |
425 | #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_ |
426 | |