1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/client/client_session.h"
17
18#include <unordered_map>
19#include <utility>
20#include <vector>
21
22#include "tensorflow/core/platform/env.h"
23#include "tensorflow/core/platform/mutex.h"
24#include "tensorflow/core/protobuf/config.pb.h"
25#include "tensorflow/core/public/session.h"
26#include "tensorflow/core/public/session_options.h"
27
28namespace tensorflow {
29
30class ClientSession::Impl {
31 private:
32 friend class ClientSession;
33
34 Impl(Session* session, std::shared_ptr<Graph> graph)
35 : session_(session), graph_(std::move(graph)) {}
36
37 static SessionOptions MakeDefaultSessionOptions(const string& target);
38 Status MaybeExtendGraph() const;
39
40 std::unique_ptr<Session> session_;
41 std::shared_ptr<Graph> graph_;
42
43 mutable mutex mu_;
44 mutable int last_num_graph_nodes_ TF_GUARDED_BY(mu_) = 0;
45};
46
47ClientSession::ClientSession(const Scope& scope, const string& target)
48 : ClientSession(scope, Impl::MakeDefaultSessionOptions(target)) {}
49
50ClientSession::ClientSession(const Scope& scope) : ClientSession(scope, "") {}
51
52ClientSession::ClientSession(const Scope& scope,
53 const SessionOptions& session_options) {
54 Session* new_session;
55 Status status = NewSession(session_options, &new_session);
56 TF_CHECK_OK(status) << status;
57 impl_.reset(new Impl(new_session, scope.graph_as_shared_ptr()));
58 CHECK_NOTNULL(impl()->session_.get());
59}
60
61// Define destructor here so we can forward declare `Impl` in client_session.h.
62// If we define a dtor in the header file or use the default dtor,
63// unique_ptr<Impl> needs the complete type.
64ClientSession::~ClientSession() {}
65
66SessionOptions ClientSession::Impl::MakeDefaultSessionOptions(
67 const string& target) {
68 SessionOptions options;
69 options.env = Env::Default();
70 options.target = target;
71 return options;
72}
73
74Status ClientSession::Run(const std::vector<Output>& fetch_outputs,
75 std::vector<Tensor>* outputs) const {
76 return Run(FeedType{}, fetch_outputs, {}, outputs);
77}
78
79Status ClientSession::Run(const FeedType& inputs,
80 const std::vector<Output>& fetch_outputs,
81 std::vector<Tensor>* outputs) const {
82 return Run(inputs, fetch_outputs, {}, outputs);
83}
84
85Status ClientSession::Run(const FeedType& inputs,
86 const std::vector<Output>& fetch_outputs,
87 const std::vector<Operation>& run_outputs,
88 std::vector<Tensor>* outputs) const {
89 return Run(RunOptions(), inputs, fetch_outputs, run_outputs, outputs,
90 nullptr);
91}
92
93Status ClientSession::Impl::MaybeExtendGraph() const {
94 mutex_lock l(mu_);
95 int num_nodes = graph_->num_node_ids();
96 if (num_nodes > last_num_graph_nodes_) {
97 GraphDef graph_def;
98 graph_->ToGraphDefSubRange(&graph_def, last_num_graph_nodes_);
99 last_num_graph_nodes_ = num_nodes;
100 return session_->Extend(graph_def);
101 }
102 return OkStatus();
103}
104
105Status ClientSession::Run(const RunOptions& run_options, const FeedType& inputs,
106 const std::vector<Output>& fetch_outputs,
107 const std::vector<Operation>& run_outputs,
108 std::vector<Tensor>* outputs,
109 RunMetadata* run_metadata) const {
110 std::vector<std::pair<string, Tensor>> feeds;
111 for (auto const& feed : inputs) {
112 TF_RETURN_IF_ERROR(feed.second.status);
113 feeds.emplace_back(feed.first.name(), feed.second.tensor);
114 }
115 std::vector<string> output_tensor_names;
116 output_tensor_names.reserve(fetch_outputs.size());
117 for (auto const& output : fetch_outputs) {
118 output_tensor_names.push_back(output.name());
119 }
120 std::vector<string> target_node_names;
121 target_node_names.reserve(run_outputs.size());
122 for (auto const& output : run_outputs) {
123 target_node_names.push_back(output.node()->name());
124 }
125 TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
126 return impl()->session_->Run(run_options, feeds, output_tensor_names,
127 target_node_names, outputs, run_metadata);
128}
129
130Status ClientSession::Run(
131 const RunOptions& run_options, const FeedType& inputs,
132 const std::vector<Output>& fetch_outputs,
133 const std::vector<Operation>& run_outputs, std::vector<Tensor>* outputs,
134 RunMetadata* run_metadata,
135 const thread::ThreadPoolOptions& threadpool_options) const {
136 std::vector<std::pair<string, Tensor>> feeds;
137 for (auto const& feed : inputs) {
138 TF_RETURN_IF_ERROR(feed.second.status);
139 feeds.emplace_back(feed.first.name(), feed.second.tensor);
140 }
141 std::vector<string> output_tensor_names;
142 output_tensor_names.reserve(fetch_outputs.size());
143 for (auto const& output : fetch_outputs) {
144 output_tensor_names.push_back(output.name());
145 }
146 std::vector<string> target_node_names;
147 target_node_names.reserve(run_outputs.size());
148 for (auto const& output : run_outputs) {
149 target_node_names.push_back(output.node()->name());
150 }
151 TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
152 return impl()->session_->Run(run_options, feeds, output_tensor_names,
153 target_node_names, outputs, run_metadata,
154 threadpool_options);
155}
156
157Status ClientSession::MakeCallable(const CallableOptions& callable_options,
158 CallableHandle* out_handle) {
159 TF_RETURN_IF_ERROR(impl()->MaybeExtendGraph());
160 return impl()->session_->MakeCallable(callable_options, out_handle);
161}
162
163Status ClientSession::RunCallable(CallableHandle handle,
164 const std::vector<Tensor>& feed_tensors,
165 std::vector<Tensor>* fetch_tensors,
166 RunMetadata* run_metadata) {
167 return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
168 run_metadata);
169}
170
171Status ClientSession::RunCallable(CallableHandle handle,
172 const std::vector<Tensor>& feed_tensors,
173 std::vector<Tensor>* fetch_tensors,
174 RunMetadata* run_metadata,
175 const thread::ThreadPoolOptions& options) {
176 return impl()->session_->RunCallable(handle, feed_tensors, fetch_tensors,
177 run_metadata, options);
178}
179
180Status ClientSession::ReleaseCallable(CallableHandle handle) {
181 return impl()->session_->ReleaseCallable(handle);
182}
183
184} // end namespace tensorflow
185