1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ |
17 | #define TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ |
18 | |
19 | #include <memory> |
20 | #include <string> |
21 | #include <unordered_map> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/cc/framework/ops.h" |
25 | #include "tensorflow/cc/framework/scope.h" |
26 | #include "tensorflow/core/public/session_options.h" |
27 | |
28 | namespace tsl { |
29 | namespace thread { |
30 | struct ThreadPoolOptions; |
31 | } |
32 | } // namespace tsl |
33 | |
34 | namespace tensorflow { |
35 | |
36 | namespace thread { |
37 | using tsl::thread::ThreadPoolOptions; |
38 | } |
39 | |
40 | /// @addtogroup core |
41 | /// @{ |
42 | |
43 | /// A `ClientSession` object lets the caller drive the evaluation of the |
44 | /// TensorFlow graph constructed with the C++ API. |
45 | /// |
46 | /// Example: |
47 | /// |
48 | /// Scope root = Scope::NewRootScope(); |
49 | /// auto a = Placeholder(root, DT_INT32); |
50 | /// auto c = Add(root, a, {41}); |
51 | /// |
52 | /// ClientSession session(root); |
53 | /// std::vector<Tensor> outputs; |
54 | /// |
55 | /// Status s = session.Run({ {a, {1}} }, {c}, &outputs); |
56 | /// if (!s.ok()) { ... } |
57 | class ClientSession { |
58 | public: |
59 | /// A data type to represent feeds to a Run call. |
60 | /// |
61 | /// This is a map of `Output` objects returned by op-constructors to the value |
62 | /// to feed them with. See `Input::Initializer` for details on what can be |
63 | /// used as feed values. |
64 | typedef std::unordered_map<Output, Input::Initializer, OutputHash> FeedType; |
65 | |
66 | /// Create a new session to evaluate the graph contained in `scope` by |
67 | /// connecting to the TensorFlow runtime specified by `target`. |
68 | ClientSession(const Scope& scope, const string& target); |
69 | |
70 | /// Same as above, but use the empty string ("") as the target specification. |
71 | explicit ClientSession(const Scope& scope); |
72 | |
73 | /// Create a new session, configuring it with `session_options`. |
74 | ClientSession(const Scope& scope, const SessionOptions& session_options); |
75 | |
76 | ~ClientSession(); |
77 | |
78 | /// Evaluate the tensors in `fetch_outputs`. The values are returned as |
79 | /// `Tensor` objects in `outputs`. The number and order of `outputs` will |
80 | /// match `fetch_outputs`. |
81 | Status Run(const std::vector<Output>& fetch_outputs, |
82 | std::vector<Tensor>* outputs) const; |
83 | |
84 | /// Same as above, but use the mapping in `inputs` as feeds. |
85 | Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs, |
86 | std::vector<Tensor>* outputs) const; |
87 | |
88 | /// Same as above. Additionally runs the operations ins `run_outputs`. |
89 | Status Run(const FeedType& inputs, const std::vector<Output>& fetch_outputs, |
90 | const std::vector<Operation>& run_outputs, |
91 | std::vector<Tensor>* outputs) const; |
92 | |
93 | /// Use `run_options` to turn on performance profiling. `run_metadata`, if not |
94 | /// null, is filled in with the profiling results. |
95 | Status Run(const RunOptions& run_options, const FeedType& inputs, |
96 | const std::vector<Output>& fetch_outputs, |
97 | const std::vector<Operation>& run_outputs, |
98 | std::vector<Tensor>* outputs, RunMetadata* run_metadata) const; |
99 | |
100 | /// Same as above. Additionally allows user to provide custom threadpool |
101 | /// implementation via ThreadPoolOptions. |
102 | Status Run(const RunOptions& run_options, const FeedType& inputs, |
103 | const std::vector<Output>& fetch_outputs, |
104 | const std::vector<Operation>& run_outputs, |
105 | std::vector<Tensor>* outputs, RunMetadata* run_metadata, |
106 | const thread::ThreadPoolOptions& threadpool_options) const; |
107 | |
108 | /// \brief A handle to a subgraph, created with |
109 | /// `ClientSession::MakeCallable()`. |
110 | typedef int64_t CallableHandle; |
111 | |
112 | /// \brief Creates a `handle` for invoking the subgraph defined by |
113 | /// `callable_options`. |
114 | /// NOTE: This API is still experimental and may change. |
115 | Status MakeCallable(const CallableOptions& callable_options, |
116 | CallableHandle* out_handle); |
117 | |
118 | /// \brief Invokes the subgraph named by `handle` with the given options and |
119 | /// input tensors. |
120 | /// |
121 | /// The order of tensors in `feed_tensors` must match the order of names in |
122 | /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will |
123 | /// match the order of names in `CallableOptions::fetch()` when this subgraph |
124 | /// was created. |
125 | /// NOTE: This API is still experimental and may change. |
126 | Status RunCallable(CallableHandle handle, |
127 | const std::vector<Tensor>& feed_tensors, |
128 | std::vector<Tensor>* fetch_tensors, |
129 | RunMetadata* run_metadata); |
130 | |
131 | /// \brief Invokes the subgraph named by `handle` with the given options and |
132 | /// input tensors. |
133 | /// |
134 | /// The order of tensors in `feed_tensors` must match the order of names in |
135 | /// `CallableOptions::feed()` and the order of tensors in `fetch_tensors` will |
136 | /// match the order of names in `CallableOptions::fetch()` when this subgraph |
137 | /// was created. |
138 | /// NOTE: This API is still experimental and may change. |
139 | Status RunCallable(CallableHandle handle, |
140 | const std::vector<Tensor>& feed_tensors, |
141 | std::vector<Tensor>* fetch_tensors, |
142 | RunMetadata* run_metadata, |
143 | const thread::ThreadPoolOptions& options); |
144 | |
145 | /// \brief Releases resources associated with the given `handle` in this |
146 | /// session. |
147 | /// NOTE: This API is still experimental and may change. |
148 | Status ReleaseCallable(CallableHandle handle); |
149 | |
150 | private: |
151 | class Impl; |
152 | std::unique_ptr<Impl> impl_; |
153 | Impl* impl() { return impl_.get(); } |
154 | const Impl* impl() const { return impl_.get(); } |
155 | }; |
156 | |
157 | /// @} |
158 | |
159 | } // end namespace tensorflow |
160 | |
161 | #endif // TENSORFLOW_CC_CLIENT_CLIENT_SESSION_H_ |
162 | |