1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/saved_model/bundle_v2.h" |
17 | |
18 | #include <string> |
19 | #include <utility> |
20 | |
21 | #include "tensorflow/cc/saved_model/constants.h" |
22 | #include "tensorflow/cc/saved_model/metrics.h" |
23 | #include "tensorflow/cc/saved_model/reader.h" |
24 | #include "tensorflow/cc/saved_model/util.h" |
25 | #include "tensorflow/core/platform/env.h" |
26 | #include "tensorflow/core/platform/path.h" |
27 | #include "tensorflow/core/platform/strcat.h" |
28 | #include "tensorflow/core/protobuf/saved_model.pb.h" |
29 | #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" |
30 | #include "tensorflow/core/util/tensor_bundle/byte_swap.h" |
31 | |
32 | namespace tensorflow { |
33 | namespace { |
34 | |
35 | using error::Code::NOT_FOUND; |
36 | using strings::StrCat; |
37 | |
38 | // `tensorflow::SavedModelV2Bundle::Load` API label. |
39 | constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2" ; |
40 | |
41 | Status ReadSavedModelProto(const string& export_dir, |
42 | SavedModel* saved_model_proto) { |
43 | LOG(INFO) << "Reading SavedModel from: " << export_dir; |
44 | |
45 | const string saved_model_pb_path = |
46 | io::JoinPath(export_dir, kSavedModelFilenamePb); |
47 | Status found_pb = Env::Default()->FileExists(saved_model_pb_path); |
48 | if (found_pb.ok()) { |
49 | Status result = |
50 | ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto); |
51 | if (result.ok()) { |
52 | metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto)) |
53 | .IncrementBy(1); |
54 | } |
55 | return result; |
56 | } |
57 | |
58 | const string saved_model_pbtxt_path = |
59 | io::JoinPath(export_dir, kSavedModelFilenamePbTxt); |
60 | Status found_pbtxt = Env::Default()->FileExists(saved_model_pbtxt_path); |
61 | if (found_pbtxt.ok()) { |
62 | Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path, |
63 | saved_model_proto); |
64 | if (result.ok()) { |
65 | metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto)) |
66 | .IncrementBy(1); |
67 | } |
68 | return result; |
69 | } |
70 | |
71 | Status err; |
72 | if (found_pb.code() == found_pbtxt.code()) { |
73 | err = Status(found_pb.code(), StrCat(found_pb.error_message(), "\n" , |
74 | found_pbtxt.error_message())); |
75 | } else if (found_pb.code() == NOT_FOUND) { |
76 | err = found_pbtxt; |
77 | } else if (found_pbtxt.code() == NOT_FOUND) { |
78 | err = found_pb; |
79 | } else { |
80 | // found_pb and found_pbtxt both errored, w/ different codes, neither being |
81 | // NOT_FOUND. |
82 | err = Status( |
83 | error::Code::INTERNAL, |
84 | StrCat("Different errors encountered while looking for saved_model.pb " |
85 | "and saved_model.pbtxt in the export directory path \"" , |
86 | export_dir, "\": \n" , found_pb.ToString(), "\n" , |
87 | found_pbtxt.ToString())); |
88 | } |
89 | |
90 | return err; |
91 | } |
92 | |
93 | Status ReadCheckpointObjectGraph(BundleReader* bundle_reader, |
94 | TrackableObjectGraph* object_graph) { |
95 | Tensor object_graph_tensor; |
96 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
97 | bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor), |
98 | "SavedModel checkpoint does not contain object graph." ); |
99 | if (object_graph_tensor.dtype() != DT_STRING || |
100 | object_graph_tensor.dims() != 0 || |
101 | object_graph_tensor.NumElements() != 1) { |
102 | return Status( |
103 | error::Code::FAILED_PRECONDITION, |
104 | "SavedModel checkpoint object graph was not the correct type." ); |
105 | } |
106 | |
107 | const tstring* object_graph_string = reinterpret_cast<const tstring*>( |
108 | object_graph_tensor.tensor_data().data()); |
109 | if (!object_graph->ParseFromString(*object_graph_string)) { |
110 | return Status( |
111 | error::Code::FAILED_PRECONDITION, |
112 | "SavedModel checkpoint object graph could not be deserialized." ); |
113 | } |
114 | return OkStatus(); |
115 | } |
116 | |
117 | } // namespace |
118 | |
119 | Status SavedModelV2Bundle::Load(const std::string& export_dir, |
120 | SavedModelV2Bundle* const bundle) { |
121 | metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1); |
122 | SavedModel saved_model_proto; |
123 | TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto)); |
124 | |
125 | // Load MetaGraphDef. |
126 | // In version 2 SavedModels, there is only one MetaGraphDef. |
127 | if (saved_model_proto.meta_graphs_size() != 1) { |
128 | return Status( |
129 | error::Code::INVALID_ARGUMENT, |
130 | strings::StrCat( |
131 | "SavedModelV2 should have exactly one MetaGraphDef but actually " , |
132 | "contains " , saved_model_proto.meta_graphs_size())); |
133 | } |
134 | bundle->meta_graph_def_ = |
135 | std::move(*saved_model_proto.mutable_meta_graphs(0)); |
136 | |
137 | // Correct the endiness of Tensor content on big-endian system |
138 | if (!port::kLittleEndian) { |
139 | TF_RETURN_IF_ERROR(ByteSwapTensorContent(&(bundle->meta_graph_def_))); |
140 | } |
141 | |
142 | // Load GraphDebugInfo. |
143 | TF_RETURN_IF_ERROR( |
144 | ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_)); |
145 | |
146 | const std::string variables_dir = |
147 | io::JoinPath(export_dir, kSavedModelVariablesDirectory); |
148 | if (!Env::Default()->FileExists(variables_dir).ok()) { |
149 | LOG(INFO) |
150 | << "No checkpoint found, assuming this is a program-only SavedModel" ; |
151 | } else { |
152 | // Load the variables checkpoint reader. |
153 | const std::string variables_prefix = |
154 | io::JoinPath(variables_dir, kSavedModelVariablesFilename); |
155 | bundle->variable_reader_.reset( |
156 | new BundleReader(Env::Default(), variables_prefix)); |
157 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
158 | bundle->variable_reader_->status(), |
159 | "Unable to load SavedModel variables checkpoint from " , |
160 | variables_prefix); |
161 | |
162 | // Deserialize the object graph proto from the tensor bundle. |
163 | TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph( |
164 | bundle->variable_reader_.get(), &bundle->trackable_object_graph_)); |
165 | } |
166 | return OkStatus(); |
167 | } |
168 | |
169 | Status SavedModelV2Bundle::VisitObjectsToRestore( |
170 | RestoreObjectsCallback callback) { |
171 | if (saved_object_graph().nodes_size() == 0 || |
172 | trackable_object_graph().nodes_size() == 0) { |
173 | return OkStatus(); |
174 | } |
175 | |
176 | // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph |
177 | // and descend to leaves. Note that the TrackableObjectGraph can have cycles |
178 | // (as can the SavedObjectGraph). |
179 | // This is detected and cycle edges are skipped. |
180 | const SavedObject* root_saved_object = &saved_object_graph().nodes(0); |
181 | const TrackableObjectGraph::TrackableObject* root_trackable_object = |
182 | &trackable_object_graph().nodes(0); |
183 | absl::flat_hash_set<int> trackable_node_ids; |
184 | return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object, |
185 | std::string(), &trackable_node_ids, |
186 | std::move(callback)); |
187 | } |
188 | |
189 | Status SavedModelV2Bundle::RecurseObjectsToRestore( |
190 | const SavedObject* saved_object, int saved_object_node_id, |
191 | const TrackableObjectGraph::TrackableObject* trackable_object, |
192 | std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids, |
193 | RestoreObjectsCallback callback) { |
194 | // Callback if any attributes or slot variables. |
195 | // Note that the root is always excluded from the search (it can never |
196 | // be a restorable object). This matches some logic on the Python side. |
197 | if (saved_object_node_id != 0 && |
198 | (trackable_object->attributes_size() > 0 || |
199 | trackable_object->slot_variables_size() > 0)) { |
200 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
201 | callback(saved_object_node_id, *trackable_object), "Unable to restore " , |
202 | object_name); |
203 | } |
204 | |
205 | for (const auto& trackable_child_ref : trackable_object->children()) { |
206 | const auto& local_name = trackable_child_ref.local_name(); |
207 | |
208 | // Compute the full child name. |
209 | std::string child_name; |
210 | if (object_name.empty()) { |
211 | child_name = local_name; |
212 | } else { |
213 | child_name = strings::StrCat(object_name, "." , local_name); |
214 | } |
215 | |
216 | // Descend down the trackable graph. |
217 | int trackable_child_node_id = trackable_child_ref.node_id(); |
218 | if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) { |
219 | // Cycle or duplicate detected - ignore this branch. |
220 | continue; |
221 | } |
222 | if (trackable_child_node_id < 0 || |
223 | trackable_child_node_id >= trackable_object_graph().nodes_size()) { |
224 | return errors::FailedPrecondition( |
225 | strings::StrCat("Illegal trackable child node id for " , child_name)); |
226 | } |
227 | const auto* trackable_child = |
228 | &trackable_object_graph().nodes(trackable_child_node_id); |
229 | |
230 | // Descend down the saved object graph. |
231 | int saved_child_node_id = -1; |
232 | const SavedObject* saved_child = nullptr; |
233 | for (const auto& saved_child_ref : saved_object->children()) { |
234 | if (saved_child_ref.local_name() == local_name) { |
235 | // Found. |
236 | saved_child_node_id = saved_child_ref.node_id(); |
237 | if (saved_child_node_id >= 0 && |
238 | saved_child_node_id < saved_object_graph().nodes_size()) { |
239 | saved_child = &saved_object_graph().nodes(saved_child_node_id); |
240 | } |
241 | break; |
242 | } |
243 | } |
244 | |
245 | if (!saved_child) { |
246 | return Status( |
247 | errors::Code::FAILED_PRECONDITION, |
248 | strings::StrCat("Could not find saved object to restore for " , |
249 | child_name)); |
250 | } |
251 | |
252 | TF_RETURN_IF_ERROR(RecurseObjectsToRestore( |
253 | saved_child, saved_child_node_id, trackable_child, child_name, |
254 | seen_trackable_node_ids, callback)); |
255 | } |
256 | return OkStatus(); |
257 | } |
258 | |
259 | } // namespace tensorflow |
260 | |