1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/saved_model/bundle_v2.h"
17
18#include <string>
19#include <utility>
20
21#include "tensorflow/cc/saved_model/constants.h"
22#include "tensorflow/cc/saved_model/metrics.h"
23#include "tensorflow/cc/saved_model/reader.h"
24#include "tensorflow/cc/saved_model/util.h"
25#include "tensorflow/core/platform/env.h"
26#include "tensorflow/core/platform/path.h"
27#include "tensorflow/core/platform/strcat.h"
28#include "tensorflow/core/protobuf/saved_model.pb.h"
29#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
30#include "tensorflow/core/util/tensor_bundle/byte_swap.h"
31
32namespace tensorflow {
33namespace {
34
35using error::Code::NOT_FOUND;
36using strings::StrCat;
37
38// `tensorflow::SavedModelV2Bundle::Load` API label.
39constexpr char kCCLoadBundleV2Label[] = "cc_load_bundle_v2";
40
41Status ReadSavedModelProto(const string& export_dir,
42 SavedModel* saved_model_proto) {
43 LOG(INFO) << "Reading SavedModel from: " << export_dir;
44
45 const string saved_model_pb_path =
46 io::JoinPath(export_dir, kSavedModelFilenamePb);
47 Status found_pb = Env::Default()->FileExists(saved_model_pb_path);
48 if (found_pb.ok()) {
49 Status result =
50 ReadBinaryProto(Env::Default(), saved_model_pb_path, saved_model_proto);
51 if (result.ok()) {
52 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
53 .IncrementBy(1);
54 }
55 return result;
56 }
57
58 const string saved_model_pbtxt_path =
59 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
60 Status found_pbtxt = Env::Default()->FileExists(saved_model_pbtxt_path);
61 if (found_pbtxt.ok()) {
62 Status result = ReadTextProto(Env::Default(), saved_model_pbtxt_path,
63 saved_model_proto);
64 if (result.ok()) {
65 metrics::SavedModelRead(saved_model::GetWriteVersion(*saved_model_proto))
66 .IncrementBy(1);
67 }
68 return result;
69 }
70
71 Status err;
72 if (found_pb.code() == found_pbtxt.code()) {
73 err = Status(found_pb.code(), StrCat(found_pb.error_message(), "\n",
74 found_pbtxt.error_message()));
75 } else if (found_pb.code() == NOT_FOUND) {
76 err = found_pbtxt;
77 } else if (found_pbtxt.code() == NOT_FOUND) {
78 err = found_pb;
79 } else {
80 // found_pb and found_pbtxt both errored, w/ different codes, neither being
81 // NOT_FOUND.
82 err = Status(
83 error::Code::INTERNAL,
84 StrCat("Different errors encountered while looking for saved_model.pb "
85 "and saved_model.pbtxt in the export directory path \"",
86 export_dir, "\": \n", found_pb.ToString(), "\n",
87 found_pbtxt.ToString()));
88 }
89
90 return err;
91}
92
93Status ReadCheckpointObjectGraph(BundleReader* bundle_reader,
94 TrackableObjectGraph* object_graph) {
95 Tensor object_graph_tensor;
96 TF_RETURN_WITH_CONTEXT_IF_ERROR(
97 bundle_reader->Lookup(kObjectGraphProtoKey, &object_graph_tensor),
98 "SavedModel checkpoint does not contain object graph.");
99 if (object_graph_tensor.dtype() != DT_STRING ||
100 object_graph_tensor.dims() != 0 ||
101 object_graph_tensor.NumElements() != 1) {
102 return Status(
103 error::Code::FAILED_PRECONDITION,
104 "SavedModel checkpoint object graph was not the correct type.");
105 }
106
107 const tstring* object_graph_string = reinterpret_cast<const tstring*>(
108 object_graph_tensor.tensor_data().data());
109 if (!object_graph->ParseFromString(*object_graph_string)) {
110 return Status(
111 error::Code::FAILED_PRECONDITION,
112 "SavedModel checkpoint object graph could not be deserialized.");
113 }
114 return OkStatus();
115}
116
117} // namespace
118
119Status SavedModelV2Bundle::Load(const std::string& export_dir,
120 SavedModelV2Bundle* const bundle) {
121 metrics::SavedModelReadApi(kCCLoadBundleV2Label).IncrementBy(1);
122 SavedModel saved_model_proto;
123 TF_RETURN_IF_ERROR(ReadSavedModelProto(export_dir, &saved_model_proto));
124
125 // Load MetaGraphDef.
126 // In version 2 SavedModels, there is only one MetaGraphDef.
127 if (saved_model_proto.meta_graphs_size() != 1) {
128 return Status(
129 error::Code::INVALID_ARGUMENT,
130 strings::StrCat(
131 "SavedModelV2 should have exactly one MetaGraphDef but actually ",
132 "contains ", saved_model_proto.meta_graphs_size()));
133 }
134 bundle->meta_graph_def_ =
135 std::move(*saved_model_proto.mutable_meta_graphs(0));
136
137 // Correct the endiness of Tensor content on big-endian system
138 if (!port::kLittleEndian) {
139 TF_RETURN_IF_ERROR(ByteSwapTensorContent(&(bundle->meta_graph_def_)));
140 }
141
142 // Load GraphDebugInfo.
143 TF_RETURN_IF_ERROR(
144 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info_));
145
146 const std::string variables_dir =
147 io::JoinPath(export_dir, kSavedModelVariablesDirectory);
148 if (!Env::Default()->FileExists(variables_dir).ok()) {
149 LOG(INFO)
150 << "No checkpoint found, assuming this is a program-only SavedModel";
151 } else {
152 // Load the variables checkpoint reader.
153 const std::string variables_prefix =
154 io::JoinPath(variables_dir, kSavedModelVariablesFilename);
155 bundle->variable_reader_.reset(
156 new BundleReader(Env::Default(), variables_prefix));
157 TF_RETURN_WITH_CONTEXT_IF_ERROR(
158 bundle->variable_reader_->status(),
159 "Unable to load SavedModel variables checkpoint from ",
160 variables_prefix);
161
162 // Deserialize the object graph proto from the tensor bundle.
163 TF_RETURN_IF_ERROR(ReadCheckpointObjectGraph(
164 bundle->variable_reader_.get(), &bundle->trackable_object_graph_));
165 }
166 return OkStatus();
167}
168
169Status SavedModelV2Bundle::VisitObjectsToRestore(
170 RestoreObjectsCallback callback) {
171 if (saved_object_graph().nodes_size() == 0 ||
172 trackable_object_graph().nodes_size() == 0) {
173 return OkStatus();
174 }
175
176 // Start from root nodes of both the SavedObjectGraph and TrackableObjectGraph
177 // and descend to leaves. Note that the TrackableObjectGraph can have cycles
178 // (as can the SavedObjectGraph).
179 // This is detected and cycle edges are skipped.
180 const SavedObject* root_saved_object = &saved_object_graph().nodes(0);
181 const TrackableObjectGraph::TrackableObject* root_trackable_object =
182 &trackable_object_graph().nodes(0);
183 absl::flat_hash_set<int> trackable_node_ids;
184 return RecurseObjectsToRestore(root_saved_object, 0, root_trackable_object,
185 std::string(), &trackable_node_ids,
186 std::move(callback));
187}
188
189Status SavedModelV2Bundle::RecurseObjectsToRestore(
190 const SavedObject* saved_object, int saved_object_node_id,
191 const TrackableObjectGraph::TrackableObject* trackable_object,
192 std::string object_name, absl::flat_hash_set<int>* seen_trackable_node_ids,
193 RestoreObjectsCallback callback) {
194 // Callback if any attributes or slot variables.
195 // Note that the root is always excluded from the search (it can never
196 // be a restorable object). This matches some logic on the Python side.
197 if (saved_object_node_id != 0 &&
198 (trackable_object->attributes_size() > 0 ||
199 trackable_object->slot_variables_size() > 0)) {
200 TF_RETURN_WITH_CONTEXT_IF_ERROR(
201 callback(saved_object_node_id, *trackable_object), "Unable to restore ",
202 object_name);
203 }
204
205 for (const auto& trackable_child_ref : trackable_object->children()) {
206 const auto& local_name = trackable_child_ref.local_name();
207
208 // Compute the full child name.
209 std::string child_name;
210 if (object_name.empty()) {
211 child_name = local_name;
212 } else {
213 child_name = strings::StrCat(object_name, ".", local_name);
214 }
215
216 // Descend down the trackable graph.
217 int trackable_child_node_id = trackable_child_ref.node_id();
218 if (!seen_trackable_node_ids->insert(trackable_child_node_id).second) {
219 // Cycle or duplicate detected - ignore this branch.
220 continue;
221 }
222 if (trackable_child_node_id < 0 ||
223 trackable_child_node_id >= trackable_object_graph().nodes_size()) {
224 return errors::FailedPrecondition(
225 strings::StrCat("Illegal trackable child node id for ", child_name));
226 }
227 const auto* trackable_child =
228 &trackable_object_graph().nodes(trackable_child_node_id);
229
230 // Descend down the saved object graph.
231 int saved_child_node_id = -1;
232 const SavedObject* saved_child = nullptr;
233 for (const auto& saved_child_ref : saved_object->children()) {
234 if (saved_child_ref.local_name() == local_name) {
235 // Found.
236 saved_child_node_id = saved_child_ref.node_id();
237 if (saved_child_node_id >= 0 &&
238 saved_child_node_id < saved_object_graph().nodes_size()) {
239 saved_child = &saved_object_graph().nodes(saved_child_node_id);
240 }
241 break;
242 }
243 }
244
245 if (!saved_child) {
246 return Status(
247 errors::Code::FAILED_PRECONDITION,
248 strings::StrCat("Could not find saved object to restore for ",
249 child_name));
250 }
251
252 TF_RETURN_IF_ERROR(RecurseObjectsToRestore(
253 saved_child, saved_child_node_id, trackable_child, child_name,
254 seen_trackable_node_ids, callback));
255 }
256 return OkStatus();
257}
258
259} // namespace tensorflow
260