1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/saved_model/fingerprinting.h" |
17 | |
18 | #include <algorithm> |
19 | #include <string> |
20 | |
21 | #include "absl/container/btree_map.h" |
22 | #include "absl/strings/strip.h" |
23 | #include "tensorflow/cc/saved_model/constants.h" |
24 | #include "tensorflow/core/framework/function.pb.h" |
25 | #include "tensorflow/core/framework/op_def.pb.h" |
26 | #include "tensorflow/core/framework/versions.pb.h" |
27 | #include "tensorflow/core/graph/regularization/simple_delete.h" |
28 | #include "tensorflow/core/graph/regularization/util.h" |
29 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
30 | #include "tensorflow/core/platform/env.h" |
31 | #include "tensorflow/core/platform/fingerprint.h" |
32 | #include "tensorflow/core/platform/path.h" |
33 | #include "tensorflow/core/platform/statusor.h" |
34 | #include "tensorflow/core/protobuf/fingerprint.pb.h" |
35 | #include "tensorflow/core/protobuf/meta_graph.pb.h" |
36 | #include "tensorflow/core/protobuf/saved_model.pb.h" |
37 | #include "tensorflow/core/protobuf/saved_object_graph.pb.h" |
38 | #include "tensorflow/core/util/tensor_bundle/naming.h" |
39 | |
40 | namespace tensorflow::saved_model::fingerprinting { |
41 | |
42 | // Version of the code that produced the fingerprint. |
43 | const int kFingerprintProducer = 0; |
44 | namespace { |
45 | |
46 | uint64 RegularizeAndHashSignatureDefs( |
47 | const google::protobuf::Map<std::string, SignatureDef>& signature_def_map) { |
48 | // Sort `signature_def_map`, which is an unordered map from string keys to |
49 | // SignatureDefs. |
50 | absl::btree_map<std::string, SignatureDef> sorted_signature_defs; |
51 | sorted_signature_defs.insert(signature_def_map.begin(), |
52 | signature_def_map.end()); |
53 | uint64 result_hash = 0; |
54 | for (const auto& item : sorted_signature_defs) { |
55 | std::string signature_def_string; |
56 | SerializeToStringDeterministic(item.second, &signature_def_string); |
57 | result_hash = FingerprintCat64( |
58 | result_hash, tensorflow::Fingerprint64(signature_def_string)); |
59 | } |
60 | return result_hash; |
61 | } |
62 | |
63 | // The SavedObjectGraph contains two parts: the list of nodes and the map of |
64 | // concrete functions. Regularization treats these two parts separately. |
65 | uint64 RegularizeAndHashSavedObjectGraph( |
66 | const SavedObjectGraph& object_graph_def) { |
67 | // Sort `concrete_functions`, which is an unordered map from function names to |
68 | // SavedConcreteFunction, using the suffix UID of the function name. Assumes |
69 | // that the trackable children are listed in a deterministic order during |
70 | // serialization. |
71 | absl::btree_map<int, std::string> uid_to_function_names; |
72 | for (const auto& [name, concrete_function] : |
73 | object_graph_def.concrete_functions()) { |
74 | StatusOr<int> uid = graph_regularization::GetSuffixUID(name); |
75 | // All valid function names should end in an UID. |
76 | if (uid.ok()) { |
77 | uid_to_function_names.insert({*uid, name}); |
78 | } else { |
79 | LOG(ERROR) << uid.status().error_message(); |
80 | } |
81 | } |
82 | uint64 result_hash = 0; |
83 | for (const auto& [uid, function_name] : uid_to_function_names) { |
84 | // Hash the function name (with the UID stripped). |
85 | result_hash = FingerprintCat64(result_hash, |
86 | tensorflow::Fingerprint64(absl::StripSuffix( |
87 | function_name, std::to_string(uid)))); |
88 | // Hash the serialized concrete function. |
89 | std::string concrete_function_string; |
90 | SerializeToStringDeterministic( |
91 | object_graph_def.concrete_functions().at(function_name), |
92 | &concrete_function_string); |
93 | result_hash = FingerprintCat64( |
94 | result_hash, tensorflow::Fingerprint64(concrete_function_string)); |
95 | } |
96 | // TODO(b/241294832): Complete canonicalization of `object_graph_def.nodes`. |
97 | return result_hash; |
98 | } |
99 | |
100 | // Returns the hash of the checkpoint .index file, 0 if there is none. |
101 | uint64 HashCheckpointIndexFile(absl::string_view model_dir) { |
102 | std::string meta_filename = MetaFilename(io::JoinPath( |
103 | model_dir, kSavedModelVariablesDirectory, kSavedModelVariablesFilename)); |
104 | std::string data; |
105 | Status read_status = ReadFileToString(Env::Default(), meta_filename, &data); |
106 | if (read_status.ok()) { |
107 | return tensorflow::Fingerprint64(data); |
108 | } else { |
109 | LOG(WARNING) << read_status.error_message(); |
110 | return 0; |
111 | } |
112 | } |
113 | |
114 | } // namespace |
115 | |
116 | FingerprintDef CreateFingerprintDef(const MetaGraphDef& metagraph, |
117 | absl::string_view export_dir) { |
118 | // Create a copy of `metagraph` which will be used and mutated for fingerprint |
119 | // computation. |
120 | MetaGraphDef metagraph_copy = metagraph; |
121 | FingerprintDef fingerprint_def; |
122 | // Set fingerprint field #1. |
123 | fingerprint_def.set_graph_def_checksum( |
124 | graph_regularization::ComputeHash(metagraph_copy.graph_def())); |
125 | // Set fingerprint field #2. |
126 | graph_regularization::SimpleDelete(*metagraph_copy.mutable_graph_def()); |
127 | fingerprint_def.set_graph_def_program_hash( |
128 | graph_regularization::ComputeHash(metagraph_copy.graph_def())); |
129 | // Set fingerprint field #3. |
130 | fingerprint_def.set_signature_def_hash( |
131 | RegularizeAndHashSignatureDefs(metagraph_copy.signature_def())); |
132 | // Set fingerprint field #4. |
133 | StatusOr<uint64> object_graph_hash = |
134 | RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def()); |
135 | fingerprint_def.set_saved_object_graph_hash( |
136 | RegularizeAndHashSavedObjectGraph(metagraph_copy.object_graph_def())); |
137 | // Set fingerprint field #5. |
138 | fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); |
139 | // Set version of the fingerprint. |
140 | VersionDef* version = fingerprint_def.mutable_version(); |
141 | version->set_producer(kFingerprintProducer); |
142 | |
143 | return fingerprint_def; |
144 | } |
145 | |
146 | } // namespace tensorflow::saved_model::fingerprinting |
147 | |