1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
18
19#include <atomic>
20#include <memory>
21#include <string>
22#include <unordered_map>
23#include <unordered_set>
24#include <vector>
25
26#include "tensorflow/core/common_runtime/costmodel_manager.h"
27#include "tensorflow/core/common_runtime/debugger_state_interface.h"
28#include "tensorflow/core/common_runtime/device_mgr.h"
29#include "tensorflow/core/common_runtime/device_set.h"
30#include "tensorflow/core/common_runtime/executor.h"
31#include "tensorflow/core/common_runtime/graph_execution_state.h"
32#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
33#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
34#include "tensorflow/core/common_runtime/session_factory.h"
35#include "tensorflow/core/framework/cancellation.h"
36#include "tensorflow/core/framework/collective.h"
37#include "tensorflow/core/framework/graph.pb.h"
38#include "tensorflow/core/framework/session_state.h"
39#include "tensorflow/core/framework/tensor.h"
40#include "tensorflow/core/lib/core/errors.h"
41#include "tensorflow/core/lib/core/status.h"
42#include "tensorflow/core/platform/macros.h"
43#include "tensorflow/core/platform/mutex.h"
44#include "tensorflow/core/platform/thread_annotations.h"
45#include "tensorflow/core/platform/types.h"
46#include "tensorflow/core/public/session.h"
47
48namespace tensorflow {
49
50class CostModel;
51class DebugGateway;
52class Device;
53class DirectSessionFactory;
54
55class DirectSession : public Session {
56 public:
57 typedef std::function<void(Session*)> CloseCallback;
58
59 // Takes ownership of 'device_mgr'.
60 // 'factory' is used to unregister the DirectSession with 'factory' when its
61 // closed. This ensures that Reset requests from the 'factory' don't get sent
62 // to sessions that are already closed.
63 DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr,
64 DirectSessionFactory* factory);
65 ~DirectSession() override;
66
67 typedef std::vector<std::pair<string, Tensor>> NamedTensorList;
68 typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap;
69
70 ::tensorflow::Status Create(const GraphDef& graph) override;
71 ::tensorflow::Status Create(GraphDef&& graph) override;
72 ::tensorflow::Status Extend(const GraphDef& graph) override;
73 ::tensorflow::Status Extend(GraphDef&& graph) override;
74 ::tensorflow::Status Run(const NamedTensorList& inputs,
75 const std::vector<string>& output_names,
76 const std::vector<string>& target_nodes,
77 std::vector<Tensor>* outputs) override;
78
79 // NOTE: Experimental and subject to change.
80 ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options,
81 const NamedTensorList& inputs,
82 const std::vector<string>& output_names,
83 const std::vector<string>& target_nodes,
84 std::vector<Tensor>* outputs,
85 RunMetadata* run_metadata) override;
86
87 // NOTE: Experimental and subject to change.
88 ::tensorflow::Status Run(
89 const ::tensorflow::RunOptions& run_options,
90 const NamedTensorList& inputs, const std::vector<string>& output_names,
91 const std::vector<string>& target_nodes, std::vector<Tensor>* outputs,
92 RunMetadata* run_metadata,
93 const thread::ThreadPoolOptions& threadpool_options) override;
94
95 // NOTE: PRunSetup and PRun are added to support partial execution. This
96 // feature is experimental and subject to change.
97 ::tensorflow::Status PRunSetup(const std::vector<string>& input_names,
98 const std::vector<string>& output_names,
99 const std::vector<string>& target_nodes,
100 string* handle) override;
101 ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs,
102 const std::vector<string>& output_names,
103 std::vector<Tensor>* outputs) override;
104
105 // Reset clears 'containers' from the device_mgr of the DirectSession.
106 // If 'containers' is empty, then Reset clears the default container.
107 ::tensorflow::Status Reset(const std::vector<string>& containers);
108
109 ::tensorflow::Status ListDevices(
110 std::vector<DeviceAttributes>* response) override;
111 ::tensorflow::Status Close() override;
112 ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override {
113 *output = device_mgr_.get();
114 return OkStatus();
115 }
116
117 void ExportCostModels(CostModelManager::CostModelMap* cost_models) {
118 cost_model_manager_.ExportCostModels(cost_models);
119 }
120
121 ::tensorflow::Status MakeCallable(const CallableOptions& callable_options,
122 CallableHandle* out_handle) override;
123
124 ::tensorflow::Status RunCallable(CallableHandle handle,
125 const std::vector<Tensor>& feed_tensors,
126 std::vector<Tensor>* fetch_tensors,
127 RunMetadata* run_metadata) override;
128
129 ::tensorflow::Status RunCallable(
130 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
131 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
132 const thread::ThreadPoolOptions& threadpool_options) override;
133
134 ::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
135
136 ::tensorflow::Status Finalize() override;
137
138 const SessionOptions& options() const { return options_; }
139
140 private:
141 // For access to collective_graph_key_.
142 friend class DirectSessionCollectiveTest;
143
144 // We create one executor and its dependent library runtime for
145 // every partition.
146 struct PerPartitionExecutorsAndLib {
147 std::unique_ptr<Graph> graph = nullptr;
148 Device* device = nullptr; // not owned.
149 FunctionLibraryRuntime* flib = nullptr; // not owned.
150 std::unique_ptr<Executor> executor;
151 };
152
153 // An ExecutorsAndKeys is created for a given set of feeds/fetches.
154 // 'step_count' is the number of times this graph is executed.
155 // 'graph' is the entire graph being executed. 'name_to_node'
156 // maps node name to node. We keep 'graph' and 'name_to_node' only in
157 // the case of partial runs. Each item in 'items' is the executor for
158 // a partition of the graph bundled with its dependent library runtime.
159 // 'input_keys' are the rendezvous keys for the feeds and 'output_keys'
160 // are rendezvous keys for the fetches.
161 struct ExecutorsAndKeys {
162 ExecutorsAndKeys() : step_count(0) {}
163
164 std::atomic_int_fast64_t step_count;
165 std::unique_ptr<Graph> graph;
166 NameNodeMap name_to_node;
167 std::vector<PerPartitionExecutorsAndLib> items;
168 std::unordered_map<string, size_t> input_name_to_index;
169 std::unordered_map<string, string> input_name_to_rendezvous_key;
170 std::unordered_map<string, size_t> output_name_to_index;
171 std::unordered_map<string, string> output_name_to_rendezvous_key;
172
173 DataTypeVector input_types;
174 DataTypeVector output_types;
175
176 CallableOptions callable_options;
177
178 int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
179 };
180
181 // A FunctionInfo object is created for every unique set of feeds/fetches.
182 // This info could be folded into the ExecutorsAndKeys object but we would
183 // like to maintain a deletion order in which the OpKernels (owned by the
184 // executor) should be destroyed first, followed by the resources in the
185 // device and then followed by the function stuff.
186 // TODO(rohanj): Consolidate function library definitions so that we can
187 // instantiate only one ProcFLR and lib_def and make this just a member
188 // variable and not a vector.
189 // 'flib_def' is the function library used.
190 // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per
191 // device.
192 struct FunctionInfo {
193 std::unique_ptr<FunctionLibraryDefinition> flib_def;
194 std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
195 };
196
197 // For each live Run() call, the session maintains a RunState.
198 // 'status' is the current status of the execution.
199 struct RunState {
200 mutex mu;
201 Status status TF_GUARDED_BY(mu);
202 std::unique_ptr<CollectiveExecutor::Handle> collective_executor;
203 std::unique_ptr<StepStatsCollector> collector;
204 TensorStore tensor_store;
205 ScopedStepContainer step_container;
206
207 RunState(int64_t step_id, const std::vector<Device*>* devices);
208 };
209
210 // For each live partial execution, the session maintains a PartialRunState.
211 // 'executor_done' is "notified" when all executors are done. 'pending_inputs'
212 // are the set of pending feeds and 'pending_outputs' are the set of pending
213 // fetches.
214 struct PartialRunState : public RunState {
215 Notification executors_done;
216 std::unordered_map<string, bool> pending_inputs; // true if fed
217 std::unordered_map<string, bool> pending_outputs; // true if fetched
218 core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr;
219
220 PartialRunState(const std::vector<string>& pending_input_names,
221 const std::vector<string>& pending_output_names,
222 int64_t step_id, const std::vector<Device*>* devices);
223
224 // Returns true if all pending inputs and outputs have been completed.
225 bool PendingDone() const;
226
227 ~PartialRunState();
228 };
229
230 struct RunStateArgs {
231 explicit RunStateArgs(const DebugOptions& options)
232 : debug_options(options) {}
233
234 bool is_partial_run = false;
235 string handle;
236 std::unique_ptr<Graph> graph;
237 const DebugOptions& debug_options;
238 int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
239 };
240
241 // Retrieves an already existing set of executors to run 'inputs' and
242 // 'outputs', or creates and caches them for future use.
243 ::tensorflow::Status GetOrCreateExecutors(
244 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
245 gtl::ArraySlice<string> target_nodes,
246 ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args);
247
248 // Creates a set of executors to run the subgraph defined by
249 // `callable_options`.
250 ::tensorflow::Status CreateExecutors(
251 const CallableOptions& callable_options,
252 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
253 std::unique_ptr<FunctionInfo>* out_func_info,
254 RunStateArgs* run_state_args);
255
256 // Creates several graphs given the existing graph_def_ and the
257 // input feeds and fetches, given 'devices'. The graphs share a common
258 // function library 'flib_def'.
259 ::tensorflow::Status CreateGraphs(
260 const BuildGraphOptions& options,
261 std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
262 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
263 RunStateArgs* run_state_args, DataTypeVector* input_types,
264 DataTypeVector* output_types, int64_t* collective_graph_key);
265
266 ::tensorflow::Status RunInternal(
267 int64_t step_id, const RunOptions& run_options,
268 CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
269 RunMetadata* run_metadata,
270 const thread::ThreadPoolOptions& threadpool_options);
271
272 // Returns whether inter-op execution uses a global pool or the input
273 // `run_options` requests being run on inter_op_thread_pool = 0 in case
274 // multiple pools are configured.
275 bool ShouldUseRunHandlerPool(const RunOptions& run_options) const;
276
277 ::tensorflow::Status ExtendLocked(GraphDef&& graph)
278 TF_EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
279
280 ::tensorflow::Status ResourceHandleToInputTensor(
281 const Tensor& resource_tensor, Tensor* retrieved_tensor);
282
283 // Feeds more inputs to the executors, triggering further execution.
284 ::tensorflow::Status SendPRunInputs(
285 const std::vector<std::pair<string, Tensor>>& inputs,
286 const ExecutorsAndKeys* executors_and_keys,
287 IntraProcessRendezvous* rendez);
288
289 // Fetches more outputs from the executors. It waits until the output
290 // tensors are computed.
291 ::tensorflow::Status RecvPRunOutputs(
292 const std::vector<string>& output_names,
293 const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
294 std::vector<Tensor>* outputs);
295
296 // Check if the specified fetches can be computed from the feeds
297 // that we have already provided.
298 ::tensorflow::Status CheckFetch(
299 const std::vector<std::pair<string, Tensor>>& feeds,
300 const std::vector<string>& fetches,
301 const ExecutorsAndKeys* executors_and_keys,
302 const PartialRunState* run_state);
303
304 // Use the appropriate WaitForNotification function based on whether
305 // operation_timeout_in_ms is greater than 0.
306 //
307 // If the timeout expires, the `cm->StartCancel()` will be called.
308 ::tensorflow::Status WaitForNotification(Notification* n,
309 int64_t timeout_in_ms);
310 void WaitForNotification(Notification* n, RunState* run_state,
311 CancellationManager* cm, int64_t timeout_in_ms);
312
313 ::tensorflow::Status CheckNotClosed() {
314 mutex_lock l(closed_lock_);
315 if (closed_) return errors::Cancelled("Session has been closed.");
316 return OkStatus();
317 }
318
319 ::tensorflow::Status CheckGraphCreated(const char* method) {
320 mutex_lock l(graph_state_lock_);
321 if (!graph_created_) {
322 return errors::InvalidArgument(
323 "Session was not created with a graph before ", method, "!");
324 }
325 return OkStatus();
326 }
327
328 ::tensorflow::Status CreateDebuggerState(
329 const CallableOptions& options, int64_t global_step,
330 int64_t session_run_index, int64_t executor_step_index,
331 std::unique_ptr<DebuggerStateInterface>* debugger_state);
332
333 ::tensorflow::Status DecorateAndPublishGraphForDebug(
334 const DebugOptions& debug_options, Graph* graph, Device* device);
335
336 const SessionOptions options_;
337
338 // Device structures.
339 const std::unique_ptr<const DeviceMgr> device_mgr_;
340 std::vector<Device*> devices_; // not owned
341 DeviceSet device_set_;
342
343 // Unique session identifier.
344 string session_handle_;
345 mutex graph_state_lock_;
346 bool graph_created_ TF_GUARDED_BY(graph_state_lock_) = false;
347 bool finalized_ TF_GUARDED_BY(graph_state_lock_) = false;
348
349 // The thread-pools to use for running ops, with a bool indicating if the pool
350 // is owned.
351 std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_;
352
353 Status init_error_; // Set to an error if construction failed.
354
355 // If true, blocks until device has finished all queued operations in a step.
356 bool sync_on_finish_ = true;
357
358 std::vector<std::unique_ptr<FunctionInfo>> functions_
359 TF_GUARDED_BY(executor_lock_);
360
361 mutex executor_lock_; // protects executors_
362 // Holds mappings from signature to the executors that process
363 // it. The reason for a level of indirection around mapped_type is
364 // to guarantee address stability.
365 // The map value is a shared_ptr since multiple map keys can point to the
366 // same ExecutorsAndKey object.
367 std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_
368 TF_GUARDED_BY(executor_lock_);
369
370 class RunCallableCallFrame;
371 struct Callable {
372 std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
373 std::shared_ptr<FunctionInfo> function_info;
374 ~Callable();
375 };
376 mutex callables_lock_;
377 int64_t next_callable_handle_ TF_GUARDED_BY(callables_lock_) = 0;
378 std::unordered_map<int64_t, Callable> callables_
379 TF_GUARDED_BY(callables_lock_);
380
381 // Holds mappings from handle to partial run state.
382 std::unordered_map<string, std::unique_ptr<PartialRunState>> partial_runs_
383 TF_GUARDED_BY(executor_lock_);
384
385 // This holds all the tensors that are currently alive in the session.
386 SessionState session_state_;
387
388 DirectSessionFactory* const factory_; // not owned
389 CancellationManager* cancellation_manager_;
390 std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_;
391
392 // Map of placed stateful nodes, i.e. nodes for which is_stateful()
393 // is true, such as "params" and "queue" nodes. Once placed these
394 // nodes can not be moved to a different device. Maps node names to
395 // device names.
396 std::unordered_map<string, string> stateful_placements_
397 TF_GUARDED_BY(graph_state_lock_);
398
399 // Execution_state; used when placing the entire graph.
400 std::unique_ptr<GraphExecutionState> execution_state_
401 TF_GUARDED_BY(graph_state_lock_);
402
403 // The function library, before any rewrites or optimizations have been
404 // performed. In particular, CreateGraphs() may need to modify the function
405 // library; it copies and modifies the function library.
406 std::unique_ptr<FunctionLibraryDefinition> flib_def_;
407
408 // true if the Session has been Closed.
409 mutex closed_lock_;
410 bool closed_ TF_GUARDED_BY(closed_lock_) = false;
411
412 // For generating unique names for this session instance.
413 std::atomic<int64_t> edge_name_counter_ = {0};
414 std::atomic<int64_t> handle_name_counter_ = {0};
415
416 // For generating step ids that are unique among all sessions.
417 static std::atomic_int_fast64_t step_id_counter_;
418
419 // Global timeout for all blocking operations in this session.
420 const int64_t operation_timeout_in_ms_ = 0;
421
422 // Manages all the cost models for the graphs executed in this session.
423 CostModelManager cost_model_manager_;
424
425 // For testing collective graph key generation.
426 mutex collective_graph_key_lock_;
427 int64_t collective_graph_key_ TF_GUARDED_BY(collective_graph_key_lock_) = -1;
428
429 // Run in caller's thread if RunOptions.inter_op_thread_pool is negative or
430 // all of following conditions are met:
431 // 1. This session doesn't own any thread pool.
432 // 2. RunOptions.inter_op_thread_pool is unspecified or 0.
433 // 3. This session has a single executor.
434 // 4. config.inter_op_parallelism_threads is specified to negative explicitly
435 // or through environment variable TF_NUM_INTEROP_THREADS.
436 // 5. RunOptions.experimental.use_run_handler_pool is unspecified or false.
437 // Otherwise run in global thread pool, session owned thread pool or handler
438 // pool according to other specifications of RunOptions and ConfigProto.
439 bool run_in_caller_thread_ = false;
440
441 TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
442
443 // EXPERIMENTAL: debugger (tfdbg) related
444 friend class DebugGateway;
445};
446
447} // end namespace tensorflow
448
449#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
450