1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/data/hash_utils.h"
16
17#include <queue>
18
19#include "absl/container/flat_hash_map.h"
20#include "absl/container/flat_hash_set.h"
21#include "absl/strings/str_cat.h"
22#include "absl/strings/str_join.h"
23#include "tensorflow/core/common_runtime/function.h"
24#include "tensorflow/core/data/dataset_utils.h"
25#include "tensorflow/core/framework/attr_value.pb.h"
26#include "tensorflow/core/framework/dataset.h"
27#include "tensorflow/core/framework/function.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/framework/op.h"
30#include "tensorflow/core/framework/op_def.pb.h"
31#include "tensorflow/core/framework/op_def_builder.h"
32#include "tensorflow/core/framework/op_def_util.h"
33#include "tensorflow/core/framework/op_kernel.h"
34#include "tensorflow/core/framework/tensor.pb.h"
35#include "tensorflow/core/framework/types.h"
36#include "tensorflow/core/graph/graph_def_builder.h"
37#include "tensorflow/core/lib/core/errors.h"
38#include "tensorflow/core/lib/hash/hash.h"
39#include "tensorflow/core/lib/strings/proto_serialization.h"
40#include "tensorflow/core/platform/errors.h"
41#include "tensorflow/core/platform/regexp.h"
42#include "tensorflow/core/platform/status.h"
43#include "tensorflow/core/util/work_sharder.h"
44
45namespace tensorflow {
46namespace data {
47namespace {
48
49// clang-format off
50constexpr std::array<const char*, 3> kOpsWithSeed = {
51 "AnonymousRandomSeedGenerator",
52 "ShuffleDataset",
53 "ShuffleAndRepeatDataset"
54};
55// clang-format on
56constexpr char kSeedInputName[] = "seed";
57constexpr char kSeed2InputName[] = "seed2";
58constexpr char kSeedGeneratorInputName[] = "seed_generator";
59
60template <std::size_t SIZE>
61bool IsNodeOfType(const NodeDef& node,
62 const std::array<const char*, SIZE>& op_types) {
63 for (const auto& type : op_types) {
64 if (MatchesAnyVersion(type, node.op())) {
65 return true;
66 }
67 }
68 return false;
69}
70
71Status GetSink(const GraphDef& graph_def, const NodeDef** sink) {
72 for (auto& node : graph_def.node()) {
73 if (node.op() == "_Retval") {
74 *sink = &node;
75 break;
76 }
77 }
78
79 if (sink == nullptr) {
80 return errors::Internal("Cannot find sink node for dataset graph.");
81 }
82 return OkStatus();
83}
84
85Status ShouldIgnoreInput(const NodeDef& node, int i, bool* result) {
86 *result = false;
87 if (IsNodeOfType(node, kOpsWithSeed)) {
88 const OpRegistrationData* reg;
89 auto status = OpRegistry::Global()->LookUp(node.op(), &reg);
90
91 if (status.ok()) {
92 if (reg->op_def.input_arg_size() > i) {
93 const std::string input_arg_name = reg->op_def.input_arg(i).name();
94 if (input_arg_name == kSeedInputName ||
95 input_arg_name == kSeed2InputName ||
96 input_arg_name == kSeedGeneratorInputName) {
97 VLOG(2) << "Ignoring arg: " << input_arg_name
98 << " from node: " << node.name();
99 *result = true;
100 return OkStatus();
101 }
102 }
103 } else if (errors::IsNotFound(status)) {
104 LOG(WARNING) << "Cannot find " << node.op()
105 << " in global op registry, so cannot determine which "
106 "inputs are seeds.";
107 } else {
108 return status;
109 }
110 }
111 return OkStatus();
112}
113
114Status ParseInputNodeName(absl::string_view input_name,
115 absl::string_view* node_name,
116 absl::string_view* suffix, bool* is_control_input) {
117 if (input_name[0] == '^') {
118 *node_name = input_name.substr(1);
119 *is_control_input = true;
120 return OkStatus();
121 }
122 std::pair<absl::string_view, absl::string_view> node_spec =
123 absl::StrSplit(input_name, absl::MaxSplits(':', 1));
124 *node_name = node_spec.first;
125 *suffix = node_spec.second;
126 *is_control_input = false;
127 return OkStatus();
128}
129
130// Given a graph_def and a root_node, this class computes a fingerprint that
131// tries to capture the structure of the graph rooted at the provided node.
132// It does not at any point rely on the names of the nodes in the graph and
133// just relies on the connections between different nodes. In the presence of
134// multiple cycles in the graph, there is a non-zero possibility that two
135// graphs with different structure might end up with the same fingerprint
136// as in order to break cycles we prune away some edges (in a deterministic
137// fashion though). Idea for this algorithm was borrowed from:
138// https://stackoverflow.com/questions/11338746/directed-graphs-with-a-given-root-node-match-another-directed-graph-for-equali
139class GraphHasher {
140 using NodeCache = absl::flat_hash_map<const NodeDef*, uint64>;
141 using FunctionCache = absl::flat_hash_map<const FunctionDef*, uint64>;
142 using AttrCache =
143 absl::flat_hash_map<std::pair<const NodeDef*, bool>, uint64>;
144
145 public:
146 // `GraphHasher` does not take ownership of `graph_def`, `root_node`, or
147 // `flib_def`.
148 explicit GraphHasher(const GraphDef* graph, const NodeDef* root,
149 const FunctionLibraryDefinition* flib)
150 : graph_(graph), root_(root), flib_(flib) {
151 node_cache_ = std::make_shared<NodeCache>();
152 function_cache_ = std::make_shared<FunctionCache>();
153 attr_cache_ = std::make_shared<AttrCache>();
154 }
155 explicit GraphHasher(const GraphDef* graph, const NodeDef* root,
156 const FunctionLibraryDefinition* flib,
157 std::shared_ptr<NodeCache> node_cache,
158 std::shared_ptr<FunctionCache> function_cache,
159 std::shared_ptr<AttrCache> attr_cache)
160 : graph_(graph),
161 root_(root),
162 flib_(flib),
163 node_cache_(node_cache),
164 function_cache_(function_cache),
165 attr_cache_(attr_cache) {}
166
167 Status Init() {
168 // Construct a map of name -> NodeDef to avoid repeated linear searches.
169 absl::flat_hash_map<absl::string_view, const NodeDef*> node_def_by_name;
170 node_def_by_name.reserve(graph_->node_size());
171 for (const auto& node : graph_->node()) {
172 auto result = node_def_by_name.emplace(node.name(), &node);
173 if (TF_PREDICT_FALSE(!result.second)) {
174 auto node_name_formatter =
175 [](std::string* out,
176 const decltype(node_def_by_name)::value_type& item) {
177 absl::StrAppend(out, "'", item.first, "'");
178 };
179 return errors::Internal(
180 "Encountered graph with duplicate node name '", node.name(),
181 "' in [", absl::StrJoin(node_def_by_name, ",", node_name_formatter),
182 "]");
183 }
184 }
185 // Pre-process the graph to do a BFS and prune away cycles that might cause
186 // problems.
187 absl::flat_hash_set<absl::string_view> visited;
188 std::queue<const NodeDef*> bfs_queue;
189 bfs_queue.push(root_);
190 while (!bfs_queue.empty()) {
191 const NodeDef* node = bfs_queue.front();
192 bfs_queue.pop();
193 if (visited.contains(node->name())) {
194 continue;
195 }
196 visited.insert(node->name());
197 NodeRep node_rep;
198 for (int i = 0; i < node->input_size(); ++i) {
199 DCHECK_GT(node->input(i).length(), 0);
200
201 // We skip trying to take the hash of the seeds of any ops, as they
202 // are irrelevant to the hash of the graph and may vary from run to run.
203 bool should_ignore_input = false;
204 TF_RETURN_IF_ERROR(ShouldIgnoreInput(*node, i, &should_ignore_input));
205 if (should_ignore_input) continue;
206
207 absl::string_view node_name, suffix;
208 bool is_control_input;
209 TF_RETURN_IF_ERROR(ParseInputNodeName(node->input(i), &node_name,
210 &suffix, &is_control_input));
211
212 auto* input_node = gtl::FindPtrOrNull(node_def_by_name, node_name);
213 if (input_node == nullptr) {
214 return errors::Internal("Graph node [", node->name(), "] has input [",
215 node_name, "] that doesn't exist in graph");
216 }
217
218 // If we've already seen this node before, skip it and don't add it to
219 // the queue.
220 if (visited.contains(node_name)) {
221 EdgeRep cycle_edge(node, input_node);
222 cycle_forming_edges_.insert(cycle_edge.GetHash());
223 continue;
224 }
225 if (is_control_input) {
226 node_rep.node_control_inputs.push_back(input_node);
227 } else {
228 node_rep.node_inputs.push_back(std::make_pair(input_node, suffix));
229 bfs_queue.push(input_node);
230 }
231 }
232 nodes_[node] = node_rep;
233 }
234 return OkStatus();
235 }
236
237 Status HashRoot(uint64* hash) { return HashNode(root_, hash); }
238
239 Status CheckEqual(GraphHasher* that) {
240 return CheckNodesEqual(root_, that, that->root_);
241 }
242
243 private:
244 Status HashNode(const NodeDef* node, uint64* hash) {
245 auto it = node_cache_->find(node);
246 if (it != node_cache_->end()) {
247 *hash = it->second;
248 return OkStatus();
249 }
250
251 NodeRep* node_rep = gtl::FindOrNull(nodes_, node);
252 if (node_rep == nullptr) {
253 return errors::InvalidArgument("Could not find node: ", node->name());
254 }
255
256 uint64 non_input_hash;
257 TF_RETURN_IF_ERROR(
258 HashNodeNonInput(node, /*hash_functions=*/true, &non_input_hash));
259
260 uint64 control_inputs_hash;
261 TF_RETURN_IF_ERROR(
262 HashControlInputs(node_rep->node_control_inputs, &control_inputs_hash));
263
264 // Hash regular inputs. We combine them in an ordered fashion.
265 uint64 inputs_hash = 0;
266 for (const auto& input : node_rep->node_inputs) {
267 uint64 node_hash = 0;
268 EdgeRep edge(node, input.first);
269 // If the edge was pruned we get the non input node hash to avoid cycles.
270 if (cycle_forming_edges_.contains(edge.GetHash())) {
271 TF_RETURN_IF_ERROR(
272 HashNodeNonInput(input.first, /*hash_functions=*/true, &node_hash));
273 } else {
274 TF_RETURN_IF_ERROR(HashNode(input.first, &node_hash));
275 }
276 inputs_hash = Hash64Combine(
277 inputs_hash, Hash64Combine(node_hash, Hash64(input.second.data(),
278 input.second.size())));
279 }
280
281 *hash = Hash64Combine(non_input_hash,
282 Hash64Combine(control_inputs_hash, inputs_hash));
283 auto result = node_cache_->emplace(node, *hash);
284 if (!result.second) {
285 return errors::Internal(absl::StrCat("Computed the hash for node ",
286 node->DebugString(), " twice!"));
287 }
288 return OkStatus();
289 }
290
291 Status CheckNodesEqual(const NodeDef* this_node, GraphHasher* that,
292 const NodeDef* that_node) {
293 Status s = CheckNodesEqualHelper(this_node, that, that_node);
294 if (!s.ok()) {
295 return errors::FailedPrecondition("Nodes ", this_node->name(), " and ",
296 that_node->name(),
297 " are not the same:\n", s);
298 }
299 return s;
300 }
301
302 Status CheckNodesEqualHelper(const NodeDef* this_node, GraphHasher* that,
303 const NodeDef* that_node) {
304 TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_node, that, that_node,
305 /*compare_functions=*/true));
306
307 TF_RETURN_IF_ERROR(
308 CheckControlInputsEqual(nodes_[this_node].node_control_inputs, that,
309 that->nodes_[that_node].node_control_inputs));
310
311 auto& this_node_inputs = nodes_[this_node].node_inputs;
312 auto& that_node_inputs = that->nodes_[that_node].node_inputs;
313 if (this_node_inputs.size() != that_node_inputs.size()) {
314 return errors::FailedPrecondition(
315 "Nodes have different numbers of node inputs: ",
316 this_node_inputs.size(), " vs ", that_node_inputs.size());
317 }
318 for (int i = 0; i < this_node_inputs.size(); ++i) {
319 const NodeDef* this_input = this_node_inputs[i].first;
320 const NodeDef* that_input = that_node_inputs[i].first;
321 if (is_cycle_forming_edge(this_node, this_input)) {
322 TF_RETURN_IF_ERROR(CheckNodesEqualNonInput(this_input, that, that_input,
323 /*compare_functions=*/true));
324 } else {
325 TF_RETURN_IF_ERROR(CheckNodesEqual(this_input, that, that_input));
326 }
327 absl::string_view this_input_suffix = this_node_inputs[i].second;
328 absl::string_view that_input_suffix = that_node_inputs[i].second;
329 if (this_input_suffix != that_input_suffix) {
330 return errors::FailedPrecondition(
331 "Node inputs ", this_input->name(), " and ", that_input->name(),
332 " have different suffixes: ", this_input_suffix, " vs ",
333 that_input_suffix);
334 }
335 }
336 return OkStatus();
337 }
338
339 Status HashNodeNonInput(const NodeDef* node, bool hash_functions,
340 uint64* hash) {
341 auto iter = attr_cache_->find(std::make_pair(node, hash_functions));
342 if (iter != attr_cache_->end()) {
343 *hash = iter->second;
344 return OkStatus();
345 }
346 // Hash Attrs. We get the list of attrs from the op registry and then look
347 // up their values in the NodeDef attr map. This avoids looping over
348 // a map which is non-deterministic.
349 uint64 attrs_hash = 0;
350 const OpRegistrationData* reg;
351 TF_RETURN_IF_ERROR(flib_->LookUp(node->op(), &reg));
352 uint64 op_hash = 0;
353 if (reg->is_function_op) {
354 if (hash_functions) {
355 TF_RETURN_IF_ERROR(HashFunction(node->op(), node->attr(), &op_hash));
356 }
357 } else {
358 op_hash = Hash64(node->op());
359 }
360
361 for (const auto& attr : reg->op_def.attr()) {
362 const auto& attr_key = attr.name();
363 // Ignore "metadata" attribute of tf.data operations.
364 if (DatasetOpKernel::IsDatasetOp(reg->op_def) && attr_key == "metadata")
365 continue;
366 auto node_attr_iter = node->attr().find(attr_key);
367 if (node_attr_iter == node->attr().end()) {
368 continue;
369 }
370 const auto& attr_value = node_attr_iter->second;
371 if (attr_key == kColocationAttrName ||
372 attr_key == kColocationGroupPrefix) {
373 continue;
374 }
375 uint64 attr_hash = 0;
376 TF_RETURN_IF_ERROR(
377 HashAttr(attr_key, attr_value, hash_functions, &attr_hash));
378 attrs_hash = Hash64Combine(attrs_hash, attr_hash);
379 }
380
381 // Hash Device.
382 uint64 device_hash = Hash64(node->device());
383
384 *hash = Hash64Combine(op_hash, Hash64Combine(attrs_hash, device_hash));
385
386 auto result =
387 attr_cache_->emplace(std::make_pair(node, hash_functions), *hash);
388 if (!result.second) {
389 return errors::Internal(absl::StrCat(
390 "Computed the hash for non-input node: ", node->DebugString(),
391 " and hash function bool: ", hash_functions, "twice!"));
392 }
393 return OkStatus();
394 }
395
396 Status CheckNodesEqualNonInput(const NodeDef* this_node, GraphHasher* that,
397 const NodeDef* that_node,
398 bool compare_functions) {
399 // We get the list of attrs from the op registry and then look
400 // up their values in the NodeDef attr map. This avoids looping over
401 // a map which is non-deterministic.
402 const OpRegistrationData* reg;
403 TF_RETURN_IF_ERROR(flib_->LookUp(this_node->op(), &reg));
404 if (reg->is_function_op) {
405 if (compare_functions) {
406 TF_RETURN_IF_ERROR(
407 CheckFunctionsEqual(this_node->op(), this_node->attr(), that,
408 that_node->op(), that_node->attr()));
409 }
410 } else {
411 if (this_node->op() != that_node->op()) {
412 return errors::FailedPrecondition(
413 "ops for nodes ", this_node->name(), " and ", that_node->name(),
414 " are different: ", this_node->op(), " != ", that_node->op());
415 }
416 }
417
418 for (const auto& attr : reg->op_def.attr()) {
419 const auto& attr_key = attr.name();
420 const bool this_has_attr = this_node->attr().contains(attr_key);
421 const bool that_has_attr = that_node->attr().contains(attr_key);
422 if (this_has_attr != that_has_attr) {
423 return errors::FailedPrecondition(
424 "attr with key ", attr_key, " is different for nodes ",
425 this_node->name(), " and ", that_node->name(),
426 ". Present in former: ", this_has_attr,
427 ". Present in latter: ", that_has_attr);
428 }
429 if (!this_has_attr) {
430 continue;
431 }
432 if (attr_key == kColocationAttrName ||
433 attr_key == kColocationGroupPrefix) {
434 continue;
435 }
436 const auto& this_attr = this_node->attr().at(attr_key);
437 const auto& that_attr = that_node->attr().at(attr_key);
438 TF_RETURN_IF_ERROR(CheckAttrsEqual(attr_key, this_attr, that, that_attr,
439 compare_functions));
440 }
441
442 if (this_node->device() != that_node->device()) {
443 return errors::FailedPrecondition(
444 "Devices are different for nodes ", this_node->name(), " and ",
445 that_node->name(), ": ", this_node->device(), " vs ",
446 that_node->device());
447 }
448 return OkStatus();
449 }
450
451 Status HashAttr(const std::string& attr_name, const AttrValue& attr_value,
452 bool hash_functions, uint64* hash) {
453 uint64 value_hash = 0;
454 if (attr_value.has_func()) {
455 if (hash_functions) {
456 TF_RETURN_IF_ERROR(HashFunction(attr_value.func(), &value_hash));
457 }
458 } else if (attr_value.has_list() && attr_value.list().func_size() > 0) {
459 if (hash_functions) {
460 for (auto& func : attr_value.list().func()) {
461 uint64 func_hash;
462 TF_RETURN_IF_ERROR(HashFunction(func, &func_hash));
463 value_hash = Hash64Combine(value_hash, func_hash);
464 }
465 }
466 } else {
467 value_hash = DeterministicProtoHash64(attr_value);
468 }
469 *hash = Hash64Combine(Hash64(attr_name), value_hash);
470 return OkStatus();
471 }
472
473 Status CheckAttrsEqual(const std::string& attr_name,
474 const AttrValue& this_attr, GraphHasher* that,
475 const AttrValue& that_attr, bool compare_functions) {
476 if (this_attr.has_func() != that_attr.has_func()) {
477 return errors::FailedPrecondition(
478 "AttrValues are of different types: ", this_attr.DebugString(),
479 " vs ", that_attr.DebugString());
480 }
481 if (this_attr.has_func()) {
482 if (compare_functions) {
483 TF_RETURN_IF_ERROR(
484 CheckFunctionsEqual(this_attr.func(), that, that_attr.func()));
485 }
486 return OkStatus();
487 }
488 if (this_attr.has_list() != that_attr.has_list()) {
489 return errors::FailedPrecondition(
490 "AttrValues are of different types: ", this_attr.DebugString(),
491 " vs ", that_attr.DebugString());
492 }
493 if (this_attr.has_list()) {
494 if (this_attr.list().func_size() != that_attr.list().func_size()) {
495 return errors::FailedPrecondition(
496 "AttrValues have func lists of different sizes: ",
497 this_attr.DebugString(), " vs ", that_attr.DebugString());
498 }
499 if (compare_functions) {
500 for (int i = 0; i < this_attr.list().func_size(); ++i) {
501 TF_RETURN_IF_ERROR(CheckFunctionsEqual(this_attr.list().func(i), that,
502 that_attr.list().func(i)));
503 }
504 }
505 return OkStatus();
506 }
507 uint64 this_hash, that_hash;
508 TF_RETURN_IF_ERROR(
509 HashAttr(attr_name, this_attr, /*hash_functions=*/true, &this_hash));
510 TF_RETURN_IF_ERROR(that->HashAttr(attr_name, that_attr,
511 /*hash_functions=*/true, &that_hash));
512 if (this_hash != that_hash) {
513 return errors::FailedPrecondition(
514 "AttrValues are different: ", this_attr.DebugString(), " vs ",
515 that_attr.DebugString());
516 }
517 return OkStatus();
518 }
519
520 Status HashFunction(const NameAttrList& func, uint64* hash) {
521 return HashFunction(func.name(), func.attr(), hash);
522 }
523
524 Status HashFunction(const std::string& name, const AttrValueMap& attrs,
525 uint64* hash) {
526 const FunctionDef* fdef = flib_->Find(name);
527 auto it = function_cache_->find(fdef);
528 if (it != function_cache_->end()) {
529 *hash = it->second;
530 return OkStatus();
531 }
532
533 // Convert to a GraphDef.
534 std::unique_ptr<FunctionBody> fbody;
535 TF_RETURN_IF_ERROR(
536 FunctionDefToBodyHelper(*fdef, AttrSlice(&attrs), flib_, &fbody));
537 GraphDef graph_def = fbody->graph->ToGraphDefDebug();
538
539 // For each return node, we create a new GraphHasher to compute a hash.
540 // We then combine these hashes to produce the hash ordered.
541 uint64 ret_nodes_hash = 0;
542 for (const auto& ret_node : fbody->ret_nodes) {
543 uint64 ret_node_hash = 0;
544 GraphHasher hasher(&graph_def, &ret_node->def(), flib_, node_cache_,
545 function_cache_, attr_cache_);
546 TF_RETURN_IF_ERROR(hasher.Init());
547 TF_RETURN_IF_ERROR(hasher.HashRoot(&ret_node_hash));
548 ret_nodes_hash = Hash64Combine(ret_nodes_hash, ret_node_hash);
549 }
550
551 std::vector<const NodeDef*> control_rets;
552 control_rets.reserve(fbody->control_ret_nodes.size());
553 for (const auto& control_ret_node : fbody->control_ret_nodes) {
554 control_rets.push_back(&control_ret_node->def());
555 }
556 uint64 control_ret_nodes_hash = 0;
557 TF_RETURN_IF_ERROR(
558 HashControlInputs(control_rets, &control_ret_nodes_hash));
559
560 *hash = Hash64Combine(ret_nodes_hash, control_ret_nodes_hash);
561 auto result = function_cache_->emplace(fdef, *hash);
562 if (!result.second) {
563 return errors::Internal(
564 absl::StrCat("Computed the hash for function ", name, " twice!"));
565 }
566 return OkStatus();
567 }
568
569 Status CheckFunctionsEqual(const NameAttrList& this_func, GraphHasher* that,
570 const NameAttrList& that_func) {
571 return CheckFunctionsEqual(this_func.name(), this_func.attr(), that,
572 that_func.name(), that_func.attr());
573 }
574 Status CheckFunctionsEqual(const std::string& this_name,
575 const AttrValueMap& this_attrs, GraphHasher* that,
576 const std::string& that_name,
577 const AttrValueMap& that_attrs) {
578 Status s = CheckFunctionsEqualHelper(this_name, this_attrs, that, that_name,
579 that_attrs);
580 if (!s.ok()) {
581 return errors::FailedPrecondition("Functions ", this_name, " and ",
582 that_name, " are not the same:\n", s);
583 }
584 return s;
585 }
586
587 Status CheckFunctionsEqualHelper(const std::string& this_name,
588 const AttrValueMap& this_attrs,
589 GraphHasher* that,
590 const std::string& that_name,
591 const AttrValueMap& that_attrs) {
592 const FunctionDef* this_fdef = flib_->Find(this_name);
593 const FunctionDef* that_fdef = that->flib_->Find(that_name);
594
595 // Convert to GraphDefs.
596 std::unique_ptr<FunctionBody> this_fbody;
597 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
598 *this_fdef, AttrSlice(&this_attrs), flib_, &this_fbody));
599 GraphDef this_graph_def = this_fbody->graph->ToGraphDefDebug();
600 std::unique_ptr<FunctionBody> that_fbody;
601 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
602 *that_fdef, AttrSlice(&that_attrs), that->flib_, &that_fbody));
603 GraphDef that_graph_def = that_fbody->graph->ToGraphDefDebug();
604
605 if (this_fbody->ret_nodes.size() != that_fbody->ret_nodes.size()) {
606 return errors::FailedPrecondition(
607 "Different numbers of ret nodes for functions ", this_name, " and ",
608 that_name, ": ", this_fbody->ret_nodes.size(), " vs ",
609 that_fbody->ret_nodes.size());
610 }
611 for (int i = 0; i < this_fbody->ret_nodes.size(); ++i) {
612 const NodeDef* this_root = &this_fbody->ret_nodes[i]->def();
613 const NodeDef* that_root = &that_fbody->ret_nodes[i]->def();
614 GraphHasher this_hasher(&this_graph_def, this_root, flib_, node_cache_,
615 function_cache_, attr_cache_);
616 TF_RETURN_IF_ERROR(this_hasher.Init());
617 GraphHasher that_hasher(&that_graph_def, that_root, that->flib_,
618 node_cache_, function_cache_, attr_cache_);
619 TF_RETURN_IF_ERROR(that_hasher.Init());
620 TF_RETURN_IF_ERROR(this_hasher.CheckEqual(&that_hasher));
621 }
622
623 std::vector<const NodeDef*> this_control_rets;
624 this_control_rets.reserve(this_fbody->control_ret_nodes.size());
625 for (const auto& control_ret_node : this_fbody->control_ret_nodes) {
626 this_control_rets.push_back(&control_ret_node->def());
627 }
628 std::vector<const NodeDef*> that_control_rets;
629 that_control_rets.reserve(that_fbody->control_ret_nodes.size());
630 for (const auto& control_ret_node : that_fbody->control_ret_nodes) {
631 that_control_rets.push_back(&control_ret_node->def());
632 }
633 TF_RETURN_IF_ERROR(
634 CheckControlInputsEqual(this_control_rets, that, that_control_rets));
635 return OkStatus();
636 }
637
638 Status HashControlInputs(const std::vector<const NodeDef*>& inputs,
639 uint64* hash) {
640 *hash = 0;
641 for (const NodeDef* input : inputs) {
642 uint64 node_hash = 0;
643 TF_RETURN_IF_ERROR(
644 HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
645 *hash = Hash64CombineUnordered(*hash, node_hash);
646 }
647 return OkStatus();
648 }
649
650 Status CheckControlInputsEqual(
651 const std::vector<const NodeDef*>& this_inputs, GraphHasher* that,
652 const std::vector<const NodeDef*>& that_inputs) {
653 absl::flat_hash_map<uint64, const NodeDef*> this_hashes;
654 for (const NodeDef* input : this_inputs) {
655 uint64 node_hash = 0;
656 TF_RETURN_IF_ERROR(
657 HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
658 this_hashes[node_hash] = input;
659 }
660 absl::flat_hash_map<uint64, const NodeDef*> that_hashes;
661 for (const NodeDef* input : that_inputs) {
662 uint64 node_hash = 0;
663 TF_RETURN_IF_ERROR(
664 HashNodeNonInput(input, /*hash_functions=*/false, &node_hash));
665 auto this_iter = this_hashes.find(node_hash);
666 if (this_iter != this_hashes.end()) {
667 this_hashes.erase(this_iter);
668 } else {
669 that_hashes[node_hash] = input;
670 }
671 }
672 if (!this_hashes.empty()) {
673 auto formatter = [](string* out,
674 const decltype(this_hashes)::value_type& item) {
675 out->append(item.second->name());
676 };
677 return errors::FailedPrecondition(
678 "Control dependencies are different. One node has dependencies [",
679 absl::StrJoin(this_hashes, ", ", formatter),
680 "], which don't match any of the other node's dependencies [",
681 absl::StrJoin(that_hashes, ", ", formatter), "]");
682 }
683 return OkStatus();
684 }
685
686 private:
687 bool is_cycle_forming_edge(const NodeDef* start, const NodeDef* end) {
688 EdgeRep edge(start, end);
689 return cycle_forming_edges_.contains(edge.GetHash());
690 }
691
692 struct NodeRep {
693 std::vector<const NodeDef*> node_control_inputs;
694 std::vector<std::pair<const NodeDef*, absl::string_view>> node_inputs;
695 };
696
697 struct EdgeRep {
698 const NodeDef* start_node;
699 const NodeDef* end_node;
700
701 EdgeRep(const NodeDef* start, const NodeDef* end)
702 : start_node(start), end_node(end) {}
703
704 uint64 GetHash() {
705 return Hash64Combine(absl::Hash<const NodeDef*>()(start_node),
706 absl::Hash<const NodeDef*>()(end_node));
707 }
708 };
709 const GraphDef* const graph_; // Not owned.
710 const NodeDef* const root_; // Not owned.
711 const FunctionLibraryDefinition* const flib_; // Not owned.
712 // Edges that need to be pruned as their presence will cause cycles.
713 absl::flat_hash_set<uint64> cycle_forming_edges_;
714 absl::flat_hash_map<const NodeDef*, NodeRep> nodes_;
715 std::shared_ptr<NodeCache> node_cache_;
716 std::shared_ptr<FunctionCache> function_cache_;
717 std::shared_ptr<AttrCache> attr_cache_;
718};
719
720} // anonymous namespace
721
722Status HashTensor(const Tensor& tensor, uint64* hash) {
723 const tstring* s = nullptr;
724 // Hash tensor type.
725 *hash = Hash64Combine(0, tensor.dtype());
726 // Hash tensor shape.
727 for (int i = 0; i < tensor.shape().dims(); ++i) {
728 *hash = Hash64Combine(*hash, tensor.shape().dim_size(i));
729 }
730 // Hash tensor data.
731 switch (tensor.dtype()) {
732 case DT_RESOURCE:
733 case DT_VARIANT:
734 return errors::Unimplemented("Hashing ", DataTypeString(tensor.dtype()),
735 " is not supported.");
736 case DT_STRING:
737 s = tensor.flat<tstring>().data();
738 for (int i = 0; i < tensor.NumElements(); ++i, ++s) {
739 *hash = Hash64Combine(*hash, Hash64(s->data(), s->size()));
740 }
741 break;
742 default:
743 *hash = Hash64(tensor.tensor_data().data(), tensor.tensor_data().size());
744 }
745 return OkStatus();
746}
747
748Status HashNode(const GraphDef& graph, const NodeDef& node, uint64* hash) {
749 const FunctionLibraryDefinition flib_def(OpRegistry::Global(),
750 graph.library());
751 return HashNode(graph, node, flib_def, hash);
752}
753
754Status HashNode(const GraphDef& graph, const NodeDef& node,
755 const FunctionLibraryDefinition& flib_def, uint64* hash) {
756 GraphHasher hasher(&graph, &node, &flib_def);
757 TF_RETURN_IF_ERROR(hasher.Init());
758 return hasher.HashRoot(hash);
759}
760
761Status HashGraph(const GraphDef& graph_def, uint64* hash) {
762 const NodeDef* sink = nullptr;
763 TF_RETURN_IF_ERROR(GetSink(graph_def, &sink));
764 return HashNode(graph_def, *sink, hash);
765}
766
767Status CheckGraphsEqual(const GraphDef& a, const GraphDef& b) {
768 const NodeDef* sink_a;
769 TF_RETURN_IF_ERROR(GetSink(a, &sink_a));
770 const NodeDef* sink_b;
771 TF_RETURN_IF_ERROR(GetSink(b, &sink_b));
772 return CheckSubgraphsEqual(a, sink_a, b, sink_b);
773}
774
775Status CheckSubgraphsEqual(const GraphDef& a, const NodeDef* node_a,
776 const GraphDef& b, const NodeDef* node_b) {
777 const FunctionLibraryDefinition flib_def_a(OpRegistry::Global(), a.library());
778 GraphHasher hasher_a(&a, node_a, &flib_def_a);
779 TF_RETURN_IF_ERROR(hasher_a.Init());
780
781 const FunctionLibraryDefinition flib_def_b(OpRegistry::Global(), b.library());
782 GraphHasher hasher_b(&b, node_b, &flib_def_b);
783 TF_RETURN_IF_ERROR(hasher_b.Init());
784
785 return hasher_a.CheckEqual(&hasher_b);
786}
787
788} // namespace data
789} // namespace tensorflow
790