1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ |
17 | #define TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ |
18 | |
19 | #include "absl/container/flat_hash_map.h" |
20 | #include "absl/container/inlined_vector.h" |
21 | #include "absl/strings/string_view.h" |
22 | #include "absl/types/optional.h" |
23 | #include "absl/types/span.h" |
24 | #include "tensorflow/core/graph/tensor_id.h" |
25 | #include "tensorflow/core/grappler/graph_view.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace grappler { |
29 | |
30 | // GraphTopologyView is a helper class to simplify `node-to-node` connectivity |
31 | // traversals. Regular `GraphView` simplifies `tensor-to-tensor` traversals: |
32 | // connections between output tensors and inputs of a consumer nodes. For the |
33 | // topology view we are focused on nodes connected to nodes, and it's irrelevant |
34 | // if this connection is formed by one or multiple individual tensors. |
35 | // |
36 | // Example: |
37 | // a = Placeholder(..) |
38 | // b = Placeholder(..) |
39 | // c = AddN([a, a, b]) |
40 | // |
41 | // GraphView edges: [a:0 -> c:0, a:0 -> c:1, b:0 -> c:2] |
42 | // GraphTopologyView edges: [a -> c, b -> c] |
43 | // |
44 | // GraphView is used for exploring single node fanins and fanouts, and |
45 | // GraphTopologyView is focused on efficient full graph traversals (computing |
46 | // graph node properties from transitive fanouts, etc...). |
47 | class GraphTopologyView { |
48 | public: |
49 | GraphTopologyView() = default; |
50 | explicit GraphTopologyView(bool skip_invalid_edges) |
51 | : skip_invalid_edges_(skip_invalid_edges) {} |
52 | |
53 | // Initialize graph topology view from the graph. It's possible to pass |
54 | // additional edges that do not exist in a graph, but must be respected when |
55 | // computing graph topology. Example: Tensorflow runtime allows concurrent |
56 | // execution of dequeue/enqueue ops from the same queue resource, but we might |
57 | // want to enforce ordering between them for the purpose of graph analysis. |
58 | Status InitializeFromGraph(const GraphDef& graph, |
59 | absl::Span<const GraphView::Edge> ephemeral_edges, |
60 | bool ignore_control_edges); |
61 | Status InitializeFromGraph(const GraphDef& graph, |
62 | absl::Span<const GraphView::Edge> ephemeral_edges); |
63 | Status InitializeFromGraph(const GraphDef& graph, bool ignore_control_edges); |
64 | Status InitializeFromGraph(const GraphDef& graph); |
65 | |
66 | bool is_initialized() const { return graph_ != nullptr; } |
67 | int num_nodes() const { return num_nodes_; } |
68 | const GraphDef* graph() const { return graph_; } |
69 | |
70 | // Returns true iff the node exists in the underlying graph. |
71 | bool HasNode(absl::string_view node_name) const; |
72 | |
73 | // Finds a node by name or returns `nullptr` if it's not in the graph. |
74 | const NodeDef* GetNode(absl::string_view node_name) const; |
75 | // Returns a node corresponding to the given node index. |
76 | const NodeDef* GetNode(int node_idx) const; |
77 | |
78 | // Returns a node index for the given node name, if the name exists in the |
79 | // underlying graph. Otherwise returns empty optional. |
80 | const absl::optional<int> GetNodeIndex(absl::string_view node_name) const; |
81 | // Returns a node index for the given node, if the node belongs to the |
82 | // underlying graph. Otherwise returns empty optional. |
83 | const absl::optional<int> GetNodeIndex(const NodeDef& node) const; |
84 | |
85 | // Returns all the node indexes that are in the direct fanin of the given |
86 | // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. |
87 | const absl::InlinedVector<int, 4>& GetFanin(int node_idx) const; |
88 | // Returns all the node indexes that are in the direct fanout of the given |
89 | // node. If the `node_idx` is outside of [0, num_nodes_) returns empty vector. |
90 | const absl::InlinedVector<int, 2>& GetFanout(int node_idx) const; |
91 | |
92 | private: |
93 | // If true, all invalid edges and inputs (srd, dst or input node not found in |
94 | // a graph) will be skipped, otherwise initialization will fail with error. |
95 | bool skip_invalid_edges_ = false; |
96 | |
97 | // WARN: `graph_` must outlive this object and graph nodes must not be |
98 | // destructed, because node names captured with absl::string_view. |
99 | const GraphDef* graph_ = nullptr; // do not own |
100 | int num_nodes_ = 0; |
101 | std::vector<absl::string_view> index_to_node_name_; |
102 | absl::flat_hash_map<absl::string_view, int> node_name_to_index_; |
103 | std::vector<absl::InlinedVector<int, 4>> fanins_; // node_idx->input nodes |
104 | std::vector<absl::InlinedVector<int, 2>> fanouts_; // node_idx->output nodes |
105 | |
106 | // We need a valid reference to return from GetFanin/GetFanout if the |
107 | // `node_idx` argument is outside of the [0, num_nodes_) range. |
108 | absl::InlinedVector<int, 4> empty_fanin_; |
109 | absl::InlinedVector<int, 2> empty_fanout_; |
110 | }; |
111 | |
112 | } // end namespace grappler |
113 | } // end namespace tensorflow |
114 | |
115 | #endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_TOPOLOGY_VIEW_H_ |
116 | |