1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
17#define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
18
19#include <memory>
20#include <string>
21#include <unordered_map>
22#include <vector>
23
24#include "tensorflow/cc/framework/ops.h"
25#include "tensorflow/cc/framework/scope.h"
26#include "tensorflow/core/public/session_options.h"
27
28namespace tsl {
29namespace thread {
30struct ThreadPoolOptions;
31}
32} // namespace tsl
33
34namespace tensorflow {
35
36namespace thread {
37using tsl::thread::ThreadPoolOptions;
38}
39
40/// @addtogroup core
41/// @{
42
43/// A `ClientSession` object lets the caller drive the evaluation of the
44/// TensorFlow graph constructed with the C++ API.
45///
46/// Example:
47///
48/// Scope root = Scope::NewRootScope();
49/// auto a = Placeholder(root, DT_INT32);
50/// auto c = Add(root, a, {41});
51///
52/// ClientSession session(root);
53/// std::vector<Tensor> outputs;
54///
55/// Status s = session.Run({ {a, {1}} }, {c}, &outputs);
56/// if (!s.ok()) { ... }
57class ClientSession {
58 public:
59 /// A data type to represent feeds to a Run call.
60 ///
61 /// This is a map of `Output` objects returned by op-constructors to the value
62 /// to feed them with. See `Input::Initializer` for details on what can be
63 /// used as feed values.
64 typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType;
65
66 /// Create a new session to evaluate the graph contained in `scope` by
67 /// connecting to the TensorFlow runtime specified by `target`.
68 ClientSession(const Scope& scope, const string& target);
69
70 /// Same as above, but use the empty string ("") as the target specification.
71 explicit ClientSession(const Scope& scope);
72
73 /// Create a new session, configuring it with `session_options`.
74 ClientSession(const Scope& scope, const SessionOptions& session_options);
75
76 ~ClientSession();
77
78 /// Evaluate the tensors in `fetch_outputs`. The values are returned as
79 /// `Tensor` objects in `outputs`. The number and order of `outputs` will
80 /// match `fetch_outputs`.
81 Status Run(const std::vector<Output>& fetch_outputs,
82 std::vector<Tensor>* outputs) const;
83
84 /// Same as above, but use the mapping in `inputs` as feeds.
85 Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
86 std::vector<Tensor>* outputs) const;
87
88 /// Same as above. Additionally runs the operations ins `run_outputs`.
89 Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs,
90 const std::vector<Operation>& run_outputs,
91 std::vector<Tensor>* outputs) const;
92
93 /// Use `run_options` to turn on performance profiling. `run_metadata`, if not
94 /// null, is filled in with the profiling results.
95 Status Run(const RunOptions& run_options, const FeedType& inputs,
96 const std::vector<Output>& fetch_outputs,
97 const std::vector<Operation>& run_outputs,
98 std::vector<Tensor>* outputs, RunMetadata* run_metadata) const;
99
100 /// Same as above. Additionally allows user to provide custom threadpool
101 /// implementation via ThreadPoolOptions.
102 Status Run(const RunOptions& run_options, const FeedType& inputs,
103 const std::vector<Output>& fetch_outputs,
104 const std::vector<Operation>& run_outputs,
105 std::vector<Tensor>* outputs, RunMetadata* run_metadata,
106 const thread::ThreadPoolOptions& threadpool_options) const;
107
108 /// \brief A handle to a subgraph, created with
109 /// `ClientSession::MakeCallable()`.
110 typedef int64_t CallableHandle;
111
112 /// \brief Creates a `handle` for invoking the subgraph defined by
113 /// `callable_options`.
114 /// NOTE: This API is still experimental and may change.
115 Status MakeCallable(const CallableOptions& callable_options,
116 CallableHandle* out_handle);
117
118 /// \brief Invokes the subgraph named by `handle` with the given options and
119 /// input tensors.
120 ///
121 /// The order of tensors in `feed_tensors` must match the order of names in
122 /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
123 /// match the order of names in `CallableOptions::fetch()` when this subgraph
124 /// was created.
125 /// NOTE: This API is still experimental and may change.
126 Status RunCallable(CallableHandle handle,
127 const std::vector<Tensor>& feed_tensors,
128 std::vector<Tensor>* fetch_tensors,
129 RunMetadata* run_metadata);
130
131 /// \brief Invokes the subgraph named by `handle` with the given options and
132 /// input tensors.
133 ///
134 /// The order of tensors in `feed_tensors` must match the order of names in
135 /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will
136 /// match the order of names in `CallableOptions::fetch()` when this subgraph
137 /// was created.
138 /// NOTE: This API is still experimental and may change.
139 Status RunCallable(CallableHandle handle,
140 const std::vector<Tensor>& feed_tensors,
141 std::vector<Tensor>* fetch_tensors,
142 RunMetadata* run_metadata,
143 const thread::ThreadPoolOptions& options);
144
145 /// \brief Releases resources associated with the given `handle` in this
146 /// session.
147 /// NOTE: This API is still experimental and may change.
148 Status ReleaseCallable(CallableHandle handle);
149
150 private:
151 class Impl;
152 std::unique_ptr<Impl> impl_;
153 Impl* impl() { return impl_.get(); }
154 const Impl* impl() const { return impl_.get(); }
155};
156
157/// @}
158
159} // end namespace tensorflow
160
161#endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_
162