1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/util/equal_graph_def.h" |
17 | |
18 | #include <unordered_map> |
19 | #include <unordered_set> |
20 | #include "tensorflow/core/framework/attr_value.pb.h" |
21 | #include "tensorflow/core/framework/attr_value_util.h" |
22 | #include "tensorflow/core/framework/graph.pb.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/node_def_util.h" |
25 | #include "tensorflow/core/lib/hash/hash.h" |
26 | #include "tensorflow/core/lib/strings/str_util.h" |
27 | #include "tensorflow/core/lib/strings/strcat.h" |
28 | #include "tensorflow/core/platform/protobuf.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected, |
33 | string* diff, const EqualGraphDefOptions& options) { |
34 | // Intentionally do not check that versions match so that this routine can |
35 | // be used for less brittle golden file tests. |
36 | return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options); |
37 | } |
38 | |
39 | uint64 GraphDefHash(const GraphDef& gdef, const EqualGraphDefOptions& options) { |
40 | return RepeatedNodeDefHash(gdef.node(), options); |
41 | } |
42 | |
43 | bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual, |
44 | const protobuf::RepeatedPtrField<NodeDef>& expected, |
45 | string* diff, const EqualGraphDefOptions& options) { |
46 | std::unordered_map<string, const NodeDef*> actual_index; |
47 | for (const NodeDef& node : actual) { |
48 | actual_index[node.name()] = &node; |
49 | } |
50 | |
51 | for (const NodeDef& expected_node : expected) { |
52 | auto actual_iter = actual_index.find(expected_node.name()); |
53 | if (actual_iter == actual_index.end()) { |
54 | if (diff != nullptr) { |
55 | *diff = strings::StrCat("Did not find expected node '" , |
56 | SummarizeNodeDef(expected_node), "'" ); |
57 | } |
58 | return false; |
59 | } |
60 | |
61 | if (!EqualNodeDef(*actual_iter->second, expected_node, diff, options)) { |
62 | return false; |
63 | } |
64 | |
65 | actual_index.erase(actual_iter); |
66 | } |
67 | |
68 | if (!actual_index.empty()) { |
69 | if (diff != nullptr) { |
70 | *diff = |
71 | strings::StrCat("Found unexpected node '" , |
72 | SummarizeNodeDef(*actual_index.begin()->second), "'" ); |
73 | } |
74 | return false; |
75 | } |
76 | |
77 | return true; |
78 | } |
79 | |
80 | uint64 RepeatedNodeDefHash(const protobuf::RepeatedPtrField<NodeDef>& ndefs, |
81 | const EqualGraphDefOptions& options) { |
82 | uint64 h = 0xDECAFCAFFE; |
83 | // Insert NodeDefs into map to deterministically sort by name |
84 | std::map<string, const NodeDef*> nodes; |
85 | for (const NodeDef& node : ndefs) { |
86 | nodes[node.name()] = &node; |
87 | } |
88 | for (const auto& pair : nodes) { |
89 | h = Hash64(pair.first.data(), pair.first.size(), h); |
90 | h = Hash64Combine(NodeDefHash(*pair.second, options), h); |
91 | } |
92 | return h; |
93 | } |
94 | |
95 | namespace { |
96 | |
97 | string JoinStringField(const protobuf::RepeatedPtrField<string>& f) { |
98 | string ret; |
99 | for (int i = 0; i < f.size(); ++i) { |
100 | if (i > 0) strings::StrAppend(&ret, ", " ); |
101 | strings::StrAppend(&ret, f.Get(i)); |
102 | } |
103 | return ret; |
104 | } |
105 | |
106 | } // namespace |
107 | |
108 | bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff, |
109 | const EqualGraphDefOptions& options) { |
110 | if (actual.name() != expected.name()) { |
111 | if (diff != nullptr) { |
112 | *diff = strings::StrCat("Actual node name '" , actual.name(), |
113 | "' is not expected '" , expected.name(), "'" ); |
114 | } |
115 | return false; |
116 | } |
117 | |
118 | if (actual.op() != expected.op()) { |
119 | if (diff != nullptr) { |
120 | *diff = strings::StrCat("Node named '" , actual.name(), "' has op '" , |
121 | actual.op(), "' that is not expected '" , |
122 | expected.op(), "'" ); |
123 | } |
124 | return false; |
125 | } |
126 | |
127 | if (actual.device() != expected.device()) { |
128 | if (diff != nullptr) { |
129 | *diff = strings::StrCat("Node named '" , actual.name(), "' has device '" , |
130 | actual.device(), "' that is not expected '" , |
131 | expected.device(), "'" ); |
132 | } |
133 | return false; |
134 | } |
135 | |
136 | if (actual.input_size() != expected.input_size()) { |
137 | if (diff != nullptr) { |
138 | *diff = strings::StrCat("Node named '" , actual.name(), "' has inputs '" , |
139 | JoinStringField(actual.input()), |
140 | "' that don't match expected '" , |
141 | JoinStringField(expected.input()), "'" ); |
142 | } |
143 | return false; |
144 | } |
145 | |
146 | int first_control_input = actual.input_size(); |
147 | for (int i = 0; i < actual.input_size(); ++i) { |
148 | if (absl::StartsWith(actual.input(i), "^" )) { |
149 | first_control_input = i; |
150 | break; |
151 | } |
152 | // Special case for inputs: "tensor" is equivalent to "tensor:0" |
153 | if (actual.input(i) != expected.input(i) && |
154 | actual.input(i) != strings::StrCat(expected.input(i), ":0" ) && |
155 | strings::StrCat(actual.input(i), ":0" ) != expected.input(i)) { |
156 | if (diff != nullptr) { |
157 | *diff = strings::StrCat("Node named '" , actual.name(), "' has input " , |
158 | i, " '" , actual.input(i), |
159 | "' that doesn't match expected '" , |
160 | expected.input(i), "'" ); |
161 | } |
162 | return false; |
163 | } |
164 | } |
165 | |
166 | std::unordered_set<string> actual_control; |
167 | std::unordered_set<string> expected_control; |
168 | for (int i = first_control_input; i < actual.input_size(); ++i) { |
169 | actual_control.insert(actual.input(i)); |
170 | expected_control.insert(expected.input(i)); |
171 | } |
172 | for (const auto& e : expected_control) { |
173 | if (actual_control.erase(e) == 0) { |
174 | if (diff != nullptr) { |
175 | *diff = strings::StrCat("Node named '" , actual.name(), |
176 | "' missing expected control input '" , e, "'" ); |
177 | } |
178 | return false; |
179 | } |
180 | } |
181 | if (!actual_control.empty()) { |
182 | if (diff != nullptr) { |
183 | *diff = strings::StrCat("Node named '" , actual.name(), |
184 | "' has unexpected control input '" , |
185 | *actual_control.begin(), "'" ); |
186 | } |
187 | return false; |
188 | } |
189 | |
190 | std::unordered_set<string> actual_attr; |
191 | for (const auto& a : actual.attr()) { |
192 | if (options.ignore_internal_attrs && !a.first.empty() && |
193 | a.first[0] == '_') { |
194 | continue; |
195 | } |
196 | actual_attr.insert(a.first); |
197 | } |
198 | for (const auto& e : expected.attr()) { |
199 | if (options.ignore_internal_attrs && !e.first.empty() && |
200 | e.first[0] == '_') { |
201 | continue; |
202 | } |
203 | |
204 | if (actual_attr.erase(e.first) == 0) { |
205 | if (diff != nullptr) { |
206 | *diff = strings::StrCat("Node named '" , actual.name(), |
207 | "' missing expected attr '" , e.first, |
208 | "' with value: " , SummarizeAttrValue(e.second)); |
209 | } |
210 | return false; |
211 | } |
212 | auto iter = actual.attr().find(e.first); |
213 | if (!AreAttrValuesEqual(e.second, iter->second)) { |
214 | if (diff != nullptr) { |
215 | *diff = strings::StrCat( |
216 | "Node named '" , actual.name(), "' has attr '" , e.first, |
217 | "' with value: " , SummarizeAttrValue(iter->second), |
218 | " that does not match expected: " , SummarizeAttrValue(e.second)); |
219 | } |
220 | return false; |
221 | } |
222 | } |
223 | if (!actual_attr.empty()) { |
224 | if (diff != nullptr) { |
225 | *diff = strings::StrCat( |
226 | "Node named '" , actual.name(), "' has unexpected attr '" , |
227 | *actual_attr.begin(), "' with value: " , |
228 | SummarizeAttrValue(actual.attr().find(*actual_attr.begin())->second)); |
229 | } |
230 | return false; |
231 | } |
232 | |
233 | return true; |
234 | } |
235 | |
236 | uint64 NodeDefHash(const NodeDef& ndef, const EqualGraphDefOptions& options) { |
237 | uint64 h = Hash64(ndef.name()); |
238 | h = Hash64(ndef.op().data(), ndef.op().size(), h); |
239 | h = Hash64(ndef.device().data(), ndef.device().size(), h); |
240 | |
241 | // Normal inputs. Order important. |
242 | int first_control_input = ndef.input_size(); |
243 | for (int i = 0; i < ndef.input_size(); ++i) { |
244 | if (absl::StartsWith(ndef.input(i), "^" )) { |
245 | first_control_input = i; |
246 | break; |
247 | } |
248 | h = Hash64(ndef.input(i).data(), ndef.input(i).size(), h); |
249 | } |
250 | |
251 | // Control inputs. Order irrelevant. |
252 | std::set<string> ndef_control; |
253 | for (int i = first_control_input; i < ndef.input_size(); ++i) { |
254 | ndef_control.insert(ndef.input(i)); |
255 | } |
256 | for (const string& s : ndef_control) { |
257 | h = Hash64(s.data(), s.size(), h); |
258 | } |
259 | |
260 | // Attributes |
261 | std::map<string, AttrValue> ndef_attr; |
262 | for (const auto& a : ndef.attr()) { |
263 | if (options.ignore_internal_attrs && !a.first.empty() && |
264 | a.first[0] == '_') { |
265 | continue; |
266 | } |
267 | ndef_attr[a.first] = a.second; |
268 | } |
269 | for (const auto& a : ndef_attr) { |
270 | h = Hash64(a.first.data(), a.first.size(), h); |
271 | h = Hash64Combine(AttrValueHash(a.second), h); |
272 | } |
273 | |
274 | return h; |
275 | } |
276 | |
277 | } // namespace tensorflow |
278 | |