1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/saved_model/fingerprinting.h"
17
18#include <algorithm>
19#include <string>
20
21#include "absl/container/btree_map.h"
22#include "absl/strings/strip.h"
23#include "tensorflow/cc/saved_model/constants.h"
24#include "tensorflow/core/framework/function.pb.h"
25#include "tensorflow/core/framework/op_def.pb.h"
26#include "tensorflow/core/framework/versions.pb.h"
27#include "tensorflow/core/graph/regularization/simple_delete.h"
28#include "tensorflow/core/graph/regularization/util.h"
29#include "tensorflow/core/lib/strings/proto_serialization.h"
30#include "tensorflow/core/platform/env.h"
31#include "tensorflow/core/platform/fingerprint.h"
32#include "tensorflow/core/platform/path.h"
33#include "tensorflow/core/platform/statusor.h"
34#include "tensorflow/core/protobuf/fingerprint.pb.h"
35#include "tensorflow/core/protobuf/meta_graph.pb.h"
36#include "tensorflow/core/protobuf/saved_model.pb.h"
37#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
38#include "tensorflow/core/util/tensor_bundle/naming.h"
39
40namespace tensorflow::saved_model::fingerprinting {
41
42// Version of the code that produced the fingerprint.
43const int kFingerprintProducer = 0;
44namespace {
45
46uint64 RegularizeAndHashSignatureDefs(
47 const google::protobuf::Map<std::string, SignatureDef>& signature_def_map) {
48 // Sort `signature_def_map`, which is an unordered map from string keys to
49 // SignatureDefs.
50 absl::btree_map<std::string, SignatureDef> sorted_signature_defs;
51 sorted_signature_defs.insert(signature_def_map.begin(),
52 signature_def_map.end());
53 uint64 result_hash = 0;
54 for (const auto& item : sorted_signature_defs) {
55 std::string signature_def_string;
56 SerializeToStringDeterministic(item.second, &signature_def_string);
57 result_hash = FingerprintCat64(
58 result_hash, tensorflow::Fingerprint64(signature_def_string));
59 }
60 return result_hash;
61}
62
63// The SavedObjectGraph contains two parts: the list of nodes and the map of
64// concrete functions. Regularization treats these two parts separately.
65uint64 RegularizeAndHashSavedObjectGraph(
66 const SavedObjectGraph& object_graph_def) {
67 // Sort `concrete_functions`, which is an unordered map from function names to
68 // SavedConcreteFunction, using the suffix UID of the function name. Assumes
69 // that the trackable children are listed in a deterministic order during
70 // serialization.
71 absl::btree_map<int, std::string> uid_to_function_names;
72 for (const auto& [name, concrete_function] :
73 object_graph_def.concrete_functions()) {
74 StatusOr<int> uid = graph_regularization::GetSuffixUID(name);
75 // All valid function names should end in an UID.
76 if (uid.ok()) {
77 uid_to_function_names.insert({*uid, name});
78 } else {
79 LOG(ERROR) << uid.status().error_message();
80 }
81 }
82 uint64 result_hash = 0;
83 for (const auto& [uid, function_name] : uid_to_function_names) {
84 // Hash the function name (with the UID stripped).
85 result_hash = FingerprintCat64(result_hash,
86 tensorflow::Fingerprint64(absl::StripSuffix(
87 function_name, std::to_string(uid))));
88 // Hash the serialized concrete function.
89 std::string concrete_function_string;
90 SerializeToStringDeterministic(
91 object_graph_def.concrete_functions().at(function_name),
92 &concrete_function_string);
93 result_hash = FingerprintCat64(
94 result_hash, tensorflow::Fingerprint64(concrete_function_string));
95 }
96 // TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`.
97 return result_hash;
98}
99
100// Returns the hash of the checkpoint .index file, 0 if there is none.
101uint64 HashCheckpointIndexFile(absl::string_view model_dir) {
102 std::string meta_filename = MetaFilename(io::JoinPath(
103 model_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename));
104 std::string data;
105 Status read_status = ReadFileToString(Env::Default(), meta_filename, &data);
106 if (read_status.ok()) {
107 return tensorflow::Fingerprint64(data);
108 } else {
109 LOG(WARNING) << read_status.error_message();
110 return 0;
111 }
112}
113
114} // namespace
115
116FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph,
117 absl::string_view export_dir) {
118 // Create a copy of `metagraph` which will be used and mutated for fingerprint
119 // computation.
120 MetaGraphDef metagraph_copy = metagraph;
121 FingerprintDef fingerprint_def;
122 // Set fingerprint field #1.
123 fingerprint_def.set_graph_def_checksum(
124 graph_regularization::ComputeHash(metagraph_copy.graph_def()));
125 // Set fingerprint field #2.
126 graph_regularization::SimpleDelete(*metagraph_copy.mutable_graph_def());
127 fingerprint_def.set_graph_def_program_hash(
128 graph_regularization::ComputeHash(metagraph_copy.graph_def()));
129 // Set fingerprint field #3.
130 fingerprint_def.set_signature_def_hash(
131 RegularizeAndHashSignatureDefs(metagraph_copy.signature_def()));
132 // Set fingerprint field #4.
133 StatusOr<uint64> object_graph_hash =
134 RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def());
135 fingerprint_def.set_saved_object_graph_hash(
136 RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()));
137 // Set fingerprint field #5.
138 fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir));
139 // Set version of the fingerprint.
140 VersionDef* version = fingerprint_def.mutable_version();
141 version->set_producer(kFingerprintProducer);
142
143 return fingerprint_def;
144}
145
146} // namespace tensorflow::saved_model::fingerprinting
147