1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
17#define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
18
19#include "absl/container/flat_hash_map.h"
20#include "absl/container/inlined_vector.h"
21#include "absl/strings/string_view.h"
22#include "absl/types/optional.h"
23#include "absl/types/span.h"
24#include "tensorflow/core/graph/tensor_id.h"
25#include "tensorflow/core/grappler/graph_view.h"
26
27namespace tensorflow {
28namespace grappler {
29
30// GraphTopologyView is a helper class to simplify `node-to-node` connectivity
31// traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals:
32// connections between output tensors and inputs of a consumer nodes. For the
33// topology view we are focused on nodes connected to nodes, and it's irrelevant
34// if this connection is formed by one or multiple individual tensors.
35//
36// Example:
37// a = Placeholder(..)
38// b = Placeholder(..)
39// c = AddN([a, a, b])
40//
41// GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:2]
42// GraphTopologyView edges: [a -> c, b -> c]
43//
44// GraphView is used for exploring single node fanins and fanouts, and
45// GraphTopologyView is focused on efficient full graph traversals (computing
46// graph node properties from transitive fanouts, etc...).
47class GraphTopologyView {
48 public:
49 GraphTopologyView() = default;
50 explicit GraphTopologyView(bool skip_invalid_edges)
51 : skip_invalid_edges_(skip_invalid_edges) {}
52
53 // Initialize graph topology view from the graph. It's possible to pass
54 // additional edges that do not exist in a graph, but must be respected when
55 // computing graph topology. Example: Tensorflow runtime allows concurrent
56 // execution of dequeue/enqueue ops from the same queue resource, but we might
57 // want to enforce ordering between them for the purpose of graph analysis.
58 Status InitializeFromGraph(const GraphDef& graph,
59 absl::Span<const GraphView::Edge> ephemeral_edges,
60 bool ignore_control_edges);
61 Status InitializeFromGraph(const GraphDef& graph,
62 absl::Span<const GraphView::Edge> ephemeral_edges);
63 Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges);
64 Status InitializeFromGraph(const GraphDef& graph);
65
66 bool is_initialized() const { return graph_ != nullptr; }
67 int num_nodes() const { return num_nodes_; }
68 const GraphDef* graph() const { return graph_; }
69
70 // Returns true iff the node exists in the underlying graph.
71 bool HasNode(absl::string_view node_name) const;
72
73 // Finds a node by name or returns `nullptr` if it's not in the graph.
74 const NodeDef* GetNode(absl::string_view node_name) const;
75 // Returns a node corresponding to the given node index.
76 const NodeDef* GetNode(int node_idx) const;
77
78 // Returns a node index for the given node name, if the name exists in the
79 // underlying graph. Otherwise returns empty optional.
80 const absl::optional<int> GetNodeIndex(absl::string_view node_name) const;
81 // Returns a node index for the given node, if the node belongs to the
82 // underlying graph. Otherwise returns empty optional.
83 const absl::optional<int> GetNodeIndex(const NodeDef& node) const;
84
85 // Returns all the node indexes that are in the direct fanin of the given
86 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
87 const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const;
88 // Returns all the node indexes that are in the direct fanout of the given
89 // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector.
90 const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const;
91
92 private:
93 // If true, all invalid edges and inputs (srd, dst or input node not found in
94 // a graph) will be skipped, otherwise initialization will fail with error.
95 bool skip_invalid_edges_ = false;
96
97 // WARN: `graph_` must outlive this object and graph nodes must not be
98 // destructed, because node names captured with absl::string_view.
99 const GraphDef* graph_ = nullptr; // do not own
100 int num_nodes_ = 0;
101 std::vector<absl::string_view> index_to_node_name_;
102 absl::flat_hash_map<absl::string_view, int> node_name_to_index_;
103 std::vector<absl::InlinedVector<int, 4>> fanins_; // node_idx->input nodes
104 std::vector<absl::InlinedVector<int, 2>> fanouts_; // node_idx->output nodes
105
106 // We need a valid reference to return from GetFanin/GetFanout if the
107 // `node_idx` argument is outside of the [0, num_nodes_) range.
108 absl::InlinedVector<int, 4> empty_fanin_;
109 absl::InlinedVector<int, 2> empty_fanout_;
110};
111
112} // end namespace grappler
113} // end namespace tensorflow
114
115#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_
116