1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/graph_topology_view.h"
17
18#include <algorithm>
19
20#include "absl/container/flat_hash_map.h"
21#include "absl/container/inlined_vector.h"
22#include "absl/strings/string_view.h"
23#include "absl/types/optional.h"
24#include "absl/types/span.h"
25#include "tensorflow/core/framework/graph.pb.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27
28namespace tensorflow {
29namespace grappler {
30
31namespace {
32
33template <typename T>
34inline void SortAndRemoveDuplicates(T* v) {
35 std::sort(v->begin(), v->end());
36 v->erase(std::unique(v->begin(), v->end()), v->end());
37}
38
39} // namespace
40
41Status GraphTopologyView::InitializeFromGraph(
42 const GraphDef& graph,
43 const absl::Span<const GraphView::Edge> ephemeral_edges,
44 bool ignore_control_edges) {
45 if (graph_ != nullptr) {
46 return errors::InvalidArgument("GraphTopologyView is already initialized.");
47 }
48
49 graph_ = &graph;
50 num_nodes_ = graph.node_size();
51 index_to_node_name_.resize(num_nodes_);
52 node_name_to_index_.rehash(num_nodes_);
53 fanins_.resize(num_nodes_);
54 fanouts_.resize(num_nodes_);
55
56 // Build map from name to index and vice versa.
57 for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
58 const NodeDef& node = graph.node(node_idx);
59 node_name_to_index_.emplace(node.name(), node_idx);
60 index_to_node_name_.emplace_back(node.name());
61 }
62
63 // 1. Add ephemeral edges to the adjacency lists.
64 for (const GraphView::Edge& edge : ephemeral_edges) {
65 const auto src = node_name_to_index_.find(edge.src.node->name());
66 const bool valid_src = src != node_name_to_index_.end();
67 if (!valid_src) {
68 const string error_message =
69 absl::StrCat("Non-existent src node: ", edge.src.node->name());
70 if (skip_invalid_edges_) {
71 VLOG(0) << "Skip error: " << error_message;
72 } else {
73 return errors::InvalidArgument(error_message);
74 }
75 }
76
77 const auto dst = node_name_to_index_.find(edge.dst.node->name());
78 const bool valid_dst = dst != node_name_to_index_.end();
79
80 if (!valid_dst) {
81 const string error_message =
82 absl::StrCat("Non-existent dst node: ", edge.dst.node->name());
83 if (skip_invalid_edges_) {
84 VLOG(0) << "Skip error: " << error_message;
85 } else {
86 return errors::InvalidArgument(error_message);
87 }
88 }
89
90 if (valid_dst && valid_src) {
91 const int src_idx = src->second;
92 const int dst_idx = dst->second;
93 if (ignore_control_edges && (src_idx < 0 || dst_idx < 0)) {
94 continue;
95 }
96 fanins_[dst_idx].push_back(src_idx);
97 fanouts_[src_idx].push_back(dst_idx);
98 }
99 }
100
101 // 2. Add graph edges to the adjacency lists.
102 for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
103 const NodeDef& node = graph.node(node_idx);
104 fanins_[node_idx].reserve(node.input_size());
105
106 for (const string& input : node.input()) {
107 TensorId tensor = ParseTensorName(input);
108 if (ignore_control_edges && IsTensorIdControl(tensor)) {
109 continue;
110 }
111 const auto it = node_name_to_index_.find(tensor.node());
112 const bool valid_input = it != node_name_to_index_.end();
113
114 if (!valid_input) {
115 const string error_message = absl::StrCat("Non-existent input ", input,
116 " in node ", node.name());
117 if (skip_invalid_edges_) {
118 VLOG(3) << "Skip error: " << error_message;
119 } else {
120 return errors::InvalidArgument(error_message);
121 }
122 }
123
124 if (valid_input) {
125 const int input_idx = it->second;
126 fanins_[node_idx].push_back(input_idx);
127 fanouts_[input_idx].push_back(node_idx);
128 }
129 }
130
131 // Dedup the input list while it's still hot in cache.
132 SortAndRemoveDuplicates(&fanins_[node_idx]);
133 }
134
135 // Dedup outputs for all the graph nodes.
136 for (int node_idx = 0; node_idx < num_nodes_; ++node_idx) {
137 SortAndRemoveDuplicates(&fanouts_[node_idx]);
138 }
139
140 return OkStatus();
141}
142
143Status GraphTopologyView::InitializeFromGraph(
144 const GraphDef& graph,
145 const absl::Span<const GraphView::Edge> ephemeral_edges) {
146 return InitializeFromGraph(graph, ephemeral_edges,
147 /*ignore_control_edges=*/false);
148}
149
150Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph,
151 bool ignore_control_edges) {
152 return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
153 ignore_control_edges);
154}
155
156Status GraphTopologyView::InitializeFromGraph(const GraphDef& graph) {
157 return InitializeFromGraph(graph, absl::Span<GraphView::Edge>(),
158 /*ignore_control_edges*/ false);
159}
160
161bool GraphTopologyView::HasNode(const absl::string_view node_name) const {
162 DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
163 const auto it = node_name_to_index_.find(node_name);
164 return it != node_name_to_index_.end();
165}
166
167const NodeDef* GraphTopologyView::GetNode(
168 const absl::string_view node_name) const {
169 DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
170 const auto it = node_name_to_index_.find(node_name);
171 return it == node_name_to_index_.end() ? nullptr : &graph_->node(it->second);
172}
173
174const NodeDef* GraphTopologyView::GetNode(int node_idx) const {
175 DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
176 DCHECK(node_idx >= 0 && node_idx < num_nodes_) << "node_idx is out of range";
177 return &graph_->node(node_idx);
178}
179
180const absl::optional<int> GraphTopologyView::GetNodeIndex(
181 const absl::string_view node_name) const {
182 DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
183 const auto it = node_name_to_index_.find(node_name);
184 DCHECK(it != node_name_to_index_.end()) << "Node doesn't exist in a graph";
185 return it == node_name_to_index_.end() ? absl::nullopt
186 : absl::make_optional(it->second);
187}
188
189const absl::optional<int> GraphTopologyView::GetNodeIndex(
190 const NodeDef& node) const {
191 return GetNodeIndex(node.name());
192}
193
194const absl::InlinedVector<int, 4>& GraphTopologyView::GetFanin(
195 int node_idx) const {
196 DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
197 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
198 DCHECK(is_valid_node_idx) << "node_idx is out of range";
199 return is_valid_node_idx ? fanins_[node_idx] : empty_fanin_;
200}
201
202const absl::InlinedVector<int, 2>& GraphTopologyView::GetFanout(
203 int node_idx) const {
204 DCHECK(is_initialized()) << "GraphTopologyView is not initialized";
205 const bool is_valid_node_idx = node_idx >= 0 && node_idx < num_nodes_;
206 DCHECK(is_valid_node_idx) << "node_idx is out of range";
207 return is_valid_node_idx ? fanouts_[node_idx] : empty_fanout_;
208}
209
210} // end namespace grappler
211} // end namespace tensorflow
212