1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_
17#define TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_
18
19#include <unordered_map>
20#include <unordered_set>
21#include "absl/container/flat_hash_map.h"
22#include "absl/container/flat_hash_set.h"
23#include "absl/hash/hash.h"
24#include "absl/strings/string_view.h"
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27#include "tensorflow/core/framework/op_def.pb.h"
28#include "tensorflow/core/graph/tensor_id.h"
29#include "tensorflow/core/grappler/utils.h"
30#include "tensorflow/core/lib/gtl/map_util.h"
31#include "tensorflow/core/platform/types.h"
32
33namespace tensorflow {
34namespace grappler {
35
36// Map a node/op's input/output port_id to arg_id.
37//
38// The port_id refers to the n-th tensor of the node, while the arg_id refers to
39// the n-th arg of the op. These two can be different if an op's arg is a list
40// of tensors.
41//
42// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
43int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
44int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
45
46namespace internal {
47
48// GraphViewInternal is a helper class to simplify graph traversal. It creates
49// an immutable view of the nodes and edges represented by a GraphDef protocol
50// buffer.
51//
52// There are two public classes implementing GraphViewInternal:
53//
54// - GraphView: constructed from the `const GraphDef` and doesn't allow
55// to mutate underlying graph via input/output ports lookup functions (ports
56// have const pointers to nodes).
57//
58// - MutableGraphView: constructed from the 'GraphDef` and allows to mutate
59// the graph via input/output ports lookup functions (ports have non-const
60// pointers to nodes), and also have couple additional functions to
61// add/remove/replace nodes in the graph.
62//
63// --------------------------- !!! WARNING !!! ---------------------------------
64// Removing nodes from the graph outside of MutableGraphView will
65// lead to segfaults! Guaranteed by absl::string_view!
66// -----------------------------------------------------------------------------
67//
68template <typename GraphDefT, typename NodeDefT>
69class GraphViewInternal {
70 public:
71 struct Port {
72 Port() : node(nullptr), port_id(0) {}
73 Port(NodeDefT* n, int port) : node(n), port_id(port) {}
74
75 bool operator==(const Port& other) const {
76 return node == other.node && port_id == other.port_id;
77 }
78
79 template <typename H>
80 friend H AbslHashValue(H h, const Port& p) {
81 return H::combine(std::move(h), p.node, p.port_id);
82 }
83
84 NodeDefT* node;
85 int port_id;
86 };
87
88 struct InputPort : public Port {
89 using Port::Port;
90 };
91
92 struct OutputPort : public Port {
93 using Port::Port;
94 };
95
96 struct Edge {
97 Edge(OutputPort s, InputPort d) : src(s), dst(d) {}
98
99 bool operator==(const Edge& other) const {
100 return src == other.src && dst == other.dst;
101 }
102
103 template <typename H>
104 friend H AbslHashValue(H h, const Edge& e) {
105 return H::combine(std::move(h), e.src, e.dst);
106 }
107
108 OutputPort src;
109 InputPort dst;
110 };
111
112 GraphDefT* graph() const { return graph_; }
113
114 // Finds a node by name or return `nullptr` if it's not in the graph view.
115 NodeDefT* GetNode(absl::string_view node_name) const {
116 return gtl::FindWithDefault(nodes_, node_name, nullptr);
117 }
118
119 // Checks if a node by name is in the graph view.
120 bool HasNode(absl::string_view node_name) const {
121 return GetNode(node_name) != nullptr;
122 }
123
124 // Gets the specified input port. Note that the special '-1' port_id can be
125 // used to access the controlling nodes (i.e. the nodes connected to node_name
126 // through an incoming control dependency).
127 InputPort GetInputPort(absl::string_view node_name, int port_id) const {
128 return InputPort(GetNode(node_name), port_id);
129 }
130
131 // Gets the specified output port. Note that the special '-1' port_id can be
132 // used to access the controlled nodes (i.e. the nodes connected to node_name
133 // through an outgoing control dependency).
134 OutputPort GetOutputPort(absl::string_view node_name, int port_id) const {
135 return OutputPort(GetNode(node_name), port_id);
136 }
137
138 // Gets the input port(s) in the immediate fanout of an output port.
139 const absl::flat_hash_set<InputPort>& GetFanout(
140 const OutputPort& port) const {
141 return gtl::FindWithDefault(fanouts_, port, fanout_not_found_value_);
142 }
143
144 // Gets the output port(s) in the immediate fanin of an input port.
145 absl::flat_hash_set<OutputPort> GetFanin(const InputPort& port) const {
146 if (port.port_id >= 0) {
147 OutputPort regular_fanin = GetRegularFanin(port);
148 if (regular_fanin.node == nullptr) {
149 return {};
150 }
151 return {regular_fanin};
152 }
153
154 // Collect fanin for the control input.
155 absl::flat_hash_set<OutputPort> result;
156 const int first_control_port =
157 gtl::FindWithDefault(max_regular_input_port_, port.node, -1) + 1;
158 for (int i = first_control_port; i < port.node->input_size(); ++i) {
159 TensorId tensor_id = ParseTensorName(port.node->input(i));
160
161 auto it = nodes_.find(tensor_id.node());
162 if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
163 }
164 return result;
165 }
166
167 // Special case: regular (i.e. non-control) input ports can only have one
168 // fanin. If port.port_id is out of range or is a control dependency, then an
169 // empty OutputPort is returned.
170 const OutputPort GetRegularFanin(const InputPort& port) const {
171 if (port.port_id < 0 ||
172 port.port_id >
173 gtl::FindWithDefault(max_regular_input_port_, port.node, -1)) {
174 return OutputPort();
175 }
176
177 TensorId tensor_id = ParseTensorName(port.node->input(port.port_id));
178 return GetOutputPort(tensor_id.node(), tensor_id.index());
179 }
180
181 // Checks if a tensor id is a fanin of the node.
182 bool HasFanin(const NodeDefT& node, const TensorId& fanin) const {
183 int end = node.input_size();
184 if (end == 0 || fanin.index() < -1) {
185 return false;
186 }
187
188 const int num_regular_fanins =
189 gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1;
190 int start = 0;
191 if (fanin.index() > -1) {
192 end = num_regular_fanins;
193 } else {
194 start = num_regular_fanins;
195 }
196 for (int i = start; i < end; ++i) {
197 if (ParseTensorName(node.input(i)) == fanin) {
198 return true;
199 }
200 }
201 return false;
202 }
203
204 // Gets all the input ports in the immediate fanout of a node. Include the
205 // controlled nodes iff include_controlled_nodes is true.
206 absl::flat_hash_set<InputPort> GetFanouts(
207 const NodeDefT& node, bool include_controlled_nodes) const {
208 absl::flat_hash_set<InputPort> result;
209
210 OutputPort port;
211 port.node = const_cast<NodeDefT*>(&node);
212 const int first_port_id = include_controlled_nodes ? -1 : 0;
213 const int last_port_id =
214 gtl::FindWithDefault(max_regular_output_port_, &node, -1);
215
216 for (int i = first_port_id; i <= last_port_id; ++i) {
217 port.port_id = i;
218 auto it = fanouts_.find(port);
219 if (it != fanouts_.end()) {
220 result.insert(it->second.begin(), it->second.end());
221 }
222 }
223 return result;
224 }
225
226 // Gets all the output ports in the immediate fanin of a node. Include the
227 // controlling nodes iff include_controlling_nodes is true.
228 absl::flat_hash_set<OutputPort> GetFanins(
229 const NodeDefT& node, bool include_controlling_nodes) const {
230 absl::flat_hash_set<OutputPort> result;
231 const int max_input_port =
232 include_controlling_nodes
233 ? node.input_size() - 1
234 : gtl::FindWithDefault(max_regular_input_port_, &node, -1);
235 for (int i = 0; i <= max_input_port; ++i) {
236 TensorId tensor_id = ParseTensorName(node.input(i));
237
238 auto it = nodes_.find(tensor_id.node());
239 if (it != nodes_.end()) result.emplace(it->second, tensor_id.index());
240 }
241 return result;
242 }
243
244 // Gets the number of ports in the immediate fanin of a node. Count the
245 // controlling nodes iff include_controlling_nodes is true.
246 int NumFanins(const NodeDefT& node, bool include_controlling_nodes) const {
247 if (include_controlling_nodes) {
248 return node.input_size();
249 }
250 return gtl::FindWithDefault(max_regular_input_port_, &node, -1) + 1;
251 }
252
253 // Gets the number of ports in the immediate fanout of a node. Count the
254 // controlled nodes iff include_controlled_nodes is true.
255 int NumFanouts(const NodeDefT& node, bool include_controlled_nodes) const {
256 int count = 0;
257
258 OutputPort port;
259 port.node = const_cast<NodeDefT*>(&node);
260 const int first_port_id = include_controlled_nodes ? -1 : 0;
261 const int last_port_id =
262 gtl::FindWithDefault(max_regular_output_port_, &node, -1);
263
264 for (int i = first_port_id; i <= last_port_id; ++i) {
265 port.port_id = i;
266 auto it = fanouts_.find(port);
267 if (it != fanouts_.end()) count += it->second.size();
268 }
269
270 return count;
271 }
272
273 // Gets all the edges in the immediate fanout of a node. Include the
274 // controlled edges iff include_controlled_edges is true.
275 absl::flat_hash_set<Edge> GetFanoutEdges(
276 const NodeDefT& node, bool include_controlled_edges) const {
277 absl::flat_hash_set<Edge> result;
278
279 OutputPort port;
280 port.node = const_cast<NodeDefT*>(&node);
281 const int first_port_id = include_controlled_edges ? -1 : 0;
282 const int last_port_id =
283 gtl::FindWithDefault(max_regular_output_port_, &node, -1);
284
285 for (int i = first_port_id; i <= last_port_id; ++i) {
286 port.port_id = i;
287 auto it = fanouts_.find(port);
288 if (it != fanouts_.end()) {
289 for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
290 result.emplace(/*src=*/port, /*dst=*/*itr);
291 }
292 }
293 }
294 return result;
295 }
296
297 // Gets all the edges in the immediate fanin of a node. Include the
298 // controlling edges iff include_controlling_edges is true.
299 absl::flat_hash_set<Edge> GetFaninEdges(
300 const NodeDefT& node, bool include_controlling_edges) const {
301 absl::flat_hash_set<Edge> result;
302 const int max_input_port =
303 include_controlling_edges
304 ? node.input_size() - 1
305 : gtl::FindWithDefault(max_regular_input_port_, &node, -1);
306 for (int i = 0; i <= max_input_port; ++i) {
307 TensorId tensor_id = ParseTensorName(node.input(i));
308
309 auto it = nodes_.find(tensor_id.node());
310 if (it != nodes_.end()) {
311 result.emplace(/*src=*/OutputPort(it->second, tensor_id.index()),
312 /*dst=*/InputPort(const_cast<NodeDefT*>(&node), i));
313 }
314 }
315 return result;
316 }
317
318 protected:
319 explicit GraphViewInternal(GraphDefT* graph) : graph_(graph) {}
320
321 Status AddUniqueNode(NodeDefT* node) {
322 auto inserted = nodes_.emplace(node->name(), node);
323 return inserted.second
324 ? OkStatus()
325 : errors::InvalidArgument("Non unique node name detected: ",
326 node->name());
327 }
328
329 // TODO(ezhulenev): Remove this function.
330 void AddUniqueNodeOrDie(NodeDefT* node) {
331 Status st = AddUniqueNode(node);
332 CHECK(st.ok()) << st.error_message();
333 }
334
335 // TODO(lyandy): Checks for self loops, Switch control dependencies, fanins
336 // exist, and all regular fanins come before controlling fanins.
337 void AddFanouts(NodeDefT* node) {
338 int max_input_port = -1;
339 for (int i = 0; i < node->input_size(); ++i) {
340 TensorId tensor_id = ParseTensorName(node->input(i));
341 OutputPort output(nodes_[tensor_id.node()], tensor_id.index());
342
343 if (output.port_id < 0) {
344 fanouts_[output].emplace(node, -1);
345 } else {
346 max_input_port = i;
347 max_regular_output_port_[output.node] =
348 std::max(max_regular_output_port_[output.node], output.port_id);
349 fanouts_[output].emplace(node, i);
350 }
351 }
352 if (max_input_port > -1) {
353 max_regular_input_port_[node] = max_input_port;
354 }
355 }
356
357 // Access to the mutable internal state for MutableGraphView.
358 absl::flat_hash_map<absl::string_view, NodeDefT*>& nodes() { return nodes_; }
359
360 absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>>& fanouts() {
361 return fanouts_;
362 }
363
364 absl::flat_hash_map<const NodeDefT*, int>& max_regular_input_port() {
365 return max_regular_input_port_;
366 }
367
368 absl::flat_hash_map<const NodeDefT*, int>& max_regular_output_port() {
369 return max_regular_output_port_;
370 }
371
372 private:
373 GraphDefT* graph_; // must outlive the graph view
374
375 // A mapping from the node name to the node itself.
376 absl::flat_hash_map<absl::string_view, NodeDefT*> nodes_;
377
378 // A mapping from the output port to all inputs that read from it.
379 absl::flat_hash_map<OutputPort, absl::flat_hash_set<InputPort>> fanouts_;
380
381 // Keep a maximum index of input tensors of the node.
382 absl::flat_hash_map<const NodeDefT*, int> max_regular_input_port_;
383
384 // Keep a maximum index of tensor fetched from the node. It doesn't guarantee
385 // that all tensors in the [0, max_regular_output_port] range are actually
386 // fetched by other nodes.
387 absl::flat_hash_map<const NodeDefT*, int> max_regular_output_port_;
388
389 // If the node has no fanouts at given output port (output tensor consumers)
390 // we return a reference to this set from `GetFanout` (we can't construct new
391 // empty set every time, because we need a non-dangling reference).
392 absl::flat_hash_set<InputPort> fanout_not_found_value_;
393};
394
395} // namespace internal
396
397// Immutable GraphView that keeps the constness of the GraphDef. If you need to
398// mutate the graph or the nodes via the graph view lookup functions, see
399// MutableGraphView.
400class GraphView
401 : public internal::GraphViewInternal<const GraphDef, const NodeDef> {
402 public:
403 explicit GraphView(const GraphDef* graph) : GraphViewInternal(graph) {
404 for (const NodeDef& node : graph->node()) AddUniqueNodeOrDie(&node);
405 for (const NodeDef& node : graph->node()) AddFanouts(&node);
406 }
407};
408
409// Returns true if node has one (or zero) fanout nodes at given output port.
410bool HasSingleFanoutNode(const GraphView& graph_view, const NodeDef* node,
411 int port = 0);
412
413// Returns true if node has at least one fanout node at given output port.
414bool HasFanouts(const GraphView& graph_view, const NodeDef* node, int port = 0);
415// Returns true if the node has at least one input control dependency.
416bool HasControlFanin(const GraphView& graph_view, const NodeDef* node);
417// Returns true if the node has at least one output control dependency.
418bool HasControlFanout(const GraphView& graph_view, const NodeDef* node);
419// Returns true if the node has at least one input or output control dependency.
420bool HasControlFaninOrFanout(const GraphView& graph_view, const NodeDef* node);
421
422} // end namespace grappler
423} // end namespace tensorflow
424
425#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_VIEW_H_
426