1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/data/hash_utils.h" |
16 | |
17 | #include <queue> |
18 | |
19 | #include "absl/container/flat_hash_map.h" |
20 | #include "absl/container/flat_hash_set.h" |
21 | #include "absl/strings/str_cat.h" |
22 | #include "absl/strings/str_join.h" |
23 | #include "tensorflow/core/common_runtime/function.h" |
24 | #include "tensorflow/core/data/dataset_utils.h" |
25 | #include "tensorflow/core/framework/attr_value.pb.h" |
26 | #include "tensorflow/core/framework/dataset.h" |
27 | #include "tensorflow/core/framework/function.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/framework/op.h" |
30 | #include "tensorflow/core/framework/op_def.pb.h" |
31 | #include "tensorflow/core/framework/op_def_builder.h" |
32 | #include "tensorflow/core/framework/op_def_util.h" |
33 | #include "tensorflow/core/framework/op_kernel.h" |
34 | #include "tensorflow/core/framework/tensor.pb.h" |
35 | #include "tensorflow/core/framework/types.h" |
36 | #include "tensorflow/core/graph/graph_def_builder.h" |
37 | #include "tensorflow/core/lib/core/errors.h" |
38 | #include "tensorflow/core/lib/hash/hash.h" |
39 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
40 | #include "tensorflow/core/platform/errors.h" |
41 | #include "tensorflow/core/platform/regexp.h" |
42 | #include "tensorflow/core/platform/status.h" |
43 | #include "tensorflow/core/util/work_sharder.h" |
44 | |
45 | namespace tensorflow { |
46 | namespace data { |
47 | namespace { |
48 | |
49 | // clang-format off |
50 | constexpr std::array<const char*, 3> kOpsWithSeed = { |
51 | "AnonymousRandomSeedGenerator" , |
52 | "ShuffleDataset" , |
53 | "ShuffleAndRepeatDataset" |
54 | }; |
55 | // clang-format on |
56 | constexpr char kSeedInputName[] = "seed" ; |
57 | constexpr char kSeed2InputName[] = "seed2" ; |
58 | constexpr char kSeedGeneratorInputName[] = "seed_generator" ; |
59 | |
60 | template <std::size_t SIZE> |
61 | bool IsNodeOfType(const NodeDef& node, |
62 | const std::array<const char*, SIZE>& op_types) { |
63 | for (const auto& type : op_types) { |
64 | if (MatchesAnyVersion(type, node.op())) { |
65 | return true; |
66 | } |
67 | } |
68 | return false; |
69 | } |
70 | |
71 | Status GetSink(const GraphDef& graph_def, const NodeDef** sink) { |
72 | for (auto& node : graph_def.node()) { |
73 | if (node.op() == "_Retval" ) { |
74 | *sink = &node; |
75 | break; |
76 | } |
77 | } |
78 | |
79 | if (sink == nullptr) { |
80 | return errors::Internal("Cannot find sink node for dataset graph." ); |
81 | } |
82 | return OkStatus(); |
83 | } |
84 | |
85 | Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) { |
86 | *result = false; |
87 | if (IsNodeOfType(node, kOpsWithSeed)) { |
88 | const OpRegistrationData* reg; |
89 | auto status = OpRegistry::Global()->LookUp(node.op(), ®); |
90 | |
91 | if (status.ok()) { |
92 | if (reg->op_def.input_arg_size() > i) { |
93 | const std::string input_arg_name = reg->op_def.input_arg(i).name(); |
94 | if (input_arg_name == kSeedInputName || |
95 | input_arg_name == kSeed2InputName || |
96 | input_arg_name == kSeedGeneratorInputName) { |
97 | VLOG(2) << "Ignoring arg: " << input_arg_name |
98 | << " from node: " << node.name(); |
99 | *result = true; |
100 | return OkStatus(); |
101 | } |
102 | } |
103 | } else if (errors::IsNotFound(status)) { |
104 | LOG(WARNING) << "Cannot find " << node.op() |
105 | << " in global op registry, so cannot determine which " |
106 | "inputs are seeds." ; |
107 | } else { |
108 | return status; |
109 | } |
110 | } |
111 | return OkStatus(); |
112 | } |
113 | |
114 | Status ParseInputNodeName(absl::string_view input_name, |
115 | absl::string_view* node_name, |
116 | absl::string_view* suffix, bool* is_control_input) { |
117 | if (input_name[0] == '^') { |
118 | *node_name = input_name.substr(1); |
119 | *is_control_input = true; |
120 | return OkStatus(); |
121 | } |
122 | std::pair<absl::string_view, absl::string_view> node_spec = |
123 | absl::StrSplit(input_name, absl::MaxSplits(':', 1)); |
124 | *node_name = node_spec.first; |
125 | *suffix = node_spec.second; |
126 | *is_control_input = false; |
127 | return OkStatus(); |
128 | } |
129 | |
130 | // Given a graph_def and a root_node, this class computes a fingerprint that |
131 | // tries to capture the structure of the graph rooted at the provided node. |
132 | // It does not at any point rely on the names of the nodes in the graph and |
133 | // just relies on the connections between different nodes. In the presence of |
134 | // multiple cycles in the graph, there is a non-zero possibility that two |
135 | // graphs with different structure might end up with the same fingerprint |
136 | // as in order to break cycles we prune away some edges (in a deterministic |
137 | // fashion though). Idea for this algorithm was borrowed from: |
138 | // https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali |
139 | class GraphHasher { |
140 | using NodeCache = absl::flat_hash_map<const NodeDef*, uint64>; |
141 | using FunctionCache = absl::flat_hash_map<const FunctionDef*, uint64>; |
142 | using AttrCache = |
143 | absl::flat_hash_map<std::pair<const NodeDef*, bool>, uint64>; |
144 | |
145 | public: |
146 | // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or |
147 | // `flib_def`. |
148 | explicit GraphHasher(const GraphDef* graph, const NodeDef* root, |
149 | const FunctionLibraryDefinition* flib) |
150 | : graph_(graph), root_(root), flib_(flib) { |
151 | node_cache_ = std::make_shared<NodeCache>(); |
152 | function_cache_ = std::make_shared<FunctionCache>(); |
153 | attr_cache_ = std::make_shared<AttrCache>(); |
154 | } |
155 | explicit GraphHasher(const GraphDef* graph, const NodeDef* root, |
156 | const FunctionLibraryDefinition* flib, |
157 | std::shared_ptr<NodeCache> node_cache, |
158 | std::shared_ptr<FunctionCache> function_cache, |
159 | std::shared_ptr<AttrCache> attr_cache) |
160 | : graph_(graph), |
161 | root_(root), |
162 | flib_(flib), |
163 | node_cache_(node_cache), |
164 | function_cache_(function_cache), |
165 | attr_cache_(attr_cache) {} |
166 | |
167 | Status Init() { |
168 | // Construct a map of name -> NodeDef to avoid repeated linear searches. |
169 | absl::flat_hash_map<absl::string_view, const NodeDef*> node_def_by_name; |
170 | node_def_by_name.reserve(graph_->node_size()); |
171 | for (const auto& node : graph_->node()) { |
172 | auto result = node_def_by_name.emplace(node.name(), &node); |
173 | if (TF_PREDICT_FALSE(!result.second)) { |
174 | auto node_name_formatter = |
175 | [](std::string* out, |
176 | const decltype(node_def_by_name)::value_type& item) { |
177 | absl::StrAppend(out, "'" , item.first, "'" ); |
178 | }; |
179 | return errors::Internal( |
180 | "Encountered graph with duplicate node name '" , node.name(), |
181 | "' in [" , absl::StrJoin(node_def_by_name, "," , node_name_formatter), |
182 | "]" ); |
183 | } |
184 | } |
185 | // Pre-process the graph to do a BFS and prune away cycles that might cause |
186 | // problems. |
187 | absl::flat_hash_set<absl::string_view> visited; |
188 | std::queue<const NodeDef*> bfs_queue; |
189 | bfs_queue.push(root_); |
190 | while (!bfs_queue.empty()) { |
191 | const NodeDef* node = bfs_queue.front(); |
192 | bfs_queue.pop(); |
193 | if (visited.contains(node->name())) { |
194 | continue; |
195 | } |
196 | visited.insert(node->name()); |
197 | NodeRep node_rep; |
198 | for (int i = 0; i < node->input_size(); ++i) { |
199 | DCHECK_GT(node->input(i).length(), 0); |
200 | |
201 | // We skip trying to take the hash of the seeds of any ops, as they |
202 | // are irrelevant to the hash of the graph and may vary from run to run. |
203 | bool should_ignore_input = false; |
204 | TF_RETURN_IF_ERROR(ShouldIgnoreInput(*node, i, &should_ignore_input)); |
205 | if (should_ignore_input) continue; |
206 | |
207 | absl::string_view node_name, suffix; |
208 | bool is_control_input; |
209 | TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name, |
210 | &suffix, &is_control_input)); |
211 | |
212 | auto* input_node = gtl::FindPtrOrNull(node_def_by_name, node_name); |
213 | if (input_node == nullptr) { |
214 | return errors::Internal("Graph node [" , node->name(), "] has input [" , |
215 | node_name, "] that doesn't exist in graph" ); |
216 | } |
217 | |
218 | // If we've already seen this node before, skip it and don't add it to |
219 | // the queue. |
220 | if (visited.contains(node_name)) { |
221 | EdgeRep cycle_edge(node, input_node); |
222 | cycle_forming_edges_.insert(cycle_edge.GetHash()); |
223 | continue; |
224 | } |
225 | if (is_control_input) { |
226 | node_rep.node_control_inputs.push_back(input_node); |
227 | } else { |
228 | node_rep.node_inputs.push_back(std::make_pair(input_node, suffix)); |
229 | bfs_queue.push(input_node); |
230 | } |
231 | } |
232 | nodes_[node] = node_rep; |
233 | } |
234 | return OkStatus(); |
235 | } |
236 | |
237 | Status HashRoot(uint64* hash) { return HashNode(root_, hash); } |
238 | |
239 | Status CheckEqual(GraphHasher* that) { |
240 | return CheckNodesEqual(root_, that, that->root_); |
241 | } |
242 | |
243 | private: |
244 | Status HashNode(const NodeDef* node, uint64* hash) { |
245 | auto it = node_cache_->find(node); |
246 | if (it != node_cache_->end()) { |
247 | *hash = it->second; |
248 | return OkStatus(); |
249 | } |
250 | |
251 | NodeRep* node_rep = gtl::FindOrNull(nodes_, node); |
252 | if (node_rep == nullptr) { |
253 | return errors::InvalidArgument("Could not find node: " , node->name()); |
254 | } |
255 | |
256 | uint64 non_input_hash; |
257 | TF_RETURN_IF_ERROR( |
258 | HashNodeNonInput(node, /*hash_functions=*/true, &non_input_hash)); |
259 | |
260 | uint64 control_inputs_hash; |
261 | TF_RETURN_IF_ERROR( |
262 | HashControlInputs(node_rep->node_control_inputs, &control_inputs_hash)); |
263 | |
264 | // Hash regular inputs. We combine them in an ordered fashion. |
265 | uint64 inputs_hash = 0; |
266 | for (const auto& input : node_rep->node_inputs) { |
267 | uint64 node_hash = 0; |
268 | EdgeRep edge(node, input.first); |
269 | // If the edge was pruned we get the non input node hash to avoid cycles. |
270 | if (cycle_forming_edges_.contains(edge.GetHash())) { |
271 | TF_RETURN_IF_ERROR( |
272 | HashNodeNonInput(input.first, /*hash_functions=*/true, &node_hash)); |
273 | } else { |
274 | TF_RETURN_IF_ERROR(HashNode(input.first, &node_hash)); |
275 | } |
276 | inputs_hash = Hash64Combine( |
277 | inputs_hash, Hash64Combine(node_hash, Hash64(input.second.data(), |
278 | input.second.size()))); |
279 | } |
280 | |
281 | *hash = Hash64Combine(non_input_hash, |
282 | Hash64Combine(control_inputs_hash, inputs_hash)); |
283 | auto result = node_cache_->emplace(node, *hash); |
284 | if (!result.second) { |
285 | return errors::Internal(absl::StrCat("Computed the hash for node " , |
286 | node->DebugString(), " twice!" )); |
287 | } |
288 | return OkStatus(); |
289 | } |
290 | |
291 | Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that, |
292 | const NodeDef* that_node) { |
293 | Status s = CheckNodesEqualHelper(this_node, that, that_node); |
294 | if (!s.ok()) { |
295 | return errors::FailedPrecondition("Nodes " , this_node->name(), " and " , |
296 | that_node->name(), |
297 | " are not the same:\n" , s); |
298 | } |
299 | return s; |
300 | } |
301 | |
302 | Status CheckNodesEqualHelper(const NodeDef* this_node, GraphHasher* that, |
303 | const NodeDef* that_node) { |
304 | TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_node, that, that_node, |
305 | /*compare_functions=*/true)); |
306 | |
307 | TF_RETURN_IF_ERROR( |
308 | CheckControlInputsEqual(nodes_[this_node].node_control_inputs, that, |
309 | that->nodes_[that_node].node_control_inputs)); |
310 | |
311 | auto& this_node_inputs = nodes_[this_node].node_inputs; |
312 | auto& that_node_inputs = that->nodes_[that_node].node_inputs; |
313 | if (this_node_inputs.size() != that_node_inputs.size()) { |
314 | return errors::FailedPrecondition( |
315 | "Nodes have different numbers of node inputs: " , |
316 | this_node_inputs.size(), " vs " , that_node_inputs.size()); |
317 | } |
318 | for (int i = 0; i < this_node_inputs.size(); ++i) { |
319 | const NodeDef* this_input = this_node_inputs[i].first; |
320 | const NodeDef* that_input = that_node_inputs[i].first; |
321 | if (is_cycle_forming_edge(this_node, this_input)) { |
322 | TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_input, that, that_input, |
323 | /*compare_functions=*/true)); |
324 | } else { |
325 | TF_RETURN_IF_ERROR(CheckNodesEqual(this_input, that, that_input)); |
326 | } |
327 | absl::string_view this_input_suffix = this_node_inputs[i].second; |
328 | absl::string_view that_input_suffix = that_node_inputs[i].second; |
329 | if (this_input_suffix != that_input_suffix) { |
330 | return errors::FailedPrecondition( |
331 | "Node inputs " , this_input->name(), " and " , that_input->name(), |
332 | " have different suffixes: " , this_input_suffix, " vs " , |
333 | that_input_suffix); |
334 | } |
335 | } |
336 | return OkStatus(); |
337 | } |
338 | |
339 | Status HashNodeNonInput(const NodeDef* node, bool hash_functions, |
340 | uint64* hash) { |
341 | auto iter = attr_cache_->find(std::make_pair(node, hash_functions)); |
342 | if (iter != attr_cache_->end()) { |
343 | *hash = iter->second; |
344 | return OkStatus(); |
345 | } |
346 | // Hash Attrs. We get the list of attrs from the op registry and then look |
347 | // up their values in the NodeDef attr map. This avoids looping over |
348 | // a map which is non-deterministic. |
349 | uint64 attrs_hash = 0; |
350 | const OpRegistrationData* reg; |
351 | TF_RETURN_IF_ERROR(flib_->LookUp(node->op(), ®)); |
352 | uint64 op_hash = 0; |
353 | if (reg->is_function_op) { |
354 | if (hash_functions) { |
355 | TF_RETURN_IF_ERROR(HashFunction(node->op(), node->attr(), &op_hash)); |
356 | } |
357 | } else { |
358 | op_hash = Hash64(node->op()); |
359 | } |
360 | |
361 | for (const auto& attr : reg->op_def.attr()) { |
362 | const auto& attr_key = attr.name(); |
363 | // Ignore "metadata" attribute of tf.data operations. |
364 | if (DatasetOpKernel::IsDatasetOp(reg->op_def) && attr_key == "metadata" ) |
365 | continue; |
366 | auto node_attr_iter = node->attr().find(attr_key); |
367 | if (node_attr_iter == node->attr().end()) { |
368 | continue; |
369 | } |
370 | const auto& attr_value = node_attr_iter->second; |
371 | if (attr_key == kColocationAttrName || |
372 | attr_key == kColocationGroupPrefix) { |
373 | continue; |
374 | } |
375 | uint64 attr_hash = 0; |
376 | TF_RETURN_IF_ERROR( |
377 | HashAttr(attr_key, attr_value, hash_functions, &attr_hash)); |
378 | attrs_hash = Hash64Combine(attrs_hash, attr_hash); |
379 | } |
380 | |
381 | // Hash Device. |
382 | uint64 device_hash = Hash64(node->device()); |
383 | |
384 | *hash = Hash64Combine(op_hash, Hash64Combine(attrs_hash, device_hash)); |
385 | |
386 | auto result = |
387 | attr_cache_->emplace(std::make_pair(node, hash_functions), *hash); |
388 | if (!result.second) { |
389 | return errors::Internal(absl::StrCat( |
390 | "Computed the hash for non-input node: " , node->DebugString(), |
391 | " and hash function bool: " , hash_functions, "twice!" )); |
392 | } |
393 | return OkStatus(); |
394 | } |
395 | |
396 | Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that, |
397 | const NodeDef* that_node, |
398 | bool compare_functions) { |
399 | // We get the list of attrs from the op registry and then look |
400 | // up their values in the NodeDef attr map. This avoids looping over |
401 | // a map which is non-deterministic. |
402 | const OpRegistrationData* reg; |
403 | TF_RETURN_IF_ERROR(flib_->LookUp(this_node->op(), ®)); |
404 | if (reg->is_function_op) { |
405 | if (compare_functions) { |
406 | TF_RETURN_IF_ERROR( |
407 | CheckFunctionsEqual(this_node->op(), this_node->attr(), that, |
408 | that_node->op(), that_node->attr())); |
409 | } |
410 | } else { |
411 | if (this_node->op() != that_node->op()) { |
412 | return errors::FailedPrecondition( |
413 | "ops for nodes " , this_node->name(), " and " , that_node->name(), |
414 | " are different: " , this_node->op(), " != " , that_node->op()); |
415 | } |
416 | } |
417 | |
418 | for (const auto& attr : reg->op_def.attr()) { |
419 | const auto& attr_key = attr.name(); |
420 | const bool this_has_attr = this_node->attr().contains(attr_key); |
421 | const bool that_has_attr = that_node->attr().contains(attr_key); |
422 | if (this_has_attr != that_has_attr) { |
423 | return errors::FailedPrecondition( |
424 | "attr with key " , attr_key, " is different for nodes " , |
425 | this_node->name(), " and " , that_node->name(), |
426 | ". Present in former: " , this_has_attr, |
427 | ". Present in latter: " , that_has_attr); |
428 | } |
429 | if (!this_has_attr) { |
430 | continue; |
431 | } |
432 | if (attr_key == kColocationAttrName || |
433 | attr_key == kColocationGroupPrefix) { |
434 | continue; |
435 | } |
436 | const auto& this_attr = this_node->attr().at(attr_key); |
437 | const auto& that_attr = that_node->attr().at(attr_key); |
438 | TF_RETURN_IF_ERROR(CheckAttrsEqual(attr_key, this_attr, that, that_attr, |
439 | compare_functions)); |
440 | } |
441 | |
442 | if (this_node->device() != that_node->device()) { |
443 | return errors::FailedPrecondition( |
444 | "Devices are different for nodes " , this_node->name(), " and " , |
445 | that_node->name(), ": " , this_node->device(), " vs " , |
446 | that_node->device()); |
447 | } |
448 | return OkStatus(); |
449 | } |
450 | |
451 | Status HashAttr(const std::string& attr_name, const AttrValue& attr_value, |
452 | bool hash_functions, uint64* hash) { |
453 | uint64 value_hash = 0; |
454 | if (attr_value.has_func()) { |
455 | if (hash_functions) { |
456 | TF_RETURN_IF_ERROR(HashFunction(attr_value.func(), &value_hash)); |
457 | } |
458 | } else if (attr_value.has_list() && attr_value.list().func_size() > 0) { |
459 | if (hash_functions) { |
460 | for (auto& func : attr_value.list().func()) { |
461 | uint64 func_hash; |
462 | TF_RETURN_IF_ERROR(HashFunction(func, &func_hash)); |
463 | value_hash = Hash64Combine(value_hash, func_hash); |
464 | } |
465 | } |
466 | } else { |
467 | value_hash = DeterministicProtoHash64(attr_value); |
468 | } |
469 | *hash = Hash64Combine(Hash64(attr_name), value_hash); |
470 | return OkStatus(); |
471 | } |
472 | |
473 | Status CheckAttrsEqual(const std::string& attr_name, |
474 | const AttrValue& this_attr, GraphHasher* that, |
475 | const AttrValue& that_attr, bool compare_functions) { |
476 | if (this_attr.has_func() != that_attr.has_func()) { |
477 | return errors::FailedPrecondition( |
478 | "AttrValues are of different types: " , this_attr.DebugString(), |
479 | " vs " , that_attr.DebugString()); |
480 | } |
481 | if (this_attr.has_func()) { |
482 | if (compare_functions) { |
483 | TF_RETURN_IF_ERROR( |
484 | CheckFunctionsEqual(this_attr.func(), that, that_attr.func())); |
485 | } |
486 | return OkStatus(); |
487 | } |
488 | if (this_attr.has_list() != that_attr.has_list()) { |
489 | return errors::FailedPrecondition( |
490 | "AttrValues are of different types: " , this_attr.DebugString(), |
491 | " vs " , that_attr.DebugString()); |
492 | } |
493 | if (this_attr.has_list()) { |
494 | if (this_attr.list().func_size() != that_attr.list().func_size()) { |
495 | return errors::FailedPrecondition( |
496 | "AttrValues have func lists of different sizes: " , |
497 | this_attr.DebugString(), " vs " , that_attr.DebugString()); |
498 | } |
499 | if (compare_functions) { |
500 | for (int i = 0; i < this_attr.list().func_size(); ++i) { |
501 | TF_RETURN_IF_ERROR(CheckFunctionsEqual(this_attr.list().func(i), that, |
502 | that_attr.list().func(i))); |
503 | } |
504 | } |
505 | return OkStatus(); |
506 | } |
507 | uint64 this_hash, that_hash; |
508 | TF_RETURN_IF_ERROR( |
509 | HashAttr(attr_name, this_attr, /*hash_functions=*/true, &this_hash)); |
510 | TF_RETURN_IF_ERROR(that->HashAttr(attr_name, that_attr, |
511 | /*hash_functions=*/true, &that_hash)); |
512 | if (this_hash != that_hash) { |
513 | return errors::FailedPrecondition( |
514 | "AttrValues are different: " , this_attr.DebugString(), " vs " , |
515 | that_attr.DebugString()); |
516 | } |
517 | return OkStatus(); |
518 | } |
519 | |
520 | Status HashFunction(const NameAttrList& func, uint64* hash) { |
521 | return HashFunction(func.name(), func.attr(), hash); |
522 | } |
523 | |
524 | Status HashFunction(const std::string& name, const AttrValueMap& attrs, |
525 | uint64* hash) { |
526 | const FunctionDef* fdef = flib_->Find(name); |
527 | auto it = function_cache_->find(fdef); |
528 | if (it != function_cache_->end()) { |
529 | *hash = it->second; |
530 | return OkStatus(); |
531 | } |
532 | |
533 | // Convert to a GraphDef. |
534 | std::unique_ptr<FunctionBody> fbody; |
535 | TF_RETURN_IF_ERROR( |
536 | FunctionDefToBodyHelper(*fdef, AttrSlice(&attrs), flib_, &fbody)); |
537 | GraphDef graph_def = fbody->graph->ToGraphDefDebug(); |
538 | |
539 | // For each return node, we create a new GraphHasher to compute a hash. |
540 | // We then combine these hashes to produce the hash ordered. |
541 | uint64 ret_nodes_hash = 0; |
542 | for (const auto& ret_node : fbody->ret_nodes) { |
543 | uint64 ret_node_hash = 0; |
544 | GraphHasher hasher(&graph_def, &ret_node->def(), flib_, node_cache_, |
545 | function_cache_, attr_cache_); |
546 | TF_RETURN_IF_ERROR(hasher.Init()); |
547 | TF_RETURN_IF_ERROR(hasher.HashRoot(&ret_node_hash)); |
548 | ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash); |
549 | } |
550 | |
551 | std::vector<const NodeDef*> control_rets; |
552 | control_rets.reserve(fbody->control_ret_nodes.size()); |
553 | for (const auto& control_ret_node : fbody->control_ret_nodes) { |
554 | control_rets.push_back(&control_ret_node->def()); |
555 | } |
556 | uint64 control_ret_nodes_hash = 0; |
557 | TF_RETURN_IF_ERROR( |
558 | HashControlInputs(control_rets, &control_ret_nodes_hash)); |
559 | |
560 | *hash = Hash64Combine(ret_nodes_hash, control_ret_nodes_hash); |
561 | auto result = function_cache_->emplace(fdef, *hash); |
562 | if (!result.second) { |
563 | return errors::Internal( |
564 | absl::StrCat("Computed the hash for function " , name, " twice!" )); |
565 | } |
566 | return OkStatus(); |
567 | } |
568 | |
569 | Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that, |
570 | const NameAttrList& that_func) { |
571 | return CheckFunctionsEqual(this_func.name(), this_func.attr(), that, |
572 | that_func.name(), that_func.attr()); |
573 | } |
574 | Status CheckFunctionsEqual(const std::string& this_name, |
575 | const AttrValueMap& this_attrs, GraphHasher* that, |
576 | const std::string& that_name, |
577 | const AttrValueMap& that_attrs) { |
578 | Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name, |
579 | that_attrs); |
580 | if (!s.ok()) { |
581 | return errors::FailedPrecondition("Functions " , this_name, " and " , |
582 | that_name, " are not the same:\n" , s); |
583 | } |
584 | return s; |
585 | } |
586 | |
587 | Status CheckFunctionsEqualHelper(const std::string& this_name, |
588 | const AttrValueMap& this_attrs, |
589 | GraphHasher* that, |
590 | const std::string& that_name, |
591 | const AttrValueMap& that_attrs) { |
592 | const FunctionDef* this_fdef = flib_->Find(this_name); |
593 | const FunctionDef* that_fdef = that->flib_->Find(that_name); |
594 | |
595 | // Convert to GraphDefs. |
596 | std::unique_ptr<FunctionBody> this_fbody; |
597 | TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
598 | *this_fdef, AttrSlice(&this_attrs), flib_, &this_fbody)); |
599 | GraphDef this_graph_def = this_fbody->graph->ToGraphDefDebug(); |
600 | std::unique_ptr<FunctionBody> that_fbody; |
601 | TF_RETURN_IF_ERROR(FunctionDefToBodyHelper( |
602 | *that_fdef, AttrSlice(&that_attrs), that->flib_, &that_fbody)); |
603 | GraphDef that_graph_def = that_fbody->graph->ToGraphDefDebug(); |
604 | |
605 | if (this_fbody->ret_nodes.size() != that_fbody->ret_nodes.size()) { |
606 | return errors::FailedPrecondition( |
607 | "Different numbers of ret nodes for functions " , this_name, " and " , |
608 | that_name, ": " , this_fbody->ret_nodes.size(), " vs " , |
609 | that_fbody->ret_nodes.size()); |
610 | } |
611 | for (int i = 0; i < this_fbody->ret_nodes.size(); ++i) { |
612 | const NodeDef* this_root = &this_fbody->ret_nodes[i]->def(); |
613 | const NodeDef* that_root = &that_fbody->ret_nodes[i]->def(); |
614 | GraphHasher this_hasher(&this_graph_def, this_root, flib_, node_cache_, |
615 | function_cache_, attr_cache_); |
616 | TF_RETURN_IF_ERROR(this_hasher.Init()); |
617 | GraphHasher that_hasher(&that_graph_def, that_root, that->flib_, |
618 | node_cache_, function_cache_, attr_cache_); |
619 | TF_RETURN_IF_ERROR(that_hasher.Init()); |
620 | TF_RETURN_IF_ERROR(this_hasher.CheckEqual(&that_hasher)); |
621 | } |
622 | |
623 | std::vector<const NodeDef*> this_control_rets; |
624 | this_control_rets.reserve(this_fbody->control_ret_nodes.size()); |
625 | for (const auto& control_ret_node : this_fbody->control_ret_nodes) { |
626 | this_control_rets.push_back(&control_ret_node->def()); |
627 | } |
628 | std::vector<const NodeDef*> that_control_rets; |
629 | that_control_rets.reserve(that_fbody->control_ret_nodes.size()); |
630 | for (const auto& control_ret_node : that_fbody->control_ret_nodes) { |
631 | that_control_rets.push_back(&control_ret_node->def()); |
632 | } |
633 | TF_RETURN_IF_ERROR( |
634 | CheckControlInputsEqual(this_control_rets, that, that_control_rets)); |
635 | return OkStatus(); |
636 | } |
637 | |
638 | Status HashControlInputs(const std::vector<const NodeDef*>& inputs, |
639 | uint64* hash) { |
640 | *hash = 0; |
641 | for (const NodeDef* input : inputs) { |
642 | uint64 node_hash = 0; |
643 | TF_RETURN_IF_ERROR( |
644 | HashNodeNonInput(input, /*hash_functions=*/false, &node_hash)); |
645 | *hash = Hash64CombineUnordered(*hash, node_hash); |
646 | } |
647 | return OkStatus(); |
648 | } |
649 | |
650 | Status CheckControlInputsEqual( |
651 | const std::vector<const NodeDef*>& this_inputs, GraphHasher* that, |
652 | const std::vector<const NodeDef*>& that_inputs) { |
653 | absl::flat_hash_map<uint64, const NodeDef*> this_hashes; |
654 | for (const NodeDef* input : this_inputs) { |
655 | uint64 node_hash = 0; |
656 | TF_RETURN_IF_ERROR( |
657 | HashNodeNonInput(input, /*hash_functions=*/false, &node_hash)); |
658 | this_hashes[node_hash] = input; |
659 | } |
660 | absl::flat_hash_map<uint64, const NodeDef*> that_hashes; |
661 | for (const NodeDef* input : that_inputs) { |
662 | uint64 node_hash = 0; |
663 | TF_RETURN_IF_ERROR( |
664 | HashNodeNonInput(input, /*hash_functions=*/false, &node_hash)); |
665 | auto this_iter = this_hashes.find(node_hash); |
666 | if (this_iter != this_hashes.end()) { |
667 | this_hashes.erase(this_iter); |
668 | } else { |
669 | that_hashes[node_hash] = input; |
670 | } |
671 | } |
672 | if (!this_hashes.empty()) { |
673 | auto formatter = [](string* out, |
674 | const decltype(this_hashes)::value_type& item) { |
675 | out->append(item.second->name()); |
676 | }; |
677 | return errors::FailedPrecondition( |
678 | "Control dependencies are different. One node has dependencies [" , |
679 | absl::StrJoin(this_hashes, ", " , formatter), |
680 | "], which don't match any of the other node's dependencies [" , |
681 | absl::StrJoin(that_hashes, ", " , formatter), "]" ); |
682 | } |
683 | return OkStatus(); |
684 | } |
685 | |
686 | private: |
687 | bool is_cycle_forming_edge(const NodeDef* start, const NodeDef* end) { |
688 | EdgeRep edge(start, end); |
689 | return cycle_forming_edges_.contains(edge.GetHash()); |
690 | } |
691 | |
692 | struct NodeRep { |
693 | std::vector<const NodeDef*> node_control_inputs; |
694 | std::vector<std::pair<const NodeDef*, absl::string_view>> node_inputs; |
695 | }; |
696 | |
697 | struct EdgeRep { |
698 | const NodeDef* start_node; |
699 | const NodeDef* end_node; |
700 | |
701 | EdgeRep(const NodeDef* start, const NodeDef* end) |
702 | : start_node(start), end_node(end) {} |
703 | |
704 | uint64 GetHash() { |
705 | return Hash64Combine(absl::Hash<const NodeDef*>()(start_node), |
706 | absl::Hash<const NodeDef*>()(end_node)); |
707 | } |
708 | }; |
709 | const GraphDef* const graph_; // Not owned. |
710 | const NodeDef* const root_; // Not owned. |
711 | const FunctionLibraryDefinition* const flib_; // Not owned. |
712 | // Edges that need to be pruned as their presence will cause cycles. |
713 | absl::flat_hash_set<uint64> cycle_forming_edges_; |
714 | absl::flat_hash_map<const NodeDef*, NodeRep> nodes_; |
715 | std::shared_ptr<NodeCache> node_cache_; |
716 | std::shared_ptr<FunctionCache> function_cache_; |
717 | std::shared_ptr<AttrCache> attr_cache_; |
718 | }; |
719 | |
720 | } // anonymous namespace |
721 | |
722 | Status HashTensor(const Tensor& tensor, uint64* hash) { |
723 | const tstring* s = nullptr; |
724 | // Hash tensor type. |
725 | *hash = Hash64Combine(0, tensor.dtype()); |
726 | // Hash tensor shape. |
727 | for (int i = 0; i < tensor.shape().dims(); ++i) { |
728 | *hash = Hash64Combine(*hash, tensor.shape().dim_size(i)); |
729 | } |
730 | // Hash tensor data. |
731 | switch (tensor.dtype()) { |
732 | case DT_RESOURCE: |
733 | case DT_VARIANT: |
734 | return errors::Unimplemented("Hashing " , DataTypeString(tensor.dtype()), |
735 | " is not supported." ); |
736 | case DT_STRING: |
737 | s = tensor.flat<tstring>().data(); |
738 | for (int i = 0; i < tensor.NumElements(); ++i, ++s) { |
739 | *hash = Hash64Combine(*hash, Hash64(s->data(), s->size())); |
740 | } |
741 | break; |
742 | default: |
743 | *hash = Hash64(tensor.tensor_data().data(), tensor.tensor_data().size()); |
744 | } |
745 | return OkStatus(); |
746 | } |
747 | |
748 | Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) { |
749 | const FunctionLibraryDefinition flib_def(OpRegistry::Global(), |
750 | graph.library()); |
751 | return HashNode(graph, node, flib_def, hash); |
752 | } |
753 | |
754 | Status HashNode(const GraphDef& graph, const NodeDef& node, |
755 | const FunctionLibraryDefinition& flib_def, uint64* hash) { |
756 | GraphHasher hasher(&graph, &node, &flib_def); |
757 | TF_RETURN_IF_ERROR(hasher.Init()); |
758 | return hasher.HashRoot(hash); |
759 | } |
760 | |
761 | Status HashGraph(const GraphDef& graph_def, uint64* hash) { |
762 | const NodeDef* sink = nullptr; |
763 | TF_RETURN_IF_ERROR(GetSink(graph_def, &sink)); |
764 | return HashNode(graph_def, *sink, hash); |
765 | } |
766 | |
767 | Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) { |
768 | const NodeDef* sink_a; |
769 | TF_RETURN_IF_ERROR(GetSink(a, &sink_a)); |
770 | const NodeDef* sink_b; |
771 | TF_RETURN_IF_ERROR(GetSink(b, &sink_b)); |
772 | return CheckSubgraphsEqual(a, sink_a, b, sink_b); |
773 | } |
774 | |
775 | Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a, |
776 | const GraphDef& b, const NodeDef* node_b) { |
777 | const FunctionLibraryDefinition flib_def_a(OpRegistry::Global(), a.library()); |
778 | GraphHasher hasher_a(&a, node_a, &flib_def_a); |
779 | TF_RETURN_IF_ERROR(hasher_a.Init()); |
780 | |
781 | const FunctionLibraryDefinition flib_def_b(OpRegistry::Global(), b.library()); |
782 | GraphHasher hasher_b(&b, node_b, &flib_def_b); |
783 | TF_RETURN_IF_ERROR(hasher_b.Init()); |
784 | |
785 | return hasher_a.CheckEqual(&hasher_b); |
786 | } |
787 | |
788 | } // namespace data |
789 | } // namespace tensorflow |
790 | |