1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/util/equal_graph_def.h"
17
18#include <unordered_map>
19#include <unordered_set>
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/framework/attr_value_util.h"
22#include "tensorflow/core/framework/graph.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/node_def_util.h"
25#include "tensorflow/core/lib/hash/hash.h"
26#include "tensorflow/core/lib/strings/str_util.h"
27#include "tensorflow/core/lib/strings/strcat.h"
28#include "tensorflow/core/platform/protobuf.h"
29
30namespace tensorflow {
31
32bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
33 string* diff, const EqualGraphDefOptions& options) {
34 // Intentionally do not check that versions match so that this routine can
35 // be used for less brittle golden file tests.
36 return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
37}
38
39uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) {
40 return RepeatedNodeDefHash(gdef.node(), options);
41}
42
43bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
44 const protobuf::RepeatedPtrField<NodeDef>& expected,
45 string* diff, const EqualGraphDefOptions& options) {
46 std::unordered_map<string, const NodeDef*> actual_index;
47 for (const NodeDef& node : actual) {
48 actual_index[node.name()] = &node;
49 }
50
51 for (const NodeDef& expected_node : expected) {
52 auto actual_iter = actual_index.find(expected_node.name());
53 if (actual_iter == actual_index.end()) {
54 if (diff != nullptr) {
55 *diff = strings::StrCat("Did not find expected node '",
56 SummarizeNodeDef(expected_node), "'");
57 }
58 return false;
59 }
60
61 if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) {
62 return false;
63 }
64
65 actual_index.erase(actual_iter);
66 }
67
68 if (!actual_index.empty()) {
69 if (diff != nullptr) {
70 *diff =
71 strings::StrCat("Found unexpected node '",
72 SummarizeNodeDef(*actual_index.begin()->second), "'");
73 }
74 return false;
75 }
76
77 return true;
78}
79
80uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs,
81 const EqualGraphDefOptions& options) {
82 uint64 h = 0xDECAFCAFFE;
83 // Insert NodeDefs into map to deterministically sort by name
84 std::map<string, const NodeDef*> nodes;
85 for (const NodeDef& node : ndefs) {
86 nodes[node.name()] = &node;
87 }
88 for (const auto& pair : nodes) {
89 h = Hash64(pair.first.data(), pair.first.size(), h);
90 h = Hash64Combine(NodeDefHash(*pair.second, options), h);
91 }
92 return h;
93}
94
95namespace {
96
97string JoinStringField(const protobuf::RepeatedPtrField<string>& f) {
98 string ret;
99 for (int i = 0; i < f.size(); ++i) {
100 if (i > 0) strings::StrAppend(&ret, ", ");
101 strings::StrAppend(&ret, f.Get(i));
102 }
103 return ret;
104}
105
106} // namespace
107
108bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
109 const EqualGraphDefOptions& options) {
110 if (actual.name() != expected.name()) {
111 if (diff != nullptr) {
112 *diff = strings::StrCat("Actual node name '", actual.name(),
113 "' is not expected '", expected.name(), "'");
114 }
115 return false;
116 }
117
118 if (actual.op() != expected.op()) {
119 if (diff != nullptr) {
120 *diff = strings::StrCat("Node named '", actual.name(), "' has op '",
121 actual.op(), "' that is not expected '",
122 expected.op(), "'");
123 }
124 return false;
125 }
126
127 if (actual.device() != expected.device()) {
128 if (diff != nullptr) {
129 *diff = strings::StrCat("Node named '", actual.name(), "' has device '",
130 actual.device(), "' that is not expected '",
131 expected.device(), "'");
132 }
133 return false;
134 }
135
136 if (actual.input_size() != expected.input_size()) {
137 if (diff != nullptr) {
138 *diff = strings::StrCat("Node named '", actual.name(), "' has inputs '",
139 JoinStringField(actual.input()),
140 "' that don't match expected '",
141 JoinStringField(expected.input()), "'");
142 }
143 return false;
144 }
145
146 int first_control_input = actual.input_size();
147 for (int i = 0; i < actual.input_size(); ++i) {
148 if (absl::StartsWith(actual.input(i), "^")) {
149 first_control_input = i;
150 break;
151 }
152 // Special case for inputs: "tensor" is equivalent to "tensor:0"
153 if (actual.input(i) != expected.input(i) &&
154 actual.input(i) != strings::StrCat(expected.input(i), ":0") &&
155 strings::StrCat(actual.input(i), ":0") != expected.input(i)) {
156 if (diff != nullptr) {
157 *diff = strings::StrCat("Node named '", actual.name(), "' has input ",
158 i, " '", actual.input(i),
159 "' that doesn't match expected '",
160 expected.input(i), "'");
161 }
162 return false;
163 }
164 }
165
166 std::unordered_set<string> actual_control;
167 std::unordered_set<string> expected_control;
168 for (int i = first_control_input; i < actual.input_size(); ++i) {
169 actual_control.insert(actual.input(i));
170 expected_control.insert(expected.input(i));
171 }
172 for (const auto& e : expected_control) {
173 if (actual_control.erase(e) == 0) {
174 if (diff != nullptr) {
175 *diff = strings::StrCat("Node named '", actual.name(),
176 "' missing expected control input '", e, "'");
177 }
178 return false;
179 }
180 }
181 if (!actual_control.empty()) {
182 if (diff != nullptr) {
183 *diff = strings::StrCat("Node named '", actual.name(),
184 "' has unexpected control input '",
185 *actual_control.begin(), "'");
186 }
187 return false;
188 }
189
190 std::unordered_set<string> actual_attr;
191 for (const auto& a : actual.attr()) {
192 if (options.ignore_internal_attrs && !a.first.empty() &&
193 a.first[0] == '_') {
194 continue;
195 }
196 actual_attr.insert(a.first);
197 }
198 for (const auto& e : expected.attr()) {
199 if (options.ignore_internal_attrs && !e.first.empty() &&
200 e.first[0] == '_') {
201 continue;
202 }
203
204 if (actual_attr.erase(e.first) == 0) {
205 if (diff != nullptr) {
206 *diff = strings::StrCat("Node named '", actual.name(),
207 "' missing expected attr '", e.first,
208 "' with value: ", SummarizeAttrValue(e.second));
209 }
210 return false;
211 }
212 auto iter = actual.attr().find(e.first);
213 if (!AreAttrValuesEqual(e.second, iter->second)) {
214 if (diff != nullptr) {
215 *diff = strings::StrCat(
216 "Node named '", actual.name(), "' has attr '", e.first,
217 "' with value: ", SummarizeAttrValue(iter->second),
218 " that does not match expected: ", SummarizeAttrValue(e.second));
219 }
220 return false;
221 }
222 }
223 if (!actual_attr.empty()) {
224 if (diff != nullptr) {
225 *diff = strings::StrCat(
226 "Node named '", actual.name(), "' has unexpected attr '",
227 *actual_attr.begin(), "' with value: ",
228 SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second));
229 }
230 return false;
231 }
232
233 return true;
234}
235
236uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) {
237 uint64 h = Hash64(ndef.name());
238 h = Hash64(ndef.op().data(), ndef.op().size(), h);
239 h = Hash64(ndef.device().data(), ndef.device().size(), h);
240
241 // Normal inputs. Order important.
242 int first_control_input = ndef.input_size();
243 for (int i = 0; i < ndef.input_size(); ++i) {
244 if (absl::StartsWith(ndef.input(i), "^")) {
245 first_control_input = i;
246 break;
247 }
248 h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h);
249 }
250
251 // Control inputs. Order irrelevant.
252 std::set<string> ndef_control;
253 for (int i = first_control_input; i < ndef.input_size(); ++i) {
254 ndef_control.insert(ndef.input(i));
255 }
256 for (const string& s : ndef_control) {
257 h = Hash64(s.data(), s.size(), h);
258 }
259
260 // Attributes
261 std::map<string, AttrValue> ndef_attr;
262 for (const auto& a : ndef.attr()) {
263 if (options.ignore_internal_attrs && !a.first.empty() &&
264 a.first[0] == '_') {
265 continue;
266 }
267 ndef_attr[a.first] = a.second;
268 }
269 for (const auto& a : ndef_attr) {
270 h = Hash64(a.first.data(), a.first.size(), h);
271 h = Hash64Combine(AttrValueHash(a.second), h);
272 }
273
274 return h;
275}
276
277} // namespace tensorflow
278