1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_PUBLIC_SESSION_H_ |
17 | #define TENSORFLOW_CORE_PUBLIC_SESSION_H_ |
18 | |
19 | #include <string> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/device_attributes.pb.h" |
23 | #include "tensorflow/core/framework/graph.pb.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/core/status.h" |
27 | #include "tensorflow/core/platform/env.h" |
28 | #include "tensorflow/core/protobuf/config.pb.h" |
29 | #include "tensorflow/core/public/session_options.h" |
30 | |
31 | namespace tsl { |
32 | namespace thread { |
33 | struct ThreadPoolOptions; |
34 | } |
35 | } // namespace tsl |
36 | |
37 | namespace tensorflow { |
38 | |
39 | class DeviceMgr; |
40 | |
41 | namespace thread { |
42 | using tsl::thread::ThreadPoolOptions; |
43 | } |
44 | |
45 | /// \brief A Session instance lets a caller drive a TensorFlow graph |
46 | /// computation. |
47 | /// |
48 | /// When a Session is created with a given target, a new Session object |
49 | /// is bound to the universe of resources specified by that target. |
50 | /// Those resources are available to this session to perform |
51 | /// computation described in the GraphDef. After extending the session |
52 | /// with a graph, the caller uses the Run() API to perform the |
53 | /// computation and potentially fetch outputs as Tensors. |
54 | /// |
55 | /// Example: |
56 | /// |
57 | /// ```c++ |
58 | /// |
59 | /// tensorflow::GraphDef graph; |
60 | /// // ... Create or load graph into "graph". |
61 | /// |
62 | /// // This example uses the default options which connects |
63 | /// // to a local runtime. |
64 | /// tensorflow::SessionOptions options; |
65 | /// std::unique_ptr<tensorflow::Session> |
66 | /// session(tensorflow::NewSession(options)); |
67 | /// |
68 | /// // Create the session with this graph. |
69 | /// tensorflow::Status s = session->Create(graph); |
70 | /// if (!s.ok()) { ... } |
71 | /// |
72 | /// // Run the graph and fetch the first output of the "output" |
73 | /// // operation, and also run to but do not return anything |
74 | /// // for the "update_state" operation. |
75 | /// std::vector<tensorflow::Tensor> outputs; |
76 | /// s = session->Run({}, {"output:0"}, {"update_state"}, &outputs); |
77 | /// if (!s.ok()) { ... } |
78 | /// |
79 | /// // Map the output as a flattened float tensor, and do something |
80 | /// // with it. |
81 | /// auto output_tensor = outputs[0].flat<float>(); |
82 | /// if (output_tensor(0) > 0.5) { ... } |
83 | /// |
84 | /// // Close the session to release the resources associated with |
85 | /// // this session. |
86 | /// session->Close(); |
87 | /// |
88 | /// ``` |
89 | /// |
90 | /// A Session allows concurrent calls to Run(), though a Session must |
91 | /// be created / extended by a single thread. |
92 | /// |
93 | /// Only one thread must call Close(), and Close() must only be called |
94 | /// after all other calls to Run() have returned. |
95 | class Session { |
96 | public: |
97 | Session(); |
98 | virtual ~Session(); |
99 | |
100 | /// \brief Create the graph to be used for the session. |
101 | /// |
102 | /// Returns an error if this session has already been created with a |
103 | /// graph. To re-use the session with a different graph, the caller |
104 | /// must Close() the session first. |
105 | virtual Status Create(const GraphDef& graph) = 0; |
106 | #ifndef SWIG |
107 | virtual Status Create(GraphDef&& graph) { return Create(graph); } |
108 | #endif |
109 | |
110 | /// \brief Adds operations to the graph that is already registered with the |
111 | /// Session. |
112 | /// |
113 | /// The names of new operations in "graph" must not exist in the |
114 | /// graph that is already registered. |
115 | virtual Status Extend(const GraphDef& graph) = 0; |
116 | #ifndef SWIG |
117 | virtual Status Extend(GraphDef&& graph) { return Extend(graph); } |
118 | #endif |
119 | |
120 | /// \brief Runs the graph with the provided input tensors and fills |
121 | /// `outputs` for the endpoints specified in `output_tensor_names`. |
122 | /// Runs to but does not return Tensors for the nodes in |
123 | /// `target_tensor_names`. |
124 | /// |
125 | /// The order of tensors in `outputs` will match the order provided |
126 | /// by `output_tensor_names`. |
127 | /// |
128 | /// If `Run` returns `OK()`, then `outputs->size()` will be equal to |
129 | /// `output_tensor_names.size()`. If `Run` does not return `OK()`, the |
130 | /// state of `outputs` is undefined. |
131 | /// |
132 | /// REQUIRES: The name of each Tensor of the input or output must |
133 | /// match a "Tensor endpoint" in the `GraphDef` passed to `Create()`. |
134 | /// |
135 | /// REQUIRES: At least one of `output_tensor_names` and |
136 | /// `target_tensor_names` must be non-empty. |
137 | /// |
138 | /// REQUIRES: outputs is not nullptr if `output_tensor_names` is non-empty. |
139 | virtual Status Run(const std::vector<std::pair<std::string, Tensor> >& inputs, |
140 | const std::vector<std::string>& output_tensor_names, |
141 | const std::vector<std::string>& target_tensor_names, |
142 | std::vector<Tensor>* outputs) = 0; |
143 | |
144 | /// \brief Implementations which support `RunOptions`. |
145 | // |
146 | /// NOTE: This API is still experimental and may change. |
147 | virtual Status Create(const RunOptions& run_options, const GraphDef& graph) { |
148 | return errors::Unimplemented( |
149 | "Create(const RunOptions& run_options, const GraphDef& graph) is not " |
150 | "supported for this session." ); |
151 | } |
152 | virtual Status Extend(const RunOptions& run_options, const GraphDef& graph) { |
153 | return errors::Unimplemented( |
154 | "Extend(const RunOptions& run_options, const GraphDef& graph) is not " |
155 | "supported for this session." ); |
156 | } |
157 | #ifndef SWIG |
158 | virtual Status Create(const RunOptions& run_options, GraphDef&& graph) { |
159 | return Create(run_options, graph); |
160 | } |
161 | virtual Status Extend(const RunOptions& run_options, GraphDef&& graph) { |
162 | return Extend(run_options, graph); |
163 | } |
164 | #endif |
165 | virtual Status Close(const RunOptions& run_options) { |
166 | return errors::Unimplemented( |
167 | "Close(const RunOptions& run_options) is not supported for this " |
168 | "session." ); |
169 | } |
170 | |
171 | /// \brief Like `Run`, but allows users to pass in a `RunOptions` proto and |
172 | /// to retrieve non-Tensor metadata output via a `RunMetadata` proto for this |
173 | /// step. `run_metadata` may be nullptr, in which case any metadata output is |
174 | /// discarded. |
175 | /// NOTE: This API is still experimental and may change. |
176 | virtual Status Run(const RunOptions& run_options, |
177 | const std::vector<std::pair<std::string, Tensor> >& inputs, |
178 | const std::vector<std::string>& output_tensor_names, |
179 | const std::vector<std::string>& target_tensor_names, |
180 | std::vector<Tensor>* outputs, RunMetadata* run_metadata); |
181 | |
182 | /// \brief Like `Run` with `RunOptions` proto, but allows user to provide |
183 | /// custom threadpool implementation via ThreadPoolOptions. |
184 | /// NOTE: This API is still experimental and may change. |
185 | virtual Status Run(const RunOptions& run_options, |
186 | const std::vector<std::pair<std::string, Tensor> >& inputs, |
187 | const std::vector<std::string>& output_tensor_names, |
188 | const std::vector<std::string>& target_tensor_names, |
189 | std::vector<Tensor>* outputs, RunMetadata* run_metadata, |
190 | const thread::ThreadPoolOptions& threadpool_options) { |
191 | return errors::Unimplemented( |
192 | "Run with threadpool is not supported for this session." ); |
193 | } |
194 | |
195 | /// \brief Sets up a graph for partial execution. All future feeds and |
196 | /// fetches are specified by `input_names` and `output_names`. Returns |
197 | /// `handle` that can be used to perform a sequence of partial feeds and |
198 | /// fetches. |
199 | /// NOTE: This API is still experimental and may change. |
200 | virtual Status PRunSetup(const std::vector<std::string>& input_names, |
201 | const std::vector<std::string>& output_names, |
202 | const std::vector<std::string>& target_nodes, |
203 | std::string* handle); |
204 | |
205 | /// \brief Continues the pending execution specified by `handle` with the |
206 | /// provided input tensors and fills `outputs` for the endpoints specified |
207 | /// in `output_names`. |
208 | /// NOTE: This API is still experimental and may change. |
209 | virtual Status PRun( |
210 | const std::string& handle, |
211 | const std::vector<std::pair<std::string, Tensor> >& inputs, |
212 | const std::vector<std::string>& output_names, |
213 | std::vector<Tensor>* outputs); |
214 | |
215 | /// \brief List devices in the session. |
216 | /// |
217 | /// Retrieves the list of available devices within the session, and populates |
218 | /// *response. This API is optional. If it is unimplemented, Status will |
219 | /// return a corresponding error message, and *response will be unmodified. |
220 | virtual Status ListDevices(std::vector<DeviceAttributes>* response) = 0; |
221 | |
222 | /// \brief Closes this session. |
223 | /// |
224 | /// Closing a session releases the resources used by this session |
225 | /// on the TensorFlow runtime (specified during session creation by |
226 | /// the `SessionOptions::target` field). |
227 | virtual Status Close() = 0; |
228 | |
229 | // NOTE(ashankar): As of July 2017, this method was added to facilitate some |
230 | // experimentation. Reconsider/re-evaluate after September 2017. |
231 | // |
232 | // Sets `*output` to the `DeviceMgr` that owns accessible devices in the |
233 | // address-space of the caller. |
234 | virtual Status LocalDeviceManager(const DeviceMgr** output) { |
235 | return errors::Unimplemented( |
236 | "LocalDeviceManager is not supported for this session." ); |
237 | } |
238 | |
239 | /// \brief A handle to a subgraph, created with `Session::MakeCallable()`. |
240 | typedef int64_t CallableHandle; |
241 | |
242 | /// \brief Creates a `handle` for invoking the subgraph defined by |
243 | /// `callable_options`. |
244 | /// NOTE: This API is still experimental and may change. |
245 | virtual Status MakeCallable(const CallableOptions& callable_options, |
246 | CallableHandle* out_handle) { |
247 | return errors::Unimplemented( |
248 | "MakeCallable is not supported for this session." ); |
249 | } |
250 | |
251 | /// \brief Invokes the subgraph named by `handle` with the given options and |
252 | /// input tensors. |
253 | /// |
254 | /// The order of tensors in `feed_tensors` must and `fetch_tensors` will |
255 | /// match the order of names in `CallableOptions::feed()` and |
256 | /// `CallableOptions::fetch()` when this subgraph was created. |
257 | /// NOTE: This API is still experimental and may change. |
258 | virtual Status RunCallable(CallableHandle handle, |
259 | const std::vector<Tensor>& feed_tensors, |
260 | std::vector<Tensor>* fetch_tensors, |
261 | RunMetadata* run_metadata) { |
262 | return errors::Unimplemented( |
263 | "RunCallable is not supported for this session." ); |
264 | } |
265 | |
266 | /// \brief Invokes the subgraph named by `handle` with the given options and |
267 | /// input tensors. User can provide custom threadpool implementation via |
268 | /// threadpool_options. |
269 | /// |
270 | /// The order of tensors in `feed_tensors` must and `fetch_tensors` will |
271 | /// match the order of names in `CallableOptions::feed()` and |
272 | /// `CallableOptions::fetch()` when this subgraph was created. |
273 | /// NOTE: This API is still experimental and may change. |
274 | virtual Status RunCallable( |
275 | CallableHandle handle, const std::vector<Tensor>& feed_tensors, |
276 | std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata, |
277 | const thread::ThreadPoolOptions& threadpool_options) { |
278 | return errors::Unimplemented( |
279 | "RunCallable with threadpool is not supported for this session." ); |
280 | } |
281 | |
282 | /// \brief Releases resources associated with the given `handle` in this |
283 | /// session. |
284 | /// NOTE: This API is still experimental and may change. |
285 | virtual Status ReleaseCallable(CallableHandle handle) { |
286 | return errors::Unimplemented( |
287 | "ReleaseCallable is not supported for this session." ); |
288 | } |
289 | |
290 | /// \brief Release global graph-related state in this session. |
291 | /// |
292 | /// After calling `this->Finalize()`, calls to `this->Run()` with previously |
293 | /// unseen feeds and fetches, and calls to `this->MakeCallable()` will fail. |
294 | /// Using `MakeCallable()` and `RunCallable()` is recommended, because |
295 | /// explicit callable creation makes it clearer where the `Finalize()` call |
296 | /// should be placed. |
297 | /// |
298 | /// This API can be used in conjunction with a "warmup" phase to reduce the |
299 | /// memory consumed by the session: |
300 | /// |
301 | /// 1. Call `Session::Create()`. |
302 | /// 2. Call `Session::MakeCallable()` for all subgraphs that you will execute |
303 | /// in the session. |
304 | /// 3. Call `Session::Finalize()` to release global graph-related state. |
305 | /// 4. Call `Session::RunCallable()` with the handle(s) created in step 2. |
306 | /// |
307 | /// NOTE: This API is still experimental and may change. |
308 | virtual Status Finalize() { |
309 | return errors::Unimplemented("Finalize is not supported for this session." ); |
310 | } |
311 | }; |
312 | |
313 | /// \brief Create a new session with the given options. |
314 | /// |
315 | /// If session creation succeeds, the new `Session` will be stored in |
316 | /// `*out_session`, the caller will take ownership of the returned |
317 | /// `*out_session`, and this function will return `OK()`. Otherwise, this |
318 | /// function will return an error status and set *out_session to nullptr. |
319 | Status NewSession(const SessionOptions& options, Session** out_session); |
320 | |
321 | /// \brief Resets resource containers associated with a target. |
322 | /// |
323 | /// Reset() allows misbehaving or slow sessions to be aborted and closed, and |
324 | /// causes their resources eventually to be released. Reset() does not wait |
325 | /// for the computations in old sessions to cease; it merely starts the |
326 | /// process of tearing them down. However, if a new session is started after |
327 | /// a Reset(), the new session is isolated from changes that old sessions |
328 | /// (started prior to the Reset()) may continue to make to resources, provided |
329 | /// all those resources are in containers listed in "containers". |
330 | /// |
331 | /// Old sessions may continue to have side-effects on resources not in |
332 | /// containers listed in "containers", and thus may affect future |
333 | /// sessions' results in ways that are hard to predict. Thus, if well-defined |
334 | /// behavior is desired, it is recommended that all containers be listed in |
335 | /// "containers". |
336 | /// |
337 | /// `containers` is a vector of string representation of resource container |
338 | /// names. When a resource container is reset, the resources held by the |
339 | /// container will be released. In particular, all Variables in the container |
340 | /// will become undefined. If the "containers" vector is empty, the default |
341 | /// container is assumed. If the "containers" vector is non-empty, the |
342 | /// default container should be listed explicitly. |
343 | /// |
344 | /// If Reset succeeds, this function will return `OK()`. Otherwise, this |
345 | /// function will return an error status. |
346 | Status Reset(const SessionOptions& options, |
347 | const std::vector<std::string>& containers); |
348 | |
349 | /// \brief Create a new session with the given options. |
350 | /// |
351 | /// If a new `Session` object could not be created, this function will |
352 | /// return nullptr. |
353 | /// |
354 | /// *Strongly prefer* the version of NewSession that returns Status, |
355 | /// which contains more helpful error information. |
356 | Session* NewSession(const SessionOptions& options); |
357 | |
358 | } // end namespace tensorflow |
359 | |
360 | #endif // TENSORFLOW_CORE_PUBLIC_SESSION_H_ |
361 | |