1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/client/client_session.h" |
17 | |
18 | #include <unordered_map> |
19 | #include <utility> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/platform/env.h" |
23 | #include "tensorflow/core/platform/mutex.h" |
24 | #include "tensorflow/core/protobuf/config.pb.h" |
25 | #include "tensorflow/core/public/session.h" |
26 | #include "tensorflow/core/public/session_options.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | class ClientSession::Impl { |
31 | private: |
32 | friend class ClientSession; |
33 | |
34 | Impl(Session* session, std::shared_ptr<Graph> graph) |
35 | : session_(session), graph_(std::move(graph)) {} |
36 | |
37 | static SessionOptions MakeDefaultSessionOptions(const string& target); |
38 | Status MaybeExtendGraph() const; |
39 | |
40 | std::unique_ptr<Session> session_; |
41 | std::shared_ptr<Graph> graph_; |
42 | |
43 | mutable mutex mu_; |
44 | mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0; |
45 | }; |
46 | |
47 | ClientSession::ClientSession(const Scope& scope, const string& target) |
48 | : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {} |
49 | |
50 | ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "" ) {} |
51 | |
52 | ClientSession::ClientSession(const Scope& scope, |
53 | const SessionOptions& session_options) { |
54 | Session* new_session; |
55 | Status status = NewSession(session_options, &new_session); |
56 | TF_CHECK_OK(status) << status; |
57 | impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr())); |
58 | CHECK_NOTNULL(impl()->session_.get()); |
59 | } |
60 | |
61 | // Define destructor here so we can forward declare `Impl` in client_session.h. |
62 | // If we define a dtor in the header file or use the default dtor, |
63 | // unique_ptr<Impl> needs the complete type. |
64 | ClientSession::~ClientSession() {} |
65 | |
66 | SessionOptions ClientSession::Impl::MakeDefaultSessionOptions( |
67 | const string& target) { |
68 | SessionOptions options; |
69 | options.env = Env::Default(); |
70 | options.target = target; |
71 | return options; |
72 | } |
73 | |
74 | Status ClientSession::Run(const std::vector<Output>& fetch_outputs, |
75 | std::vector<Tensor>* outputs) const { |
76 | return Run(FeedType{}, fetch_outputs, {}, outputs); |
77 | } |
78 | |
79 | Status ClientSession::Run(const FeedType& inputs, |
80 | const std::vector<Output>& fetch_outputs, |
81 | std::vector<Tensor>* outputs) const { |
82 | return Run(inputs, fetch_outputs, {}, outputs); |
83 | } |
84 | |
85 | Status ClientSession::Run(const FeedType& inputs, |
86 | const std::vector<Output>& fetch_outputs, |
87 | const std::vector<Operation>& run_outputs, |
88 | std::vector<Tensor>* outputs) const { |
89 | return Run(RunOptions(), inputs, fetch_outputs, run_outputs, outputs, |
90 | nullptr); |
91 | } |
92 | |
93 | Status ClientSession::Impl::MaybeExtendGraph() const { |
94 | mutex_lock l(mu_); |
95 | int num_nodes = graph_->num_node_ids(); |
96 | if (num_nodes > last_num_graph_nodes_) { |
97 | GraphDef graph_def; |
98 | graph_->ToGraphDefSubRange(&graph_def, last_num_graph_nodes_); |
99 | last_num_graph_nodes_ = num_nodes; |
100 | return session_->Extend(graph_def); |
101 | } |
102 | return OkStatus(); |
103 | } |
104 | |
105 | Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs, |
106 | const std::vector<Output>& fetch_outputs, |
107 | const std::vector<Operation>& run_outputs, |
108 | std::vector<Tensor>* outputs, |
109 | RunMetadata* run_metadata) const { |
110 | std::vector<std::pair<string, Tensor>> feeds; |
111 | for (auto const& feed : inputs) { |
112 | TF_RETURN_IF_ERROR(feed.second.status); |
113 | feeds.emplace_back(feed.first.name(), feed.second.tensor); |
114 | } |
115 | std::vector<string> output_tensor_names; |
116 | output_tensor_names.reserve(fetch_outputs.size()); |
117 | for (auto const& output : fetch_outputs) { |
118 | output_tensor_names.push_back(output.name()); |
119 | } |
120 | std::vector<string> target_node_names; |
121 | target_node_names.reserve(run_outputs.size()); |
122 | for (auto const& output : run_outputs) { |
123 | target_node_names.push_back(output.node()->name()); |
124 | } |
125 | TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); |
126 | return impl()->session_->Run(run_options, feeds, output_tensor_names, |
127 | target_node_names, outputs, run_metadata); |
128 | } |
129 | |
130 | Status ClientSession::Run( |
131 | const RunOptions& run_options, const FeedType& inputs, |
132 | const std::vector<Output>& fetch_outputs, |
133 | const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs, |
134 | RunMetadata* run_metadata, |
135 | const thread::ThreadPoolOptions& threadpool_options) const { |
136 | std::vector<std::pair<string, Tensor>> feeds; |
137 | for (auto const& feed : inputs) { |
138 | TF_RETURN_IF_ERROR(feed.second.status); |
139 | feeds.emplace_back(feed.first.name(), feed.second.tensor); |
140 | } |
141 | std::vector<string> output_tensor_names; |
142 | output_tensor_names.reserve(fetch_outputs.size()); |
143 | for (auto const& output : fetch_outputs) { |
144 | output_tensor_names.push_back(output.name()); |
145 | } |
146 | std::vector<string> target_node_names; |
147 | target_node_names.reserve(run_outputs.size()); |
148 | for (auto const& output : run_outputs) { |
149 | target_node_names.push_back(output.node()->name()); |
150 | } |
151 | TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); |
152 | return impl()->session_->Run(run_options, feeds, output_tensor_names, |
153 | target_node_names, outputs, run_metadata, |
154 | threadpool_options); |
155 | } |
156 | |
157 | Status ClientSession::MakeCallable(const CallableOptions& callable_options, |
158 | CallableHandle* out_handle) { |
159 | TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph()); |
160 | return impl()->session_->MakeCallable(callable_options, out_handle); |
161 | } |
162 | |
163 | Status ClientSession::RunCallable(CallableHandle handle, |
164 | const std::vector<Tensor>& feed_tensors, |
165 | std::vector<Tensor>* fetch_tensors, |
166 | RunMetadata* run_metadata) { |
167 | return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, |
168 | run_metadata); |
169 | } |
170 | |
171 | Status ClientSession::RunCallable(CallableHandle handle, |
172 | const std::vector<Tensor>& feed_tensors, |
173 | std::vector<Tensor>* fetch_tensors, |
174 | RunMetadata* run_metadata, |
175 | const thread::ThreadPoolOptions& options) { |
176 | return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors, |
177 | run_metadata, options); |
178 | } |
179 | |
180 | Status ClientSession::ReleaseCallable(CallableHandle handle) { |
181 | return impl()->session_->ReleaseCallable(handle); |
182 | } |
183 | |
184 | } // end namespace tensorflow |
185 | |