1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Helpers for loading the persistent representation of a SavedModelV2. |
17 | // Please note that this is depended on by code that does not make use of |
18 | // the full runtime and its dependencies should be restricted. |
19 | |
20 | #ifndef TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ |
21 | #define TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ |
22 | |
23 | #include <functional> |
24 | #include <memory> |
25 | #include <string> |
26 | |
27 | #include "absl/container/flat_hash_set.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/protobuf/graph_debug_info.pb.h" |
30 | #include "tensorflow/core/protobuf/meta_graph.pb.h" |
31 | #include "tensorflow/core/protobuf/saved_object_graph.pb.h" |
32 | #include "tensorflow/core/protobuf/trackable_object_graph.pb.h" |
33 | #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | /// Represents a version 2 SavedModel that is loaded from storage (but not yet |
38 | /// loaded into an executable in-memory representation). |
39 | class SavedModelV2Bundle { |
40 | public: |
41 | using RestoreObjectsCallback = |
42 | std::function<Status(int, const TrackableObjectGraph::TrackableObject&)>; |
43 | |
44 | /// Loads persistent representations for a SavedModelV2 from the specified |
45 | /// export directory. |
46 | static Status Load(const std::string& export_dir, SavedModelV2Bundle* bundle); |
47 | |
48 | /// MetaGraphDef from the loaded SavedModel. |
49 | MetaGraphDef& meta_graph_def() { return meta_graph_def_; } |
50 | |
51 | /// SavedObjectGraph from the MetaGraphDef. |
52 | const SavedObjectGraph& saved_object_graph() { |
53 | return meta_graph_def().object_graph_def(); |
54 | } |
55 | |
56 | /// TrackableObjectGraph loaded from the variable_reader() checkpoint. |
57 | TrackableObjectGraph& trackable_object_graph() { |
58 | return trackable_object_graph_; |
59 | } |
60 | |
61 | /// BundleReader for accessing the variables bundle. |
62 | BundleReader* variable_reader() { return variable_reader_.get(); } |
63 | |
64 | /// The GraphDebugInfo (or nullptr if none). |
65 | GraphDebugInfo* debug_info() { return debug_info_.get(); } |
66 | |
67 | /// Restores objects, invoking the callback with the node id in the |
68 | /// saved_object_graph() and the corresponding TrackableObject from the |
69 | /// trackable_object_graph(). The callback may use the variable_reader() but |
70 | /// must not modify the underlying saved_object_graph(). |
71 | Status VisitObjectsToRestore(RestoreObjectsCallback callback); |
72 | |
73 | private: |
74 | Status RecurseObjectsToRestore( |
75 | const SavedObject* saved_object, int saved_object_node_id, |
76 | const TrackableObjectGraph::TrackableObject* trackable_object, |
77 | std::string object_name, |
78 | absl::flat_hash_set<int>* seen_trackable_node_ids, |
79 | RestoreObjectsCallback callback); |
80 | |
81 | MetaGraphDef meta_graph_def_; |
82 | TrackableObjectGraph trackable_object_graph_; |
83 | std::unique_ptr<BundleReader> variable_reader_; |
84 | std::unique_ptr<GraphDebugInfo> debug_info_; |
85 | }; |
86 | |
87 | } // namespace tensorflow |
88 | |
89 | #endif // TENSORFLOW_CC_SAVED_MODEL_BUNDLE_V2_H_ |
90 | |