1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/saved_model/loader.h"
17
18#include <unordered_set>
19
20#include "tensorflow/cc/saved_model/constants.h"
21#include "tensorflow/cc/saved_model/loader_util.h"
22#include "tensorflow/cc/saved_model/metrics.h"
23#include "tensorflow/cc/saved_model/reader.h"
24#include "tensorflow/cc/saved_model/util.h"
25#include "tensorflow/core/framework/attr_value.pb.h"
26#include "tensorflow/core/framework/function.pb.h"
27#include "tensorflow/core/framework/node_def.pb.h"
28#include "tensorflow/core/framework/op_def.pb.h"
29#include "tensorflow/core/framework/tensor.pb.h"
30#include "tensorflow/core/lib/io/path.h"
31#include "tensorflow/core/lib/monitoring/counter.h"
32#include "tensorflow/core/lib/monitoring/sampler.h"
33#include "tensorflow/core/lib/strings/str_util.h"
34#include "tensorflow/core/lib/strings/strcat.h"
35#include "tensorflow/core/platform/env.h"
36#include "tensorflow/core/platform/errors.h"
37#include "tensorflow/core/platform/file_system_helper.h"
38#include "tensorflow/core/platform/statusor.h"
39#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
40#include "tensorflow/core/protobuf/meta_graph.pb.h"
41#include "tensorflow/core/protobuf/saver.pb.h"
42#include "tensorflow/core/public/session.h"
43#include "tensorflow/core/public/session_options.h"
44#include "tensorflow/core/util/tensor_bundle/naming.h"
45
46namespace tensorflow {
47namespace {
48
49auto* load_attempt_count = monitoring::Counter<2>::New(
50 "/tensorflow/cc/saved_model/load_attempt_count",
51 "The number of times a SavedModel was successfully loaded.", "model_path",
52 "status");
53auto* load_latency = monitoring::Counter<1>::New(
54 "/tensorflow/cc/saved_model/load_latency",
55 "Latency in microseconds for SavedModels that were successfully loaded.",
56 "model_path");
57auto* load_latency_by_stage = monitoring::Sampler<2>::New(
58 {
59 "/tensorflow/cc/saved_model/load_latency_by_stage", // metric name
60 "Distribution of wall time spent (in microseconds) in each stage "
61 "(restore graph from disk, run init graph op, etc) when loading the "
62 "model",
63 "model_path",
64 "stage",
65 },
66 // Scale of 10, power of 1.8 with bucket count 33 (~20 minutes).
67 monitoring::Buckets::Exponential(10, 1.8, 33));
68
69constexpr char kLoadAttemptFail[] = "fail";
70constexpr char kLoadAttemptSuccess[] = "success";
71// `tensorflow::LoadSavedModel` API label.
72constexpr char kCCLoadLabel[] = "cc_load";
73
74uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
75 const uint64 end_microseconds = EnvTime::NowMicros();
76 // Avoid clock skew.
77 if (end_microseconds < start_microseconds) return 0;
78 return end_microseconds - start_microseconds;
79}
80
81// Ensure that constant tensors loaded from the saved model have valid shape.
82// Also ensure that constant nodes have a value assigned to them.
83// TODO(b/154763635): this is temporary and will be replaced with a better audit
84static Status ValidateNode(const NodeDef& node) {
85 const auto node_iterator = node.attr().find("value");
86 if (node_iterator != node.attr().end()) {
87 AttrValue node_value = node_iterator->second;
88 if (node_value.has_tensor()) {
89 const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
90 if (node_shape.num_elements() < 0) {
91 return errors::FailedPrecondition(
92 "Saved model contains node \"", node.name(), "\" (op \"", node.op(),
93 "\") which initializes from a tensor with ",
94 node_shape.num_elements(), " elements");
95 }
96 }
97 } else if (node.op() == "Const") {
98 return errors::FailedPrecondition(
99 "Saved model contains node \"", node.name(),
100 "\" which is a constant tensor but no value has been provided");
101 }
102 return OkStatus();
103}
104
105static Status ValidateFunctionNotRecursive(const FunctionDef& function) {
106 const auto& function_name = function.signature().name();
107 for (const auto& node : function.node_def()) {
108 if (node.op() == function_name) {
109 return errors::FailedPrecondition(
110 "Function ", function_name,
111 " is self recursive and TensorFlow does not support this scenario.");
112 }
113 }
114
115 return OkStatus();
116}
117
118static Status ValidateSavedTensors(const GraphDef& graph_def) {
119 for (const auto& node : graph_def.node()) {
120 TF_RETURN_IF_ERROR(ValidateNode(node));
121 }
122
123 if (graph_def.has_library()) {
124 const FunctionDefLibrary& library = graph_def.library();
125 for (const auto& function : library.function()) {
126 for (const auto& node : function.node_def()) {
127 TF_RETURN_IF_ERROR(ValidateNode(node));
128 }
129
130 // Also check that there is no recursivity in the library
131 TF_RETURN_IF_ERROR(ValidateFunctionNotRecursive(function));
132 }
133 }
134
135 return OkStatus();
136}
137
138Tensor CreateStringTensor(const string& value) {
139 Tensor tensor(DT_STRING, TensorShape({}));
140 tensor.scalar<tstring>()() = value;
141 return tensor;
142}
143
144void AddAssetsTensorsToInputs(const StringPiece export_dir,
145 const std::vector<AssetFileDef>& asset_file_defs,
146 std::vector<std::pair<string, Tensor>>* inputs) {
147 if (asset_file_defs.empty()) {
148 return;
149 }
150 for (auto& asset_file_def : asset_file_defs) {
151 Tensor assets_file_path_tensor = CreateStringTensor(io::JoinPath(
152 export_dir, kSavedModelAssetsDirectory, asset_file_def.filename()));
153 inputs->push_back(
154 {asset_file_def.tensor_info().name(), assets_file_path_tensor});
155 }
156}
157
158// Like Session::Run(), but uses the Make/Run/ReleaseCallable() API to avoid
159// leaving behind non-GC'ed state.
160//
161// Detailed motivation behind this approach, from ashankar@:
162//
163// Each call to Session::Run() that identifies a new subgraph (based on feeds
164// and fetches) creates some datastructures that live as long as the session
165// (the partitioned graph, associated executors etc.).
166//
167// A pathological case of this would be if say the initialization op
168// (main_op/legacy_init_op) involves the use of a large constant. Then we
169// allocate memory for that large constant that will just stick around till the
170// session dies. With this Callable mechanism, that memory will be released
171// right after ReleaseCallable returns.
172//
173// However, the resource manager state remains.
174Status RunOnce(const RunOptions& run_options,
175 const std::vector<std::pair<string, Tensor>>& inputs,
176 const std::vector<string>& output_tensor_names,
177 const std::vector<string>& target_node_names,
178 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
179 Session* session) {
180 CallableOptions callable_options;
181 std::vector<Tensor> feed_tensors;
182 *callable_options.mutable_run_options() = run_options;
183 for (const auto& input : inputs) {
184 const string& name = input.first;
185 const Tensor& tensor = input.second;
186 callable_options.add_feed(name);
187 feed_tensors.push_back(tensor);
188 }
189 for (const string& output_tensor_name : output_tensor_names) {
190 callable_options.add_fetch(output_tensor_name);
191 }
192 for (const string& target_node_name : target_node_names) {
193 callable_options.add_target(target_node_name);
194 }
195
196 Session::CallableHandle callable_handle;
197 TF_RETURN_IF_ERROR(session->MakeCallable(callable_options, &callable_handle));
198 const Status run_status = session->RunCallable(callable_handle, feed_tensors,
199 outputs, run_metadata);
200 // Be sure to call ReleaseCallable() regardless of the outcome of
201 // RunCallable().
202 session->ReleaseCallable(callable_handle).IgnoreError();
203 return run_status;
204}
205
206// RunInitOp will return OK if the initialization op was run successfully.
207// An empty init_op_name indicates that there are no init ops to run.
208Status RunInitOp(const RunOptions& run_options, const string& export_dir,
209 const MetaGraphDef& meta_graph_def,
210 const std::vector<AssetFileDef>& asset_file_defs,
211 Session* session, const string& init_op_name) {
212 if (!init_op_name.empty()) {
213 LOG(INFO) << "Running initialization op on SavedModel bundle at path: "
214 << export_dir;
215 std::vector<std::pair<string, Tensor>> inputs;
216 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
217 RunMetadata run_metadata;
218 return RunOnce(run_options, inputs, {}, {init_op_name},
219 nullptr /* outputs */, &run_metadata, session);
220 }
221 return OkStatus();
222}
223
224Status RunRestore(const RunOptions& run_options, const string& export_dir,
225 const StringPiece restore_op_name,
226 const StringPiece variable_filename_const_op_name,
227 const std::vector<AssetFileDef>& asset_file_defs,
228 Session* session) {
229 LOG(INFO) << "Restoring SavedModel bundle.";
230 // Find path to variables to be restored in export directory.
231 const string variables_directory =
232 io::JoinPath(export_dir, kSavedModelVariablesDirectory);
233 // Check for saver checkpoints in v2 format. Models exported in the checkpoint
234 // v2 format will have a variables.index file. The corresponding
235 // variables are stored in the variables.data-?????-of-????? files.
236 const string variables_index_path = io::JoinPath(
237 variables_directory, MetaFilename(kSavedModelVariablesFilename));
238 TF_ASSIGN_OR_RETURN(
239 bool variables_index_exists,
240 internal::FileExists(Env::Default(), variables_index_path));
241 if (!variables_index_exists) {
242 LOG(INFO) << "The specified SavedModel has no variables; no checkpoints "
243 "were restored. File does not exist: "
244 << variables_index_path;
245 return OkStatus();
246 }
247 const string variables_path =
248 io::JoinPath(variables_directory, kSavedModelVariablesFilename);
249
250 // Add variables to the graph.
251 Tensor variables_path_tensor(DT_STRING, TensorShape({}));
252 variables_path_tensor.scalar<tstring>()() = variables_path;
253
254 std::vector<std::pair<string, Tensor>> inputs = {
255 {string(variable_filename_const_op_name), variables_path_tensor}};
256
257 AddAssetsTensorsToInputs(export_dir, asset_file_defs, &inputs);
258
259 RunMetadata run_metadata;
260 return RunOnce(run_options, inputs, {}, {string(restore_op_name)},
261 nullptr /* outputs */, &run_metadata, session);
262}
263
264} // namespace
265
266SavedModelBundleInterface::~SavedModelBundleInterface() {}
267
268Status LoadMetagraphIntoSession(const SessionOptions& session_options,
269 const MetaGraphDef& meta_graph,
270 std::unique_ptr<Session>* session) {
271 Session* session_p = nullptr;
272 TF_RETURN_IF_ERROR(NewSession(session_options, &session_p));
273 session->reset(session_p);
274 TF_RETURN_IF_ERROR(ValidateSavedTensors(meta_graph.graph_def()));
275 return (*session)->Create(meta_graph.graph_def());
276}
277
278Status LoadSavedModelInternal(const SessionOptions& session_options,
279 const RunOptions& run_options,
280 const string& export_dir,
281 const std::unordered_set<string>& tags,
282 SavedModelBundle* const bundle) {
283 TF_RETURN_IF_ERROR(ReadMetaGraphDefFromSavedModel(export_dir, tags,
284 &bundle->meta_graph_def));
285 TF_RETURN_IF_ERROR(
286 ReadSavedModelDebugInfoIfPresent(export_dir, &bundle->debug_info));
287 TF_RETURN_IF_ERROR(LoadMetagraphIntoSession(
288 session_options, bundle->meta_graph_def, &bundle->session));
289 TF_RETURN_IF_ERROR(RestoreSession(run_options, bundle->meta_graph_def,
290 export_dir, &bundle->session));
291 return OkStatus();
292}
293
294Status LoadSavedModel(const SessionOptions& session_options,
295 const RunOptions& run_options, const string& export_dir,
296 const std::unordered_set<string>& tags,
297 SavedModelBundle* const bundle) {
298 metrics::SavedModelReadApi(kCCLoadLabel).IncrementBy(1);
299
300 // TODO(robson): Add tests for the counters.
301 const uint64 start_microseconds = Env::Default()->NowMicros();
302 const Status status = LoadSavedModelInternal(session_options, run_options,
303 export_dir, tags, bundle);
304 auto log_and_count = [&](const string& status_str) {
305 LOG(INFO) << "SavedModel load for tags { " << absl::StrJoin(tags, " ")
306 << " }; Status: " << status_str << ": " << status << ". Took "
307 << GetLatencyMicroseconds(start_microseconds) << " microseconds.";
308 load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1);
309 };
310 if (status.ok()) {
311 log_and_count(kLoadAttemptSuccess);
312 } else {
313 log_and_count(kLoadAttemptFail);
314 }
315 load_latency->GetCell(export_dir)
316 ->IncrementBy(GetLatencyMicroseconds(start_microseconds));
317 return status;
318}
319
320namespace {
321// Session wrapper that prevents calls to Session::Create(), Session::Extend(),
322// and the deprecated partial-run methods.
323//
324// Limiting the available methods on a returned Session gives us the option
325// to replace the Session with a cut-down implementation, without breaking any
326// users.
327class LiteSessionWrapper : public Session {
328 public:
329 explicit LiteSessionWrapper(std::unique_ptr<Session> wrapped)
330 : wrapped_(std::move(wrapped)) {}
331
332 Status Create(const GraphDef& graph) override {
333 return errors::Unimplemented("Session::Create()");
334 }
335 Status Create(GraphDef&& graph) override {
336 return errors::Unimplemented("Session::Create()");
337 }
338
339 Status Extend(const GraphDef& graph) override {
340 return errors::Unimplemented("Session::Extend()");
341 }
342 Status Extend(GraphDef&& graph) override {
343 return errors::Unimplemented("Session::Extend()");
344 }
345
346 Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
347 const std::vector<string>& output_tensor_names,
348 const std::vector<string>& target_node_names,
349 std::vector<Tensor>* outputs) override {
350 return wrapped_->Run(inputs, output_tensor_names, target_node_names,
351 outputs);
352 }
353
354 Status Create(const RunOptions& run_options, const GraphDef& graph) override {
355 return errors::Unimplemented("Session::Create()");
356 }
357 Status Extend(const RunOptions& run_options, const GraphDef& graph) override {
358 return errors::Unimplemented("Session::Extend()");
359 }
360 Status Create(const RunOptions& run_options, GraphDef&& graph) override {
361 return errors::Unimplemented("Session::Create()");
362 }
363 Status Extend(const RunOptions& run_options, GraphDef&& graph) override {
364 return errors::Unimplemented("Session::Extend()");
365 }
366 Status Close(const RunOptions& run_options) override {
367 return wrapped_->Close(run_options);
368 }
369
370 Status Run(const RunOptions& run_options,
371 const std::vector<std::pair<string, Tensor>>& inputs,
372 const std::vector<string>& output_tensor_names,
373 const std::vector<string>& target_node_names,
374 std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
375 return wrapped_->Run(run_options, inputs, output_tensor_names,
376 target_node_names, outputs, run_metadata);
377 }
378
379 Status PRunSetup(const std::vector<string>& input_names,
380 const std::vector<string>& output_names,
381 const std::vector<string>& target_nodes,
382 string* handle) override {
383 return errors::Unimplemented("Session::PRunSetup()");
384 }
385
386 Status PRun(const string& handle,
387 const std::vector<std::pair<string, Tensor>>& inputs,
388 const std::vector<string>& output_names,
389 std::vector<Tensor>* outputs) override {
390 return errors::Unimplemented("Session::PRun()");
391 }
392
393 Status ListDevices(std::vector<DeviceAttributes>* response) override {
394 return wrapped_->ListDevices(response);
395 }
396
397 Status Close() override { return wrapped_->Close(); }
398
399 Status MakeCallable(const CallableOptions& callable_options,
400 CallableHandle* out_handle) override {
401 return wrapped_->MakeCallable(callable_options, out_handle);
402 }
403
404 Status RunCallable(CallableHandle handle,
405 const std::vector<Tensor>& feed_tensors,
406 std::vector<Tensor>* fetch_tensors,
407 RunMetadata* run_metadata) override {
408 return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
409 run_metadata);
410 }
411
412 Status RunCallable(
413 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
414 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
415 const thread::ThreadPoolOptions& threadpool_options) override {
416 return wrapped_->RunCallable(handle, feed_tensors, fetch_tensors,
417 run_metadata, threadpool_options);
418 }
419
420 Status ReleaseCallable(CallableHandle handle) override {
421 return wrapped_->ReleaseCallable(handle);
422 }
423
424 private:
425 const std::unique_ptr<Session> wrapped_;
426};
427} // namespace
428
429Status RestoreSession(const RunOptions& run_options,
430 const MetaGraphDef& meta_graph, const string& export_dir,
431 std::unique_ptr<Session>* session) {
432 const uint64 read_start_microseconds = Env::Default()->NowMicros();
433 std::vector<AssetFileDef> asset_file_defs;
434 TF_RETURN_IF_ERROR(internal::GetAssetFileDefs(meta_graph, &asset_file_defs));
435 if (meta_graph.has_saver_def()) {
436 TF_RETURN_IF_ERROR(RunRestore(run_options, export_dir,
437 meta_graph.saver_def().restore_op_name(),
438 meta_graph.saver_def().filename_tensor_name(),
439 asset_file_defs, session->get()));
440 }
441 // Record walltime spent in restoring graph from disk, but postpone metric
442 // increments until graph init finishes.
443 const uint64 restore_graph_walltime =
444 GetLatencyMicroseconds(read_start_microseconds);
445
446 const uint64 graph_init_start_microseconds = Env::Default()->NowMicros();
447 string init_op_name;
448 TF_RETURN_IF_ERROR(
449 internal::GetInitOp(export_dir, meta_graph, &init_op_name));
450 TF_RETURN_IF_ERROR(RunInitOp(run_options, export_dir, meta_graph,
451 asset_file_defs, session->get(), init_op_name));
452 load_latency_by_stage->GetCell(export_dir, "restore_graph")
453 ->Add(restore_graph_walltime);
454 // Record wall time spent in init op.
455 load_latency_by_stage->GetCell(export_dir, "init_graph")
456 ->Add(GetLatencyMicroseconds(graph_init_start_microseconds));
457 return OkStatus();
458}
459
460Status LoadSavedModel(const SessionOptions& session_options,
461 const RunOptions& run_options, const string& export_dir,
462 const std::unordered_set<string>& tags,
463 SavedModelBundleLite* const bundle) {
464 SavedModelBundle legacy_bundle;
465 SessionOptions rewritten_options(session_options);
466 // We disallow calls to Session::Extend() on the returned session, so we can
467 // reduce memory consumption by not storing the original GraphDef.
468 rewritten_options.config.mutable_experimental()
469 ->set_optimize_for_static_graph(true);
470 // Disallowing the `RunOptions.output_partition_graphs` option (typically used
471 // in debugging and tests) allows us to reduce memory consumption further by
472 // not storing the rewritten subgraph for each signature.
473 rewritten_options.config.mutable_experimental()
474 ->set_disable_output_partition_graphs(true);
475 // TODO(mrry): Consider specializing the session creation to reduce peak
476 // RAM consumption by using `Session::Create(GraphDef&&)`.
477 TF_RETURN_IF_ERROR(LoadSavedModel(rewritten_options, run_options, export_dir,
478 tags, &legacy_bundle));
479 *bundle = SavedModelBundleLite(
480 absl::make_unique<LiteSessionWrapper>(std::move(legacy_bundle.session)),
481 std::move(*legacy_bundle.meta_graph_def.mutable_signature_def()));
482 return OkStatus();
483}
484
485bool MaybeSavedModelDirectory(const string& export_dir) {
486 const string saved_model_pb_path =
487 io::JoinPath(export_dir, kSavedModelFilenamePb);
488 const string saved_model_pbtxt_path =
489 io::JoinPath(export_dir, kSavedModelFilenamePbTxt);
490 return Env::Default()->FileExists(saved_model_pb_path).ok() ||
491 Env::Default()->FileExists(saved_model_pbtxt_path).ok();
492}
493
494} // namespace tensorflow
495