1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/grappler/graph_topology_view.h" |
17 | |
18 | #include <algorithm> |
19 | |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "absl/container/inlined_vector.h" |
22 | #include "absl/strings/string_view.h" |
23 | #include "absl/types/optional.h" |
24 | #include "absl/types/span.h" |
25 | #include "tensorflow/core/framework/graph.pb.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | |
28 | namespace tensorflow { |
29 | namespace grappler { |
30 | |
31 | namespace { |
32 | |
33 | template <typename T> |
34 | inline void SortAndRemoveDuplicates(T* v) { |
35 | std::sort(v->begin(), v->end()); |
36 | v->erase(std::unique(v->begin(), v->end()), v->end()); |
37 | } |
38 | |
39 | } // namespace |
40 | |
41 | Status GraphTopologyView::InitializeFromGraph( |
42 | const GraphDef& graph, |
43 | const absl::Span<const GraphView::Edge> ephemeral_edges, |
44 | bool ignore_control_edges) { |
45 | if (graph_ != nullptr) { |
46 | return errors::InvalidArgument("GraphTopologyView is already initialized." ); |
47 | } |
48 | |
49 | graph_ = &graph; |
50 | num_nodes_ = graph.node_size(); |
51 | index_to_node_name_.resize(num_nodes_); |
52 | node_name_to_index_.rehash(num_nodes_); |
53 | fanins_.resize(num_nodes_); |
54 | fanouts_.resize(num_nodes_); |
55 | |
56 | // Build map from name to index and vice versa. |
57 | for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) { |
58 | const NodeDef& node = graph.node(node_idx); |
59 | node_name_to_index_.emplace(node.name(), node_idx); |
60 | index_to_node_name_.emplace_back(node.name()); |
61 | } |
62 | |
63 | // 1. Add ephemeral edges to the adjacency lists. |
64 | for (const GraphView::Edge& edge : ephemeral_edges) { |
65 | const auto src = node_name_to_index_.find(edge.src.node->name()); |
66 | const bool valid_src = src != node_name_to_index_.end(); |
67 | if (!valid_src) { |
68 | const string error_message = |
69 | absl::StrCat("Non-existent src node: " , edge.src.node->name()); |
70 | if (skip_invalid_edges_) { |
71 | VLOG(0) << "Skip error: " << error_message; |
72 | } else { |
73 | return errors::InvalidArgument(error_message); |
74 | } |
75 | } |
76 | |
77 | const auto dst = node_name_to_index_.find(edge.dst.node->name()); |
78 | const bool valid_dst = dst != node_name_to_index_.end(); |
79 | |
80 | if (!valid_dst) { |
81 | const string error_message = |
82 | absl::StrCat("Non-existent dst node: " , edge.dst.node->name()); |
83 | if (skip_invalid_edges_) { |
84 | VLOG(0) << "Skip error: " << error_message; |
85 | } else { |
86 | return errors::InvalidArgument(error_message); |
87 | } |
88 | } |
89 | |
90 | if (valid_dst && valid_src) { |
91 | const int src_idx = src->second; |
92 | const int dst_idx = dst->second; |
93 | if (ignore_control_edges && (src_idx < 0 || dst_idx < 0)) { |
94 | continue; |
95 | } |
96 | fanins_[dst_idx].push_back(src_idx); |
97 | fanouts_[src_idx].push_back(dst_idx); |
98 | } |
99 | } |
100 | |
101 | // 2. Add graph edges to the adjacency lists. |
102 | for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) { |
103 | const NodeDef& node = graph.node(node_idx); |
104 | fanins_[node_idx].reserve(node.input_size()); |
105 | |
106 | for (const string& input : node.input()) { |
107 | TensorId tensor = ParseTensorName(input); |
108 | if (ignore_control_edges && IsTensorIdControl(tensor)) { |
109 | continue; |
110 | } |
111 | const auto it = node_name_to_index_.find(tensor.node()); |
112 | const bool valid_input = it != node_name_to_index_.end(); |
113 | |
114 | if (!valid_input) { |
115 | const string error_message = absl::StrCat("Non-existent input " , input, |
116 | " in node " , node.name()); |
117 | if (skip_invalid_edges_) { |
118 | VLOG(3) << "Skip error: " << error_message; |
119 | } else { |
120 | return errors::InvalidArgument(error_message); |
121 | } |
122 | } |
123 | |
124 | if (valid_input) { |
125 | const int input_idx = it->second; |
126 | fanins_[node_idx].push_back(input_idx); |
127 | fanouts_[input_idx].push_back(node_idx); |
128 | } |
129 | } |
130 | |
131 | // Dedup the input list while it's still hot in cache. |
132 | SortAndRemoveDuplicates(&fanins_[node_idx]); |
133 | } |
134 | |
135 | // Dedup outputs for all the graph nodes. |
136 | for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) { |
137 | SortAndRemoveDuplicates(&fanouts_[node_idx]); |
138 | } |
139 | |
140 | return OkStatus(); |
141 | } |
142 | |
143 | Status GraphTopologyView::InitializeFromGraph( |
144 | const GraphDef& graph, |
145 | const absl::Span<const GraphView::Edge> ephemeral_edges) { |
146 | return InitializeFromGraph(graph, ephemeral_edges, |
147 | /*ignore_control_edges=*/false); |
148 | } |
149 | |
150 | Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph, |
151 | bool ignore_control_edges) { |
152 | return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(), |
153 | ignore_control_edges); |
154 | } |
155 | |
156 | Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) { |
157 | return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(), |
158 | /*ignore_control_edges*/ false); |
159 | } |
160 | |
161 | bool GraphTopologyView::HasNode(const absl::string_view node_name) const { |
162 | DCHECK(is_initialized()) << "GraphTopologyView is not initialized" ; |
163 | const auto it = node_name_to_index_.find(node_name); |
164 | return it != node_name_to_index_.end(); |
165 | } |
166 | |
167 | const NodeDef* GraphTopologyView::GetNode( |
168 | const absl::string_view node_name) const { |
169 | DCHECK(is_initialized()) << "GraphTopologyView is not initialized" ; |
170 | const auto it = node_name_to_index_.find(node_name); |
171 | return it == node_name_to_index_.end() ? nullptr : &graph_->node(it->second); |
172 | } |
173 | |
174 | const NodeDef* GraphTopologyView::GetNode(int node_idx) const { |
175 | DCHECK(is_initialized()) << "GraphTopologyView is not initialized" ; |
176 | DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range" ; |
177 | return &graph_->node(node_idx); |
178 | } |
179 | |
180 | const absl::optional<int> GraphTopologyView::GetNodeIndex( |
181 | const absl::string_view node_name) const { |
182 | DCHECK(is_initialized()) << "GraphTopologyView is not initialized" ; |
183 | const auto it = node_name_to_index_.find(node_name); |
184 | DCHECK(it != node_name_to_index_.end()) << "Node doesn't exist in a graph" ; |
185 | return it == node_name_to_index_.end() ? absl::nullopt |
186 | : absl::make_optional(it->second); |
187 | } |
188 | |
189 | const absl::optional<int> GraphTopologyView::GetNodeIndex( |
190 | const NodeDef& node) const { |
191 | return GetNodeIndex(node.name()); |
192 | } |
193 | |
194 | const absl::InlinedVector<int, 4>& GraphTopologyView::GetFanin( |
195 | int node_idx) const { |
196 | DCHECK(is_initialized()) << "GraphTopologyView is not initialized" ; |
197 | const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_; |
198 | DCHECK(is_valid_node_idx) << "node_idx is out of range" ; |
199 | return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_; |
200 | } |
201 | |
202 | const absl::InlinedVector<int, 2>& GraphTopologyView::GetFanout( |
203 | int node_idx) const { |
204 | DCHECK(is_initialized()) << "GraphTopologyView is not initialized" ; |
205 | const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_; |
206 | DCHECK(is_valid_node_idx) << "node_idx is out of range" ; |
207 | return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_; |
208 | } |
209 | |
210 | } // end namespace grappler |
211 | } // end namespace tensorflow |
212 | |