1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Helpers for loading the persistent representation of a SavedModelV2.
17// Please note that this is depended on by code that does not make use of
18// the full runtime and its dependencies should be restricted.
19
20#ifndef TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_
21#define TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_
22
23#include <functional>
24#include <memory>
25#include <string>
26
27#include "absl/container/flat_hash_set.h"
28#include "tensorflow/core/lib/core/status.h"
29#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
30#include "tensorflow/core/protobuf/meta_graph.pb.h"
31#include "tensorflow/core/protobuf/saved_object_graph.pb.h"
32#include "tensorflow/core/protobuf/trackable_object_graph.pb.h"
33#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
34
35namespace tensorflow {
36
37/// Represents a version 2 SavedModel that is loaded from storage (but not yet
38/// loaded into an executable in-memory representation).
39class SavedModelV2Bundle {
40 public:
41 using RestoreObjectsCallback =
42 std::function<Status(int, const TrackableObjectGraph::TrackableObject&)>;
43
44 /// Loads persistent representations for a SavedModelV2 from the specified
45 /// export directory.
46 static Status Load(const std::string& export_dir, SavedModelV2Bundle* bundle);
47
48 /// MetaGraphDef from the loaded SavedModel.
49 MetaGraphDef& meta_graph_def() { return meta_graph_def_; }
50
51 /// SavedObjectGraph from the MetaGraphDef.
52 const SavedObjectGraph& saved_object_graph() {
53 return meta_graph_def().object_graph_def();
54 }
55
56 /// TrackableObjectGraph loaded from the variable_reader() checkpoint.
57 TrackableObjectGraph& trackable_object_graph() {
58 return trackable_object_graph_;
59 }
60
61 /// BundleReader for accessing the variables bundle.
62 BundleReader* variable_reader() { return variable_reader_.get(); }
63
64 /// The GraphDebugInfo (or nullptr if none).
65 GraphDebugInfo* debug_info() { return debug_info_.get(); }
66
67 /// Restores objects, invoking the callback with the node id in the
68 /// saved_object_graph() and the corresponding TrackableObject from the
69 /// trackable_object_graph(). The callback may use the variable_reader() but
70 /// must not modify the underlying saved_object_graph().
71 Status VisitObjectsToRestore(RestoreObjectsCallback callback);
72
73 private:
74 Status RecurseObjectsToRestore(
75 const SavedObject* saved_object, int saved_object_node_id,
76 const TrackableObjectGraph::TrackableObject* trackable_object,
77 std::string object_name,
78 absl::flat_hash_set<int>* seen_trackable_node_ids,
79 RestoreObjectsCallback callback);
80
81 MetaGraphDef meta_graph_def_;
82 TrackableObjectGraph trackable_object_graph_;
83 std::unique_ptr<BundleReader> variable_reader_;
84 std::unique_ptr<GraphDebugInfo> debug_info_;
85};
86
87} // namespace tensorflow
88
89#endif // TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_
90