1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | /// SavedModel loading functions and SavedModelBundle struct. |
17 | |
18 | #ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ |
19 | #define TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ |
20 | |
21 | #include <string> |
22 | #include <unordered_set> |
23 | |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/protobuf/graph_debug_info.pb.h" |
26 | #include "tensorflow/core/protobuf/meta_graph.pb.h" |
27 | #include "tensorflow/core/public/session.h" |
28 | |
29 | namespace tensorflow { |
30 | |
31 | /// Represents a SavedModel that is loaded from storage. |
32 | class SavedModelBundleInterface { |
33 | public: |
34 | virtual ~SavedModelBundleInterface(); |
35 | |
36 | /// Returns the TensorFlow Session that can be used to interact with the |
37 | /// SavedModel. |
38 | virtual Session* GetSession() const = 0; |
39 | |
40 | /// Returns a map from signature name to SignatureDef for all signatures in |
41 | /// in the SavedModel. |
42 | virtual const protobuf::Map<string, SignatureDef>& GetSignatures() const = 0; |
43 | }; |
44 | |
45 | /// SavedModel representation once the SavedModel is loaded from storage. |
46 | /// |
47 | /// NOTE: Prefer to use SavedModelBundleLite in new code, as it consumes less |
48 | /// RAM. |
49 | struct SavedModelBundle : public SavedModelBundleInterface { |
50 | /// A TensorFlow Session does not Close itself on destruction. To avoid |
51 | /// resource leaks, we explicitly call Close on Sessions that we create. |
52 | ~SavedModelBundle() override { |
53 | if (session) { |
54 | session->Close().IgnoreError(); |
55 | } |
56 | } |
57 | |
58 | SavedModelBundle() = default; |
59 | |
60 | Session* GetSession() const override { return session.get(); } |
61 | const protobuf::Map<string, SignatureDef>& GetSignatures() const override { |
62 | return meta_graph_def.signature_def(); |
63 | } |
64 | |
65 | std::unique_ptr<Session> session; |
66 | MetaGraphDef meta_graph_def; |
67 | std::unique_ptr<GraphDebugInfo> debug_info; |
68 | }; |
69 | |
70 | // A version of SavedModelBundle that avoids storing a potentially large |
71 | // MetaGraphDef. Prefer to use SavedModelBundleLite in new code. |
72 | class SavedModelBundleLite : public SavedModelBundleInterface { |
73 | public: |
74 | SavedModelBundleLite() = default; |
75 | SavedModelBundleLite(SavedModelBundleLite&& other) = default; |
76 | SavedModelBundleLite& operator=(SavedModelBundleLite&& other) = default; |
77 | |
78 | SavedModelBundleLite(std::unique_ptr<Session> session, |
79 | protobuf::Map<string, SignatureDef> signatures) |
80 | : session_(std::move(session)), signatures_(std::move(signatures)) {} |
81 | |
82 | /// A TensorFlow Session does not Close itself on destruction. To avoid |
83 | /// resource leaks, we explicitly call Close on Sessions that we create. |
84 | ~SavedModelBundleLite() override { |
85 | if (session_) { |
86 | session_->Close().IgnoreError(); |
87 | } |
88 | } |
89 | |
90 | Session* GetSession() const override { return session_.get(); } |
91 | const protobuf::Map<string, SignatureDef>& GetSignatures() const override { |
92 | return signatures_; |
93 | } |
94 | |
95 | private: |
96 | std::unique_ptr<Session> session_; |
97 | protobuf::Map<string, SignatureDef> signatures_; |
98 | }; |
99 | |
100 | // Restore variable and resources in the SavedModel export dir for the |
101 | // indicated metagraph. |
102 | // The recommended way to load a saved model is to call LoadSavedModel, |
103 | // which provides an already initialized Metagraph, Session, and DebugInfo. |
104 | Status RestoreSession(const RunOptions& run_options, |
105 | const MetaGraphDef& meta_graph, const string& export_dir, |
106 | std::unique_ptr<Session>* session); |
107 | |
108 | // Initialize a session which wraps this metagraph. |
109 | // The recommended way to load a saved model is to call LoadSavedModel, |
110 | // which provides an already initialized Metagraph, Session, and DebugInfo. |
111 | Status LoadMetagraphIntoSession(const SessionOptions& session_options, |
112 | const MetaGraphDef& meta_graph, |
113 | std::unique_ptr<Session>* session); |
114 | |
115 | /// Loads a SavedModel from the specified export directory. The MetaGraphDef |
116 | /// to be loaded is identified by the supplied tags, corresponding exactly to |
117 | /// the set of tags used at SavedModel build time. Stores a SavedModel bundle in |
118 | /// *bundle with a session and the requested MetaGraphDef, if found. |
119 | /// |
120 | /// NOTE: Prefer the overload that takes a SavedModelBundleLite* in new code. |
121 | Status LoadSavedModel(const SessionOptions& session_options, |
122 | const RunOptions& run_options, const string& export_dir, |
123 | const std::unordered_set<string>& tags, |
124 | SavedModelBundle* const bundle); |
125 | |
126 | /// Loads a SavedModel from the specified export directory. The MetaGraphDef |
127 | /// to be loaded is identified by the supplied tags, corresponding exactly to |
128 | /// the set of tags used at SavedModel build time. Stores a SavedModel bundle |
129 | /// in *bundle with a session created from the requested MetaGraphDef if found. |
130 | /// |
131 | /// This overload creates a SavedModelBundleLite, which consumes less RAM than |
132 | /// an equivalent SavedModelBundle. |
133 | Status LoadSavedModel(const SessionOptions& session_options, |
134 | const RunOptions& run_options, const string& export_dir, |
135 | const std::unordered_set<string>& tags, |
136 | SavedModelBundleLite* const bundle); |
137 | |
138 | /// Checks whether the provided directory could contain a SavedModel. Note that |
139 | /// the method does not load any data by itself. If the method returns `false`, |
140 | /// the export directory definitely does not contain a SavedModel. If the method |
141 | /// returns `true`, the export directory may contain a SavedModel but provides |
142 | /// no guarantee that it can be loaded. |
143 | bool MaybeSavedModelDirectory(const std::string& export_dir); |
144 | |
145 | } // namespace tensorflow |
146 | |
147 | #endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ |
148 | |