1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ |
18 | |
19 | #include <atomic> |
20 | #include <memory> |
21 | #include <string> |
22 | #include <unordered_map> |
23 | #include <unordered_set> |
24 | #include <vector> |
25 | |
26 | #include "tensorflow/core/common_runtime/costmodel_manager.h" |
27 | #include "tensorflow/core/common_runtime/debugger_state_interface.h" |
28 | #include "tensorflow/core/common_runtime/device_mgr.h" |
29 | #include "tensorflow/core/common_runtime/device_set.h" |
30 | #include "tensorflow/core/common_runtime/executor.h" |
31 | #include "tensorflow/core/common_runtime/graph_execution_state.h" |
32 | #include "tensorflow/core/common_runtime/process_function_library_runtime.h" |
33 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
34 | #include "tensorflow/core/common_runtime/session_factory.h" |
35 | #include "tensorflow/core/framework/cancellation.h" |
36 | #include "tensorflow/core/framework/collective.h" |
37 | #include "tensorflow/core/framework/graph.pb.h" |
38 | #include "tensorflow/core/framework/session_state.h" |
39 | #include "tensorflow/core/framework/tensor.h" |
40 | #include "tensorflow/core/lib/core/errors.h" |
41 | #include "tensorflow/core/lib/core/status.h" |
42 | #include "tensorflow/core/platform/macros.h" |
43 | #include "tensorflow/core/platform/mutex.h" |
44 | #include "tensorflow/core/platform/thread_annotations.h" |
45 | #include "tensorflow/core/platform/types.h" |
46 | #include "tensorflow/core/public/session.h" |
47 | |
48 | namespace tensorflow { |
49 | |
50 | class CostModel; |
51 | class DebugGateway; |
52 | class Device; |
53 | class DirectSessionFactory; |
54 | |
55 | class DirectSession : public Session { |
56 | public: |
57 | typedef std::function<void(Session*)> CloseCallback; |
58 | |
59 | // Takes ownership of 'device_mgr'. |
60 | // 'factory' is used to unregister the DirectSession with 'factory' when its |
61 | // closed. This ensures that Reset requests from the 'factory' don't get sent |
62 | // to sessions that are already closed. |
63 | DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, |
64 | DirectSessionFactory* factory); |
65 | ~DirectSession() override; |
66 | |
67 | typedef std::vector<std::pair<string, Tensor>> NamedTensorList; |
68 | typedef std::unordered_map<StringPiece, Node*, StringPieceHasher> NameNodeMap; |
69 | |
70 | ::tensorflow::Status Create(const GraphDef& graph) override; |
71 | ::tensorflow::Status Create(GraphDef&& graph) override; |
72 | ::tensorflow::Status Extend(const GraphDef& graph) override; |
73 | ::tensorflow::Status Extend(GraphDef&& graph) override; |
74 | ::tensorflow::Status Run(const NamedTensorList& inputs, |
75 | const std::vector<string>& output_names, |
76 | const std::vector<string>& target_nodes, |
77 | std::vector<Tensor>* outputs) override; |
78 | |
79 | // NOTE: Experimental and subject to change. |
80 | ::tensorflow::Status Run(const ::tensorflow::RunOptions& run_options, |
81 | const NamedTensorList& inputs, |
82 | const std::vector<string>& output_names, |
83 | const std::vector<string>& target_nodes, |
84 | std::vector<Tensor>* outputs, |
85 | RunMetadata* run_metadata) override; |
86 | |
87 | // NOTE: Experimental and subject to change. |
88 | ::tensorflow::Status Run( |
89 | const ::tensorflow::RunOptions& run_options, |
90 | const NamedTensorList& inputs, const std::vector<string>& output_names, |
91 | const std::vector<string>& target_nodes, std::vector<Tensor>* outputs, |
92 | RunMetadata* run_metadata, |
93 | const thread::ThreadPoolOptions& threadpool_options) override; |
94 | |
95 | // NOTE: PRunSetup and PRun are added to support partial execution. This |
96 | // feature is experimental and subject to change. |
97 | ::tensorflow::Status PRunSetup(const std::vector<string>& input_names, |
98 | const std::vector<string>& output_names, |
99 | const std::vector<string>& target_nodes, |
100 | string* handle) override; |
101 | ::tensorflow::Status PRun(const string& handle, const NamedTensorList& inputs, |
102 | const std::vector<string>& output_names, |
103 | std::vector<Tensor>* outputs) override; |
104 | |
105 | // Reset clears 'containers' from the device_mgr of the DirectSession. |
106 | // If 'containers' is empty, then Reset clears the default container. |
107 | ::tensorflow::Status Reset(const std::vector<string>& containers); |
108 | |
109 | ::tensorflow::Status ListDevices( |
110 | std::vector<DeviceAttributes>* response) override; |
111 | ::tensorflow::Status Close() override; |
112 | ::tensorflow::Status LocalDeviceManager(const DeviceMgr** output) override { |
113 | *output = device_mgr_.get(); |
114 | return OkStatus(); |
115 | } |
116 | |
117 | void ExportCostModels(CostModelManager::CostModelMap* cost_models) { |
118 | cost_model_manager_.ExportCostModels(cost_models); |
119 | } |
120 | |
121 | ::tensorflow::Status MakeCallable(const CallableOptions& callable_options, |
122 | CallableHandle* out_handle) override; |
123 | |
124 | ::tensorflow::Status RunCallable(CallableHandle handle, |
125 | const std::vector<Tensor>& feed_tensors, |
126 | std::vector<Tensor>* fetch_tensors, |
127 | RunMetadata* run_metadata) override; |
128 | |
129 | ::tensorflow::Status RunCallable( |
130 | CallableHandle handle, const std::vector<Tensor>& feed_tensors, |
131 | std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata, |
132 | const thread::ThreadPoolOptions& threadpool_options) override; |
133 | |
134 | ::tensorflow::Status ReleaseCallable(CallableHandle handle) override; |
135 | |
136 | ::tensorflow::Status Finalize() override; |
137 | |
138 | const SessionOptions& options() const { return options_; } |
139 | |
140 | private: |
141 | // For access to collective_graph_key_. |
142 | friend class DirectSessionCollectiveTest; |
143 | |
144 | // We create one executor and its dependent library runtime for |
145 | // every partition. |
146 | struct PerPartitionExecutorsAndLib { |
147 | std::unique_ptr<Graph> graph = nullptr; |
148 | Device* device = nullptr; // not owned. |
149 | FunctionLibraryRuntime* flib = nullptr; // not owned. |
150 | std::unique_ptr<Executor> executor; |
151 | }; |
152 | |
153 | // An ExecutorsAndKeys is created for a given set of feeds/fetches. |
154 | // 'step_count' is the number of times this graph is executed. |
155 | // 'graph' is the entire graph being executed. 'name_to_node' |
156 | // maps node name to node. We keep 'graph' and 'name_to_node' only in |
157 | // the case of partial runs. Each item in 'items' is the executor for |
158 | // a partition of the graph bundled with its dependent library runtime. |
159 | // 'input_keys' are the rendezvous keys for the feeds and 'output_keys' |
160 | // are rendezvous keys for the fetches. |
161 | struct ExecutorsAndKeys { |
162 | ExecutorsAndKeys() : step_count(0) {} |
163 | |
164 | std::atomic_int_fast64_t step_count; |
165 | std::unique_ptr<Graph> graph; |
166 | NameNodeMap name_to_node; |
167 | std::vector<PerPartitionExecutorsAndLib> items; |
168 | std::unordered_map<string, size_t> input_name_to_index; |
169 | std::unordered_map<string, string> input_name_to_rendezvous_key; |
170 | std::unordered_map<string, size_t> output_name_to_index; |
171 | std::unordered_map<string, string> output_name_to_rendezvous_key; |
172 | |
173 | DataTypeVector input_types; |
174 | DataTypeVector output_types; |
175 | |
176 | CallableOptions callable_options; |
177 | |
178 | int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; |
179 | }; |
180 | |
181 | // A FunctionInfo object is created for every unique set of feeds/fetches. |
182 | // This info could be folded into the ExecutorsAndKeys object but we would |
183 | // like to maintain a deletion order in which the OpKernels (owned by the |
184 | // executor) should be destroyed first, followed by the resources in the |
185 | // device and then followed by the function stuff. |
186 | // TODO(rohanj): Consolidate function library definitions so that we can |
187 | // instantiate only one ProcFLR and lib_def and make this just a member |
188 | // variable and not a vector. |
189 | // 'flib_def' is the function library used. |
190 | // 'proc_flr' is the collection of FunctionLibraryRuntime objects, one per |
191 | // device. |
192 | struct FunctionInfo { |
193 | std::unique_ptr<FunctionLibraryDefinition> flib_def; |
194 | std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr; |
195 | }; |
196 | |
197 | // For each live Run() call, the session maintains a RunState. |
198 | // 'status' is the current status of the execution. |
199 | struct RunState { |
200 | mutex mu; |
201 | Status status TF_GUARDED_BY(mu); |
202 | std::unique_ptr<CollectiveExecutor::Handle> collective_executor; |
203 | std::unique_ptr<StepStatsCollector> collector; |
204 | TensorStore tensor_store; |
205 | ScopedStepContainer step_container; |
206 | |
207 | RunState(int64_t step_id, const std::vector<Device*>* devices); |
208 | }; |
209 | |
210 | // For each live partial execution, the session maintains a PartialRunState. |
211 | // 'executor_done' is "notified" when all executors are done. 'pending_inputs' |
212 | // are the set of pending feeds and 'pending_outputs' are the set of pending |
213 | // fetches. |
214 | struct PartialRunState : public RunState { |
215 | Notification executors_done; |
216 | std::unordered_map<string, bool> pending_inputs; // true if fed |
217 | std::unordered_map<string, bool> pending_outputs; // true if fetched |
218 | core::RefCountPtr<IntraProcessRendezvous> rendez = nullptr; |
219 | |
220 | PartialRunState(const std::vector<string>& pending_input_names, |
221 | const std::vector<string>& pending_output_names, |
222 | int64_t step_id, const std::vector<Device*>* devices); |
223 | |
224 | // Returns true if all pending inputs and outputs have been completed. |
225 | bool PendingDone() const; |
226 | |
227 | ~PartialRunState(); |
228 | }; |
229 | |
230 | struct RunStateArgs { |
231 | explicit RunStateArgs(const DebugOptions& options) |
232 | : debug_options(options) {} |
233 | |
234 | bool is_partial_run = false; |
235 | string handle; |
236 | std::unique_ptr<Graph> graph; |
237 | const DebugOptions& debug_options; |
238 | int64_t collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; |
239 | }; |
240 | |
241 | // Retrieves an already existing set of executors to run 'inputs' and |
242 | // 'outputs', or creates and caches them for future use. |
243 | ::tensorflow::Status GetOrCreateExecutors( |
244 | gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, |
245 | gtl::ArraySlice<string> target_nodes, |
246 | ExecutorsAndKeys** executors_and_keys, RunStateArgs* run_state_args); |
247 | |
248 | // Creates a set of executors to run the subgraph defined by |
249 | // `callable_options`. |
250 | ::tensorflow::Status CreateExecutors( |
251 | const CallableOptions& callable_options, |
252 | std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, |
253 | std::unique_ptr<FunctionInfo>* out_func_info, |
254 | RunStateArgs* run_state_args); |
255 | |
256 | // Creates several graphs given the existing graph_def_ and the |
257 | // input feeds and fetches, given 'devices'. The graphs share a common |
258 | // function library 'flib_def'. |
259 | ::tensorflow::Status CreateGraphs( |
260 | const BuildGraphOptions& options, |
261 | std::unordered_map<string, std::unique_ptr<Graph>>* outputs, |
262 | std::unique_ptr<FunctionLibraryDefinition>* flib_def, |
263 | RunStateArgs* run_state_args, DataTypeVector* input_types, |
264 | DataTypeVector* output_types, int64_t* collective_graph_key); |
265 | |
266 | ::tensorflow::Status RunInternal( |
267 | int64_t step_id, const RunOptions& run_options, |
268 | CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys, |
269 | RunMetadata* run_metadata, |
270 | const thread::ThreadPoolOptions& threadpool_options); |
271 | |
272 | // Returns whether inter-op execution uses a global pool or the input |
273 | // `run_options` requests being run on inter_op_thread_pool = 0 in case |
274 | // multiple pools are configured. |
275 | bool ShouldUseRunHandlerPool(const RunOptions& run_options) const; |
276 | |
277 | ::tensorflow::Status ExtendLocked(GraphDef&& graph) |
278 | TF_EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); |
279 | |
280 | ::tensorflow::Status ResourceHandleToInputTensor( |
281 | const Tensor& resource_tensor, Tensor* retrieved_tensor); |
282 | |
283 | // Feeds more inputs to the executors, triggering further execution. |
284 | ::tensorflow::Status SendPRunInputs( |
285 | const std::vector<std::pair<string, Tensor>>& inputs, |
286 | const ExecutorsAndKeys* executors_and_keys, |
287 | IntraProcessRendezvous* rendez); |
288 | |
289 | // Fetches more outputs from the executors. It waits until the output |
290 | // tensors are computed. |
291 | ::tensorflow::Status RecvPRunOutputs( |
292 | const std::vector<string>& output_names, |
293 | const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state, |
294 | std::vector<Tensor>* outputs); |
295 | |
296 | // Check if the specified fetches can be computed from the feeds |
297 | // that we have already provided. |
298 | ::tensorflow::Status CheckFetch( |
299 | const std::vector<std::pair<string, Tensor>>& feeds, |
300 | const std::vector<string>& fetches, |
301 | const ExecutorsAndKeys* executors_and_keys, |
302 | const PartialRunState* run_state); |
303 | |
304 | // Use the appropriate WaitForNotification function based on whether |
305 | // operation_timeout_in_ms is greater than 0. |
306 | // |
307 | // If the timeout expires, the `cm->StartCancel()` will be called. |
308 | ::tensorflow::Status WaitForNotification(Notification* n, |
309 | int64_t timeout_in_ms); |
310 | void WaitForNotification(Notification* n, RunState* run_state, |
311 | CancellationManager* cm, int64_t timeout_in_ms); |
312 | |
313 | ::tensorflow::Status CheckNotClosed() { |
314 | mutex_lock l(closed_lock_); |
315 | if (closed_) return errors::Cancelled("Session has been closed." ); |
316 | return OkStatus(); |
317 | } |
318 | |
319 | ::tensorflow::Status CheckGraphCreated(const char* method) { |
320 | mutex_lock l(graph_state_lock_); |
321 | if (!graph_created_) { |
322 | return errors::InvalidArgument( |
323 | "Session was not created with a graph before " , method, "!" ); |
324 | } |
325 | return OkStatus(); |
326 | } |
327 | |
328 | ::tensorflow::Status CreateDebuggerState( |
329 | const CallableOptions& options, int64_t global_step, |
330 | int64_t session_run_index, int64_t executor_step_index, |
331 | std::unique_ptr<DebuggerStateInterface>* debugger_state); |
332 | |
333 | ::tensorflow::Status DecorateAndPublishGraphForDebug( |
334 | const DebugOptions& debug_options, Graph* graph, Device* device); |
335 | |
336 | const SessionOptions options_; |
337 | |
338 | // Device structures. |
339 | const std::unique_ptr<const DeviceMgr> device_mgr_; |
340 | std::vector<Device*> devices_; // not owned |
341 | DeviceSet device_set_; |
342 | |
343 | // Unique session identifier. |
344 | string session_handle_; |
345 | mutex graph_state_lock_; |
346 | bool graph_created_ TF_GUARDED_BY(graph_state_lock_) = false; |
347 | bool finalized_ TF_GUARDED_BY(graph_state_lock_) = false; |
348 | |
349 | // The thread-pools to use for running ops, with a bool indicating if the pool |
350 | // is owned. |
351 | std::vector<std::pair<thread::ThreadPool*, bool>> thread_pools_; |
352 | |
353 | Status init_error_; // Set to an error if construction failed. |
354 | |
355 | // If true, blocks until device has finished all queued operations in a step. |
356 | bool sync_on_finish_ = true; |
357 | |
358 | std::vector<std::unique_ptr<FunctionInfo>> functions_ |
359 | TF_GUARDED_BY(executor_lock_); |
360 | |
361 | mutex executor_lock_; // protects executors_ |
362 | // Holds mappings from signature to the executors that process |
363 | // it. The reason for a level of indirection around mapped_type is |
364 | // to guarantee address stability. |
365 | // The map value is a shared_ptr since multiple map keys can point to the |
366 | // same ExecutorsAndKey object. |
367 | std::unordered_map<string, std::shared_ptr<ExecutorsAndKeys>> executors_ |
368 | TF_GUARDED_BY(executor_lock_); |
369 | |
370 | class RunCallableCallFrame; |
371 | struct Callable { |
372 | std::shared_ptr<ExecutorsAndKeys> executors_and_keys; |
373 | std::shared_ptr<FunctionInfo> function_info; |
374 | ~Callable(); |
375 | }; |
376 | mutex callables_lock_; |
377 | int64_t next_callable_handle_ TF_GUARDED_BY(callables_lock_) = 0; |
378 | std::unordered_map<int64_t, Callable> callables_ |
379 | TF_GUARDED_BY(callables_lock_); |
380 | |
381 | // Holds mappings from handle to partial run state. |
382 | std::unordered_map<string, std::unique_ptr<PartialRunState>> partial_runs_ |
383 | TF_GUARDED_BY(executor_lock_); |
384 | |
385 | // This holds all the tensors that are currently alive in the session. |
386 | SessionState session_state_; |
387 | |
388 | DirectSessionFactory* const factory_; // not owned |
389 | CancellationManager* cancellation_manager_; |
390 | std::unique_ptr<CollectiveExecutorMgrInterface> collective_executor_mgr_; |
391 | |
392 | // Map of placed stateful nodes, i.e. nodes for which is_stateful() |
393 | // is true, such as "params" and "queue" nodes. Once placed these |
394 | // nodes can not be moved to a different device. Maps node names to |
395 | // device names. |
396 | std::unordered_map<string, string> stateful_placements_ |
397 | TF_GUARDED_BY(graph_state_lock_); |
398 | |
399 | // Execution_state; used when placing the entire graph. |
400 | std::unique_ptr<GraphExecutionState> execution_state_ |
401 | TF_GUARDED_BY(graph_state_lock_); |
402 | |
403 | // The function library, before any rewrites or optimizations have been |
404 | // performed. In particular, CreateGraphs() may need to modify the function |
405 | // library; it copies and modifies the function library. |
406 | std::unique_ptr<FunctionLibraryDefinition> flib_def_; |
407 | |
408 | // true if the Session has been Closed. |
409 | mutex closed_lock_; |
410 | bool closed_ TF_GUARDED_BY(closed_lock_) = false; |
411 | |
412 | // For generating unique names for this session instance. |
413 | std::atomic<int64_t> edge_name_counter_ = {0}; |
414 | std::atomic<int64_t> handle_name_counter_ = {0}; |
415 | |
416 | // For generating step ids that are unique among all sessions. |
417 | static std::atomic_int_fast64_t step_id_counter_; |
418 | |
419 | // Global timeout for all blocking operations in this session. |
420 | const int64_t operation_timeout_in_ms_ = 0; |
421 | |
422 | // Manages all the cost models for the graphs executed in this session. |
423 | CostModelManager cost_model_manager_; |
424 | |
425 | // For testing collective graph key generation. |
426 | mutex collective_graph_key_lock_; |
427 | int64_t collective_graph_key_ TF_GUARDED_BY(collective_graph_key_lock_) = -1; |
428 | |
429 | // Run in caller's thread if RunOptions.inter_op_thread_pool is negative or |
430 | // all of following conditions are met: |
431 | // 1. This session doesn't own any thread pool. |
432 | // 2. RunOptions.inter_op_thread_pool is unspecified or 0. |
433 | // 3. This session has a single executor. |
434 | // 4. config.inter_op_parallelism_threads is specified to negative explicitly |
435 | // or through environment variable TF_NUM_INTEROP_THREADS. |
436 | // 5. RunOptions.experimental.use_run_handler_pool is unspecified or false. |
437 | // Otherwise run in global thread pool, session owned thread pool or handler |
438 | // pool according to other specifications of RunOptions and ConfigProto. |
439 | bool run_in_caller_thread_ = false; |
440 | |
441 | TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); |
442 | |
443 | // EXPERIMENTAL: debugger (tfdbg) related |
444 | friend class DebugGateway; |
445 | }; |
446 | |
447 | } // end namespace tensorflow |
448 | |
449 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_ |
450 | |