1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/direct_session.h" |
17 | |
18 | #include <algorithm> |
19 | #include <atomic> |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "absl/time/time.h" |
25 | #include "absl/types/optional.h" |
26 | #include "tensorflow/core/common_runtime/collective_executor_mgr.h" |
27 | #include "tensorflow/core/common_runtime/collective_param_resolver_local.h" |
28 | #include "tensorflow/core/common_runtime/constant_folding.h" |
29 | #include "tensorflow/core/common_runtime/debugger_state_interface.h" |
30 | #include "tensorflow/core/common_runtime/device_factory.h" |
31 | #include "tensorflow/core/common_runtime/device_resolver_local.h" |
32 | #include "tensorflow/core/common_runtime/executor.h" |
33 | #include "tensorflow/core/common_runtime/executor_factory.h" |
34 | #include "tensorflow/core/common_runtime/function.h" |
35 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
36 | #include "tensorflow/core/common_runtime/graph_optimizer.h" |
37 | #include "tensorflow/core/common_runtime/local_session_selection.h" |
38 | #include "tensorflow/core/common_runtime/memory_types.h" |
39 | #include "tensorflow/core/common_runtime/optimization_registry.h" |
40 | #include "tensorflow/core/common_runtime/process_util.h" |
41 | #include "tensorflow/core/common_runtime/rendezvous_mgr.h" |
42 | #include "tensorflow/core/common_runtime/scoped_allocator_mgr.h" |
43 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
44 | #include "tensorflow/core/framework/function.h" |
45 | #include "tensorflow/core/framework/graph.pb.h" |
46 | #include "tensorflow/core/framework/graph_def_util.h" |
47 | #include "tensorflow/core/framework/log_memory.h" |
48 | #include "tensorflow/core/framework/logging.h" |
49 | #include "tensorflow/core/framework/metrics.h" |
50 | #include "tensorflow/core/framework/node_def.pb.h" |
51 | #include "tensorflow/core/framework/run_handler.h" |
52 | #include "tensorflow/core/framework/tensor.h" |
53 | #include "tensorflow/core/framework/versions.pb.h" |
54 | #include "tensorflow/core/graph/algorithm.h" |
55 | #include "tensorflow/core/graph/graph.h" |
56 | #include "tensorflow/core/graph/graph_partition.h" |
57 | #include "tensorflow/core/graph/subgraph.h" |
58 | #include "tensorflow/core/graph/tensor_id.h" |
59 | #include "tensorflow/core/lib/core/errors.h" |
60 | #include "tensorflow/core/lib/core/notification.h" |
61 | #include "tensorflow/core/lib/core/refcount.h" |
62 | #include "tensorflow/core/lib/core/status.h" |
63 | #include "tensorflow/core/lib/core/threadpool.h" |
64 | #include "tensorflow/core/lib/core/threadpool_options.h" |
65 | #include "tensorflow/core/lib/gtl/array_slice.h" |
66 | #include "tensorflow/core/lib/monitoring/counter.h" |
67 | #include "tensorflow/core/lib/random/random.h" |
68 | #include "tensorflow/core/lib/strings/numbers.h" |
69 | #include "tensorflow/core/lib/strings/str_util.h" |
70 | #include "tensorflow/core/lib/strings/strcat.h" |
71 | #include "tensorflow/core/nccl/collective_communicator.h" |
72 | #include "tensorflow/core/platform/byte_order.h" |
73 | #include "tensorflow/core/platform/cpu_info.h" |
74 | #include "tensorflow/core/platform/logging.h" |
75 | #include "tensorflow/core/platform/mutex.h" |
76 | #include "tensorflow/core/platform/tracing.h" |
77 | #include "tensorflow/core/platform/types.h" |
78 | #include "tensorflow/core/profiler/lib/connected_traceme.h" |
79 | #include "tensorflow/core/profiler/lib/device_profiler_session.h" |
80 | #include "tensorflow/core/profiler/lib/traceme_encode.h" |
81 | #include "tensorflow/core/protobuf/config.pb.h" |
82 | #include "tensorflow/core/util/device_name_utils.h" |
83 | #include "tensorflow/core/util/env_var.h" |
84 | |
85 | namespace tensorflow { |
86 | |
87 | namespace { |
88 | |
89 | auto* direct_session_runs = monitoring::Counter<0>::New( |
90 | "/tensorflow/core/direct_session_runs" , |
91 | "The number of times DirectSession::Run() has been called." ); |
92 | |
93 | Status NewThreadPoolFromThreadPoolOptions( |
94 | const SessionOptions& options, |
95 | const ThreadPoolOptionProto& thread_pool_options, int pool_number, |
96 | thread::ThreadPool** pool, bool* owned) { |
97 | int32_t num_threads = thread_pool_options.num_threads(); |
98 | if (num_threads == 0) { |
99 | num_threads = NumInterOpThreadsFromSessionOptions(options); |
100 | } |
101 | const string& name = thread_pool_options.global_name(); |
102 | if (name.empty()) { |
103 | // Session-local threadpool. |
104 | VLOG(1) << "Direct session inter op parallelism threads for pool " |
105 | << pool_number << ": " << num_threads; |
106 | *pool = new thread::ThreadPool( |
107 | options.env, ThreadOptions(), strings::StrCat("Compute" , pool_number), |
108 | num_threads, !options.config.experimental().disable_thread_spinning(), |
109 | /*allocator=*/nullptr); |
110 | *owned = true; |
111 | return OkStatus(); |
112 | } |
113 | |
114 | // Global, named threadpool. |
115 | typedef std::pair<int32, thread::ThreadPool*> MapValue; |
116 | static std::map<string, MapValue>* global_pool_map = |
117 | new std::map<string, MapValue>; |
118 | static mutex* mu = new mutex(); |
119 | mutex_lock l(*mu); |
120 | MapValue* mvalue = &(*global_pool_map)[name]; |
121 | if (mvalue->second == nullptr) { |
122 | mvalue->first = thread_pool_options.num_threads(); |
123 | mvalue->second = new thread::ThreadPool( |
124 | options.env, ThreadOptions(), strings::StrCat("Compute" , pool_number), |
125 | num_threads, !options.config.experimental().disable_thread_spinning(), |
126 | /*allocator=*/nullptr); |
127 | } else { |
128 | if (mvalue->first != thread_pool_options.num_threads()) { |
129 | return errors::InvalidArgument( |
130 | "Pool " , name, |
131 | " configured previously with num_threads=" , mvalue->first, |
132 | "; cannot re-configure with num_threads=" , |
133 | thread_pool_options.num_threads()); |
134 | } |
135 | } |
136 | *owned = false; |
137 | *pool = mvalue->second; |
138 | return OkStatus(); |
139 | } |
140 | |
141 | thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) { |
142 | static thread::ThreadPool* const thread_pool = |
143 | NewThreadPoolFromSessionOptions(options); |
144 | return thread_pool; |
145 | } |
146 | |
147 | // TODO(vrv): Figure out how to unify the many different functions |
148 | // that generate RendezvousKey, since many of them have to be |
149 | // consistent with each other. |
150 | string GetRendezvousKey(const string& tensor_name, |
151 | const DeviceAttributes& device_info, |
152 | const FrameAndIter& frame_iter) { |
153 | return strings::StrCat(device_info.name(), ";" , |
154 | strings::FpToString(device_info.incarnation()), ";" , |
155 | device_info.name(), ";" , tensor_name, ";" , |
156 | frame_iter.frame_id, ":" , frame_iter.iter_id); |
157 | } |
158 | |
159 | } // namespace |
160 | |
161 | class DirectSessionFactory : public SessionFactory { |
162 | public: |
163 | DirectSessionFactory() {} |
164 | |
165 | bool AcceptsOptions(const SessionOptions& options) override { |
166 | return options.target.empty() && |
167 | !options.config.experimental().use_tfrt() && |
168 | GetDefaultLocalSessionImpl() == LocalSessionImpl::kDirectSession; |
169 | } |
170 | |
171 | Status NewSession(const SessionOptions& options, |
172 | Session** out_session) override { |
173 | const auto& experimental_config = options.config.experimental(); |
174 | if (experimental_config.has_session_metadata()) { |
175 | if (experimental_config.session_metadata().version() < 0) { |
176 | return errors::InvalidArgument( |
177 | "Session version shouldn't be negative: " , |
178 | experimental_config.session_metadata().DebugString()); |
179 | } |
180 | const string key = GetMetadataKey(experimental_config.session_metadata()); |
181 | mutex_lock l(sessions_lock_); |
182 | if (!session_metadata_keys_.insert(key).second) { |
183 | return errors::InvalidArgument( |
184 | "A session with the same name and version has already been " |
185 | "created: " , |
186 | experimental_config.session_metadata().DebugString()); |
187 | } |
188 | } |
189 | |
190 | // Must do this before the CPU allocator is created. |
191 | if (options.config.graph_options().build_cost_model() > 0) { |
192 | EnableCPUAllocatorFullStats(); |
193 | } |
194 | std::vector<std::unique_ptr<Device>> devices; |
195 | TF_RETURN_IF_ERROR(DeviceFactory::AddDevices( |
196 | options, "/job:localhost/replica:0/task:0" , &devices)); |
197 | |
198 | DirectSession* session = new DirectSession( |
199 | options, new StaticDeviceMgr(std::move(devices)), this); |
200 | { |
201 | mutex_lock l(sessions_lock_); |
202 | sessions_.push_back(session); |
203 | } |
204 | *out_session = session; |
205 | return OkStatus(); |
206 | } |
207 | |
208 | Status Reset(const SessionOptions& options, |
209 | const std::vector<string>& containers) override { |
210 | std::vector<DirectSession*> sessions_to_reset; |
211 | { |
212 | mutex_lock l(sessions_lock_); |
213 | // We create a copy to ensure that we don't have a deadlock when |
214 | // session->Close calls the DirectSessionFactory.Deregister, which |
215 | // acquires sessions_lock_. |
216 | std::swap(sessions_to_reset, sessions_); |
217 | } |
218 | Status s; |
219 | for (auto session : sessions_to_reset) { |
220 | s.Update(session->Reset(containers)); |
221 | } |
222 | // TODO(suharshs): Change the Reset behavior of all SessionFactories so that |
223 | // it doesn't close the sessions? |
224 | for (auto session : sessions_to_reset) { |
225 | s.Update(session->Close()); |
226 | } |
227 | return s; |
228 | } |
229 | |
230 | void Deregister(const DirectSession* session) { |
231 | mutex_lock l(sessions_lock_); |
232 | sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session), |
233 | sessions_.end()); |
234 | if (session->options().config.experimental().has_session_metadata()) { |
235 | session_metadata_keys_.erase(GetMetadataKey( |
236 | session->options().config.experimental().session_metadata())); |
237 | } |
238 | } |
239 | |
240 | private: |
241 | static string GetMetadataKey(const SessionMetadata& metadata) { |
242 | return absl::StrCat(metadata.name(), "/" , metadata.version()); |
243 | } |
244 | |
245 | mutex sessions_lock_; |
246 | std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_); |
247 | absl::flat_hash_set<string> session_metadata_keys_ |
248 | TF_GUARDED_BY(sessions_lock_); |
249 | }; |
250 | |
251 | class DirectSessionRegistrar { |
252 | public: |
253 | DirectSessionRegistrar() { |
254 | SessionFactory::Register("DIRECT_SESSION" , new DirectSessionFactory()); |
255 | } |
256 | }; |
257 | static DirectSessionRegistrar registrar; |
258 | |
259 | std::atomic_int_fast64_t DirectSession::step_id_counter_(1); |
260 | |
261 | static RunHandlerPool* GetOrCreateRunHandlerPool( |
262 | const SessionOptions& options) { |
263 | int num_inter_threads = 0; |
264 | int num_intra_threads = 0; |
265 | static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment(); |
266 | static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment(); |
267 | if (env_num_inter_threads > 0) { |
268 | num_inter_threads = env_num_inter_threads; |
269 | } |
270 | if (env_num_intra_threads > 0) { |
271 | num_intra_threads = env_num_intra_threads; |
272 | } |
273 | |
274 | if (num_inter_threads == 0) { |
275 | if (options.config.session_inter_op_thread_pool_size() > 0) { |
276 | // Note due to ShouldUseRunHandler we are guaranteed that |
277 | // run_options.inter_op_thread_pool() == 0 |
278 | num_inter_threads = |
279 | options.config.session_inter_op_thread_pool(0).num_threads(); |
280 | } |
281 | if (num_inter_threads == 0) { |
282 | num_inter_threads = NumInterOpThreadsFromSessionOptions(options); |
283 | } |
284 | } |
285 | |
286 | if (num_intra_threads == 0) { |
287 | num_intra_threads = options.config.intra_op_parallelism_threads(); |
288 | if (num_intra_threads == 0) { |
289 | num_intra_threads = port::MaxParallelism(); |
290 | } |
291 | } |
292 | |
293 | static RunHandlerPool* pool = [&]() { |
294 | LOG(INFO) << "Creating run-handler pool with " |
295 | "[num_inter_threads, num_intra_threads] as [" |
296 | << num_inter_threads << "," << num_intra_threads << "]" ; |
297 | return new RunHandlerPool(num_inter_threads, num_intra_threads); |
298 | }(); |
299 | return pool; |
300 | } |
301 | |
302 | bool DirectSession::ShouldUseRunHandlerPool( |
303 | const RunOptions& run_options) const { |
304 | if (options_.config.use_per_session_threads()) return false; |
305 | if (options_.config.session_inter_op_thread_pool_size() > 0 && |
306 | run_options.inter_op_thread_pool() > 0) |
307 | return false; |
308 | // Only use RunHandlerPool when: |
309 | // a. Single global thread pool is used for inter-op parallelism. |
310 | // b. When multiple inter_op_thread_pool(s) are created, use it only while |
311 | // running sessions on the default inter_op_thread_pool=0. Typically, |
312 | // servo-team uses inter_op_thread_pool > 0 for model loading. |
313 | // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool |
314 | // per entry in session_inter_op_thread_pool() in the future. |
315 | return true; |
316 | } |
317 | |
318 | DirectSession::DirectSession(const SessionOptions& options, |
319 | const DeviceMgr* device_mgr, |
320 | DirectSessionFactory* const factory) |
321 | : options_(options), |
322 | device_mgr_(device_mgr), |
323 | factory_(factory), |
324 | cancellation_manager_(new CancellationManager()), |
325 | operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) { |
326 | const int thread_pool_size = |
327 | options_.config.session_inter_op_thread_pool_size(); |
328 | if (thread_pool_size > 0) { |
329 | for (int i = 0; i < thread_pool_size; ++i) { |
330 | thread::ThreadPool* pool = nullptr; |
331 | bool owned = false; |
332 | init_error_.Update(NewThreadPoolFromThreadPoolOptions( |
333 | options_, options_.config.session_inter_op_thread_pool(i), i, &pool, |
334 | &owned)); |
335 | thread_pools_.emplace_back(pool, owned); |
336 | } |
337 | } else if (options_.config.use_per_session_threads()) { |
338 | thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_), |
339 | true /* owned */); |
340 | } else { |
341 | thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */); |
342 | // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative |
343 | // and config.inter_op_parallelism_threads is unspecified or negative. |
344 | static const int env_num_threads = NumInterOpThreadsFromEnvironment(); |
345 | if (options_.config.inter_op_parallelism_threads() < 0 || |
346 | (options_.config.inter_op_parallelism_threads() == 0 && |
347 | env_num_threads < 0)) { |
348 | run_in_caller_thread_ = true; |
349 | } |
350 | } |
351 | // The default value of sync_on_finish will be flipped soon and this |
352 | // environment variable will be removed as well. |
353 | const Status status = |
354 | ReadBoolFromEnvVar("TF_SYNC_ON_FINISH" , true, &sync_on_finish_); |
355 | if (!status.ok()) { |
356 | LOG(ERROR) << status.error_message(); |
357 | } |
358 | session_handle_ = |
359 | strings::StrCat("direct" , strings::FpToString(random::New64())); |
360 | int devices_added = 0; |
361 | if (options.config.log_device_placement()) { |
362 | const string mapping_str = device_mgr_->DeviceMappingString(); |
363 | string msg; |
364 | if (mapping_str.empty()) { |
365 | msg = "Device mapping: no known devices." ; |
366 | } else { |
367 | msg = strings::StrCat("Device mapping:\n" , mapping_str); |
368 | } |
369 | if (!logging::LogToListeners(msg)) { |
370 | LOG(INFO) << msg; |
371 | } |
372 | } |
373 | for (auto d : device_mgr_->ListDevices()) { |
374 | devices_.push_back(d); |
375 | device_set_.AddDevice(d); |
376 | d->op_segment()->AddHold(session_handle_); |
377 | |
378 | // The first device added is special: it is the 'client device' (a |
379 | // CPU device) from which we feed and fetch Tensors. |
380 | if (devices_added == 0) { |
381 | device_set_.set_client_device(d); |
382 | } |
383 | ++devices_added; |
384 | } |
385 | } |
386 | |
387 | DirectSession::~DirectSession() { |
388 | if (!closed_) Close().IgnoreError(); |
389 | for (auto& it : partial_runs_) { |
390 | it.second.reset(nullptr); |
391 | } |
392 | for (auto& it : executors_) { |
393 | it.second.reset(); |
394 | } |
395 | callables_.clear(); |
396 | for (auto d : device_mgr_->ListDevices()) { |
397 | d->op_segment()->RemoveHold(session_handle_); |
398 | } |
399 | functions_.clear(); |
400 | delete cancellation_manager_; |
401 | for (const auto& p_and_owned : thread_pools_) { |
402 | if (p_and_owned.second) delete p_and_owned.first; |
403 | } |
404 | |
405 | execution_state_.reset(nullptr); |
406 | flib_def_.reset(nullptr); |
407 | } |
408 | |
409 | Status DirectSession::Create(const GraphDef& graph) { |
410 | return Create(GraphDef(graph)); |
411 | } |
412 | |
413 | Status DirectSession::Create(GraphDef&& graph) { |
414 | TF_RETURN_IF_ERROR(init_error_); |
415 | if (graph.node_size() > 0) { |
416 | mutex_lock l(graph_state_lock_); |
417 | if (graph_created_) { |
418 | return errors::AlreadyExists( |
419 | "A Graph has already been created for this session." ); |
420 | } |
421 | return ExtendLocked(std::move(graph)); |
422 | } |
423 | return OkStatus(); |
424 | } |
425 | |
426 | Status DirectSession::Extend(const GraphDef& graph) { |
427 | return Extend(GraphDef(graph)); |
428 | } |
429 | |
430 | Status DirectSession::Extend(GraphDef&& graph) { |
431 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
432 | mutex_lock l(graph_state_lock_); |
433 | return ExtendLocked(std::move(graph)); |
434 | } |
435 | |
436 | Status DirectSession::ExtendLocked(GraphDef&& graph) { |
437 | if (finalized_) { |
438 | return errors::FailedPrecondition("Session has been finalized." ); |
439 | } |
440 | if (!(flib_def_ && execution_state_)) { |
441 | // If this is the first call, we can initialize the execution state |
442 | // with `graph` and do not need to call `Extend()`. |
443 | GraphExecutionStateOptions options; |
444 | options.device_set = &device_set_; |
445 | options.session_options = &options_; |
446 | options.session_handle = session_handle_; |
447 | TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph( |
448 | std::move(graph), options, &execution_state_)); |
449 | // NOTE(mrry): The function library created here will be used for |
450 | // all subsequent extensions of the graph. Also, note how using the copy |
451 | // constructor of FunctionLibraryDefinition avoids duplicating the memory |
452 | // that is occupied by its shared_ptr members. |
453 | flib_def_.reset( |
454 | new FunctionLibraryDefinition(execution_state_->flib_def())); |
455 | graph_created_ = true; |
456 | } else { |
457 | std::unique_ptr<GraphExecutionState> state; |
458 | // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by |
459 | // value and move `graph` in here. |
460 | TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state)); |
461 | execution_state_.swap(state); |
462 | TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library())); |
463 | } |
464 | return OkStatus(); |
465 | } |
466 | |
467 | Status DirectSession::Run(const NamedTensorList& inputs, |
468 | const std::vector<string>& output_names, |
469 | const std::vector<string>& target_nodes, |
470 | std::vector<Tensor>* outputs) { |
471 | RunMetadata run_metadata; |
472 | return Run(RunOptions(), inputs, output_names, target_nodes, outputs, |
473 | &run_metadata); |
474 | } |
475 | |
476 | Status DirectSession::CreateDebuggerState( |
477 | const CallableOptions& callable_options, int64_t global_step, |
478 | int64_t session_run_index, int64_t executor_step_index, |
479 | std::unique_ptr<DebuggerStateInterface>* debugger_state) { |
480 | TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState( |
481 | callable_options.run_options().debug_options(), debugger_state)); |
482 | std::vector<string> input_names(callable_options.feed().begin(), |
483 | callable_options.feed().end()); |
484 | std::vector<string> output_names(callable_options.fetch().begin(), |
485 | callable_options.fetch().end()); |
486 | std::vector<string> target_names(callable_options.target().begin(), |
487 | callable_options.target().end()); |
488 | |
489 | TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata( |
490 | global_step, session_run_index, executor_step_index, input_names, |
491 | output_names, target_names)); |
492 | return OkStatus(); |
493 | } |
494 | |
495 | Status DirectSession::DecorateAndPublishGraphForDebug( |
496 | const DebugOptions& debug_options, Graph* graph, Device* device) { |
497 | std::unique_ptr<DebugGraphDecoratorInterface> decorator; |
498 | TF_RETURN_IF_ERROR( |
499 | DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator)); |
500 | |
501 | TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device)); |
502 | TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name())); |
503 | return OkStatus(); |
504 | } |
505 | |
506 | Status DirectSession::RunInternal( |
507 | int64_t step_id, const RunOptions& run_options, |
508 | CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys, |
509 | RunMetadata* run_metadata, |
510 | const thread::ThreadPoolOptions& threadpool_options) { |
511 | const uint64 start_time_usecs = options_.env->NowMicros(); |
512 | const int64_t executor_step_count = |
513 | executors_and_keys->step_count.fetch_add(1); |
514 | RunState run_state(step_id, &devices_); |
515 | const size_t num_executors = executors_and_keys->items.size(); |
516 | |
517 | profiler::TraceMeProducer activity( |
518 | // To TraceMeConsumers in ExecutorState::Process/Finish. |
519 | [&] { |
520 | if (options_.config.experimental().has_session_metadata()) { |
521 | const auto& model_metadata = |
522 | options_.config.experimental().session_metadata(); |
523 | string model_id = strings::StrCat(model_metadata.name(), ":" , |
524 | model_metadata.version()); |
525 | return profiler::TraceMeEncode("SessionRun" , |
526 | {{"id" , step_id}, |
527 | {"_r" , 1} /*root_event*/, |
528 | {"model_id" , model_id}}); |
529 | } else { |
530 | return profiler::TraceMeEncode( |
531 | "SessionRun" , {{"id" , step_id}, {"_r" , 1} /*root_event*/}); |
532 | } |
533 | }, |
534 | profiler::ContextType::kTfExecutor, step_id, |
535 | profiler::TraceMeLevel::kInfo); |
536 | |
537 | std::unique_ptr<DebuggerStateInterface> debugger_state; |
538 | if (!run_options.debug_options().debug_tensor_watch_opts().empty()) { |
539 | TF_RETURN_IF_ERROR( |
540 | CreateDebuggerState(executors_and_keys->callable_options, |
541 | run_options.debug_options().global_step(), step_id, |
542 | executor_step_count, &debugger_state)); |
543 | } |
544 | |
545 | if (run_metadata != nullptr && |
546 | options_.config.experimental().has_session_metadata()) { |
547 | *run_metadata->mutable_session_metadata() = |
548 | options_.config.experimental().session_metadata(); |
549 | } |
550 | |
551 | #ifndef __ANDROID__ |
552 | // Set up for collectives if ExecutorsAndKeys declares a key. |
553 | if (executors_and_keys->collective_graph_key != |
554 | BuildGraphOptions::kNoCollectiveGraphKey) { |
555 | if (run_options.experimental().collective_graph_key() != |
556 | BuildGraphOptions::kNoCollectiveGraphKey) { |
557 | // If a collective_graph_key was specified in run_options, ensure that it |
558 | // matches what came out of GraphExecutionState::BuildGraph(). |
559 | if (run_options.experimental().collective_graph_key() != |
560 | executors_and_keys->collective_graph_key) { |
561 | return errors::Internal( |
562 | "collective_graph_key in RunOptions " , |
563 | run_options.experimental().collective_graph_key(), |
564 | " should match collective_graph_key from optimized graph " , |
565 | executors_and_keys->collective_graph_key); |
566 | } |
567 | } |
568 | if (!collective_executor_mgr_) { |
569 | collective_executor_mgr_ = CreateProdLocalCollectiveExecutorMgr( |
570 | options_.config, device_mgr_.get(), |
571 | MaybeCreateNcclCommunicator(options_.config)); |
572 | } |
573 | run_state.collective_executor.reset(new CollectiveExecutor::Handle( |
574 | collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/)); |
575 | } |
576 | #endif |
577 | |
578 | thread::ThreadPool* pool; |
579 | // Use std::unique_ptr to ensure garbage collection |
580 | std::unique_ptr<thread::ThreadPool> threadpool_wrapper; |
581 | |
582 | const bool inline_execution_requested = |
583 | run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1; |
584 | |
585 | if (inline_execution_requested) { |
586 | // We allow using the caller thread only when having a single executor |
587 | // specified. |
588 | if (executors_and_keys->items.size() > 1) { |
589 | pool = thread_pools_[0].first; |
590 | } else { |
591 | VLOG(1) << "Executing Session::Run() synchronously!" ; |
592 | pool = nullptr; |
593 | } |
594 | } else if (threadpool_options.inter_op_threadpool != nullptr) { |
595 | threadpool_wrapper = std::make_unique<thread::ThreadPool>( |
596 | threadpool_options.inter_op_threadpool); |
597 | pool = threadpool_wrapper.get(); |
598 | } else { |
599 | if (run_options.inter_op_thread_pool() < -1 || |
600 | run_options.inter_op_thread_pool() >= |
601 | static_cast<int32>(thread_pools_.size())) { |
602 | return errors::InvalidArgument("Invalid inter_op_thread_pool: " , |
603 | run_options.inter_op_thread_pool()); |
604 | } |
605 | |
606 | pool = thread_pools_[run_options.inter_op_thread_pool()].first; |
607 | } |
608 | |
609 | const int64_t call_timeout = run_options.timeout_in_ms() > 0 |
610 | ? run_options.timeout_in_ms() |
611 | : operation_timeout_in_ms_; |
612 | absl::optional<absl::Time> deadline; |
613 | if (call_timeout > 0) { |
614 | deadline = absl::Now() + absl::Milliseconds(call_timeout); |
615 | } |
616 | |
617 | std::unique_ptr<RunHandler> handler; |
618 | if (ShouldUseRunHandlerPool(run_options) && |
619 | run_options.experimental().use_run_handler_pool()) { |
620 | VLOG(1) << "Using RunHandler to scheduler inter-op closures." ; |
621 | handler = GetOrCreateRunHandlerPool(options_)->Get( |
622 | step_id, call_timeout, |
623 | run_options.experimental().run_handler_pool_options()); |
624 | if (!handler) { |
625 | return errors::DeadlineExceeded( |
626 | "Could not obtain RunHandler for request after waiting for " , |
627 | call_timeout, "ms." ); |
628 | } |
629 | } |
630 | auto* handler_ptr = handler.get(); |
631 | |
632 | Executor::Args::Runner default_runner = nullptr; |
633 | |
634 | if (pool == nullptr) { |
635 | default_runner = [](const Executor::Args::Closure& c) { c(); }; |
636 | } else if (handler_ptr != nullptr) { |
637 | default_runner = [handler_ptr](Executor::Args::Closure c) { |
638 | handler_ptr->ScheduleInterOpClosure(std::move(c)); |
639 | }; |
640 | } else { |
641 | default_runner = [pool](Executor::Args::Closure c) { |
642 | pool->Schedule(std::move(c)); |
643 | }; |
644 | } |
645 | |
646 | // Start parallel Executors. |
647 | |
648 | // We can execute this step synchronously on the calling thread whenever |
649 | // there is a single device and the timeout mechanism is not used. |
650 | // |
651 | // When timeouts are used, we must execute the graph(s) asynchronously, in |
652 | // order to invoke the cancellation manager on the calling thread if the |
653 | // timeout expires. |
654 | const bool can_execute_synchronously = |
655 | executors_and_keys->items.size() == 1 && call_timeout == 0; |
656 | |
657 | Executor::Args args; |
658 | args.step_id = step_id; |
659 | args.call_frame = call_frame; |
660 | args.collective_executor = |
661 | (run_state.collective_executor ? run_state.collective_executor->get() |
662 | : nullptr); |
663 | args.session_state = &session_state_; |
664 | args.session_handle = session_handle_; |
665 | args.tensor_store = &run_state.tensor_store; |
666 | args.step_container = &run_state.step_container; |
667 | args.sync_on_finish = sync_on_finish_; |
668 | args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool; |
669 | args.run_all_kernels_inline = pool == nullptr; |
670 | args.start_time_usecs = start_time_usecs; |
671 | args.deadline = deadline; |
672 | |
673 | const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE); |
674 | |
675 | bool update_cost_model = false; |
676 | if (options_.config.graph_options().build_cost_model() > 0) { |
677 | const int64_t build_cost_model_every = |
678 | options_.config.graph_options().build_cost_model(); |
679 | const int64_t build_cost_model_after = |
680 | options_.config.graph_options().build_cost_model_after(); |
681 | int64_t measure_step_count = executor_step_count - build_cost_model_after; |
682 | if (measure_step_count >= 0) { |
683 | update_cost_model = |
684 | ((measure_step_count + 1) % build_cost_model_every == 0); |
685 | } |
686 | } |
687 | if (do_trace || update_cost_model || |
688 | run_options.report_tensor_allocations_upon_oom()) { |
689 | run_state.collector.reset( |
690 | new StepStatsCollector(run_metadata->mutable_step_stats())); |
691 | args.stats_collector = run_state.collector.get(); |
692 | } |
693 | |
694 | std::unique_ptr<DeviceProfilerSession> device_profiler_session; |
695 | if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) { |
696 | device_profiler_session = DeviceProfilerSession::Create(); |
697 | } |
698 | |
699 | // Register this step with session's cancellation manager, so that |
700 | // `Session::Close()` will cancel the step. |
701 | CancellationManager step_cancellation_manager(cancellation_manager_); |
702 | if (step_cancellation_manager.IsCancelled()) { |
703 | return errors::Cancelled("Run call was cancelled" ); |
704 | } |
705 | args.cancellation_manager = &step_cancellation_manager; |
706 | |
707 | Status run_status; |
708 | |
709 | auto set_threadpool_args_for_item = |
710 | [&default_runner, &handler](const PerPartitionExecutorsAndLib& item, |
711 | Executor::Args* args) { |
712 | // TODO(azaks): support partial run. |
713 | // TODO(azaks): if the device picks its own threadpool, we need to |
714 | // assign |
715 | // less threads to the main compute pool by default. |
716 | thread::ThreadPool* device_thread_pool = |
717 | item.device->tensorflow_device_thread_pool(); |
718 | // TODO(crk): Investigate usage of RunHandlerPool when using device |
719 | // specific thread pool(s). |
720 | if (!device_thread_pool) { |
721 | args->runner = default_runner; |
722 | } else { |
723 | args->runner = [device_thread_pool](Executor::Args::Closure c) { |
724 | device_thread_pool->Schedule(std::move(c)); |
725 | }; |
726 | } |
727 | if (handler != nullptr) { |
728 | args->user_intra_op_threadpool = |
729 | handler->AsIntraThreadPoolInterface(); |
730 | } |
731 | }; |
732 | |
733 | if (can_execute_synchronously) { |
734 | PrivateIntraProcessRendezvous rendezvous(device_mgr_.get()); |
735 | args.rendezvous = &rendezvous; |
736 | |
737 | const auto& item = executors_and_keys->items[0]; |
738 | set_threadpool_args_for_item(item, &args); |
739 | run_status = item.executor->Run(args); |
740 | } else { |
741 | core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous( |
742 | new RefCountedIntraProcessRendezvous(device_mgr_.get())); |
743 | args.rendezvous = rendezvous.get(); |
744 | |
745 | // `barrier` will delete itself after the final executor finishes. |
746 | Notification executors_done; |
747 | ExecutorBarrier* barrier = |
748 | new ExecutorBarrier(num_executors, rendezvous.get(), |
749 | [&run_state, &executors_done](const Status& ret) { |
750 | { |
751 | mutex_lock l(run_state.mu); |
752 | run_state.status.Update(ret); |
753 | } |
754 | executors_done.Notify(); |
755 | }); |
756 | |
757 | for (const auto& item : executors_and_keys->items) { |
758 | set_threadpool_args_for_item(item, &args); |
759 | item.executor->RunAsync(args, barrier->Get()); |
760 | } |
761 | |
762 | WaitForNotification(&executors_done, &run_state, &step_cancellation_manager, |
763 | call_timeout); |
764 | { |
765 | tf_shared_lock l(run_state.mu); |
766 | run_status = run_state.status; |
767 | } |
768 | } |
769 | |
770 | if (step_cancellation_manager.IsCancelled()) { |
771 | run_status.Update(errors::Cancelled("Run call was cancelled" )); |
772 | } |
773 | |
774 | if (device_profiler_session) { |
775 | TF_RETURN_IF_ERROR(device_profiler_session->CollectData( |
776 | run_metadata->mutable_step_stats())); |
777 | } |
778 | |
779 | TF_RETURN_IF_ERROR(run_status); |
780 | |
781 | // Save the output tensors of this run we choose to keep. |
782 | if (!run_state.tensor_store.empty()) { |
783 | TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors( |
784 | {executors_and_keys->callable_options.fetch().begin(), |
785 | executors_and_keys->callable_options.fetch().end()}, |
786 | &session_state_)); |
787 | } |
788 | |
789 | if (run_state.collector) { |
790 | run_state.collector->Finalize(); |
791 | } |
792 | |
793 | // Build and return the cost model as instructed. |
794 | if (update_cost_model) { |
795 | // Build the cost model |
796 | std::unordered_map<string, const Graph*> device_to_graph; |
797 | for (const PerPartitionExecutorsAndLib& partition : |
798 | executors_and_keys->items) { |
799 | const Graph* graph = partition.graph.get(); |
800 | const string& device = partition.flib->device()->name(); |
801 | device_to_graph[device] = graph; |
802 | } |
803 | |
804 | mutex_lock l(executor_lock_); |
805 | run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph); |
806 | |
807 | // annotate stats onto cost graph. |
808 | CostGraphDef* cost_graph = run_metadata->mutable_cost_graph(); |
809 | for (const auto& item : executors_and_keys->items) { |
810 | TF_RETURN_IF_ERROR( |
811 | cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph)); |
812 | } |
813 | } |
814 | |
815 | // If requested via RunOptions, output the partition graphs. |
816 | if (run_options.output_partition_graphs()) { |
817 | if (options_.config.experimental().disable_output_partition_graphs()) { |
818 | return errors::InvalidArgument( |
819 | "RunOptions.output_partition_graphs() is not supported when " |
820 | "disable_output_partition_graphs is true." ); |
821 | } else { |
822 | protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs = |
823 | run_metadata->mutable_partition_graphs(); |
824 | for (const PerPartitionExecutorsAndLib& exec_and_lib : |
825 | executors_and_keys->items) { |
826 | GraphDef* partition_graph_def = partition_graph_defs->Add(); |
827 | exec_and_lib.graph->ToGraphDef(partition_graph_def); |
828 | } |
829 | } |
830 | } |
831 | metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs); |
832 | |
833 | return OkStatus(); |
834 | } |
835 | |
836 | Status DirectSession::Run(const RunOptions& run_options, |
837 | const NamedTensorList& inputs, |
838 | const std::vector<string>& output_names, |
839 | const std::vector<string>& target_nodes, |
840 | std::vector<Tensor>* outputs, |
841 | RunMetadata* run_metadata) { |
842 | return Run(run_options, inputs, output_names, target_nodes, outputs, |
843 | run_metadata, thread::ThreadPoolOptions()); |
844 | } |
845 | |
846 | Status DirectSession::Run(const RunOptions& run_options, |
847 | const NamedTensorList& inputs, |
848 | const std::vector<string>& output_names, |
849 | const std::vector<string>& target_nodes, |
850 | std::vector<Tensor>* outputs, |
851 | RunMetadata* run_metadata, |
852 | const thread::ThreadPoolOptions& threadpool_options) { |
853 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
854 | TF_RETURN_IF_ERROR(CheckGraphCreated("Run()" )); |
855 | direct_session_runs->GetCell()->IncrementBy(1); |
856 | |
857 | // Extract the inputs names for this run of the session. |
858 | std::vector<string> input_tensor_names; |
859 | input_tensor_names.reserve(inputs.size()); |
860 | size_t input_size = 0; |
861 | for (const auto& it : inputs) { |
862 | input_tensor_names.push_back(it.first); |
863 | input_size += it.second.AllocatedBytes(); |
864 | } |
865 | metrics::RecordGraphInputTensors(input_size); |
866 | |
867 | // Check if we already have an executor for these arguments. |
868 | ExecutorsAndKeys* executors_and_keys; |
869 | RunStateArgs run_state_args(run_options.debug_options()); |
870 | run_state_args.collective_graph_key = |
871 | run_options.experimental().collective_graph_key(); |
872 | |
873 | TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, |
874 | target_nodes, &executors_and_keys, |
875 | &run_state_args)); |
876 | { |
877 | mutex_lock l(collective_graph_key_lock_); |
878 | collective_graph_key_ = executors_and_keys->collective_graph_key; |
879 | } |
880 | |
881 | // Configure a call frame for the step, which we use to feed and |
882 | // fetch values to and from the executors. |
883 | FunctionCallFrame call_frame(executors_and_keys->input_types, |
884 | executors_and_keys->output_types); |
885 | gtl::InlinedVector<Tensor, 4> feed_args(inputs.size()); |
886 | for (const auto& it : inputs) { |
887 | if (it.second.dtype() == DT_RESOURCE) { |
888 | Tensor tensor_from_handle; |
889 | TF_RETURN_IF_ERROR( |
890 | ResourceHandleToInputTensor(it.second, &tensor_from_handle)); |
891 | feed_args[executors_and_keys->input_name_to_index[it.first]] = |
892 | tensor_from_handle; |
893 | } else { |
894 | feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second; |
895 | } |
896 | } |
897 | const Status s = call_frame.SetArgs(feed_args); |
898 | if (errors::IsInternal(s)) { |
899 | return errors::InvalidArgument(s.error_message()); |
900 | } else if (!s.ok()) { |
901 | return s; |
902 | } |
903 | |
904 | const int64_t step_id = step_id_counter_.fetch_add(1); |
905 | |
906 | if (LogMemory::IsEnabled()) { |
907 | LogMemory::RecordStep(step_id, run_state_args.handle); |
908 | } |
909 | |
910 | TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame, |
911 | executors_and_keys, run_metadata, |
912 | threadpool_options)); |
913 | |
914 | // Receive outputs. |
915 | if (outputs) { |
916 | std::vector<Tensor> sorted_outputs; |
917 | const Status s = call_frame.ConsumeRetvals( |
918 | &sorted_outputs, /* allow_dead_tensors = */ false); |
919 | if (errors::IsInternal(s)) { |
920 | return errors::InvalidArgument(s.error_message()); |
921 | } else if (!s.ok()) { |
922 | return s; |
923 | } |
924 | const bool unique_outputs = |
925 | output_names.size() == executors_and_keys->output_name_to_index.size(); |
926 | // first_indices[i] = j implies that j is the smallest value for which |
927 | // output_names[i] == output_names[j]. |
928 | std::vector<int> first_indices; |
929 | if (!unique_outputs) { |
930 | first_indices.reserve(output_names.size()); |
931 | for (const auto& name : output_names) { |
932 | first_indices.push_back( |
933 | std::find(output_names.begin(), output_names.end(), name) - |
934 | output_names.begin()); |
935 | } |
936 | } |
937 | outputs->clear(); |
938 | size_t output_size = 0; |
939 | outputs->reserve(sorted_outputs.size()); |
940 | for (int i = 0; i < output_names.size(); ++i) { |
941 | const string& output_name = output_names[i]; |
942 | if (first_indices.empty() || first_indices[i] == i) { |
943 | outputs->emplace_back( |
944 | std::move(sorted_outputs[executors_and_keys |
945 | ->output_name_to_index[output_name]])); |
946 | } else { |
947 | outputs->push_back((*outputs)[first_indices[i]]); |
948 | } |
949 | output_size += outputs->back().AllocatedBytes(); |
950 | } |
951 | metrics::RecordGraphOutputTensors(output_size); |
952 | } |
953 | |
954 | return OkStatus(); |
955 | } |
956 | |
957 | Status DirectSession::PRunSetup(const std::vector<string>& input_names, |
958 | const std::vector<string>& output_names, |
959 | const std::vector<string>& target_nodes, |
960 | string* handle) { |
961 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
962 | TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()" )); |
963 | |
964 | // RunOptions is not available in PRunSetup, so use thread pool 0. |
965 | thread::ThreadPool* pool = thread_pools_[0].first; |
966 | |
967 | // Check if we already have an executor for these arguments. |
968 | ExecutorsAndKeys* executors_and_keys; |
969 | // TODO(cais): TFDBG support for partial runs. |
970 | DebugOptions debug_options; |
971 | RunStateArgs run_state_args(debug_options); |
972 | run_state_args.is_partial_run = true; |
973 | TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names, |
974 | target_nodes, &executors_and_keys, |
975 | &run_state_args)); |
976 | |
977 | // Create the run state and save it for future PRun calls. |
978 | Executor::Args args; |
979 | args.step_id = step_id_counter_.fetch_add(1); |
980 | PartialRunState* run_state = |
981 | new PartialRunState(input_names, output_names, args.step_id, &devices_); |
982 | run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get())); |
983 | { |
984 | mutex_lock l(executor_lock_); |
985 | if (!partial_runs_ |
986 | .emplace(run_state_args.handle, |
987 | std::unique_ptr<PartialRunState>(run_state)) |
988 | .second) { |
989 | return errors::Internal("The handle '" , run_state_args.handle, |
990 | "' created for this partial run is not unique." ); |
991 | } |
992 | } |
993 | |
994 | // Start parallel Executors. |
995 | const size_t num_executors = executors_and_keys->items.size(); |
996 | ExecutorBarrier* barrier = new ExecutorBarrier( |
997 | num_executors, run_state->rendez.get(), [run_state](const Status& ret) { |
998 | if (!ret.ok()) { |
999 | mutex_lock l(run_state->mu); |
1000 | run_state->status.Update(ret); |
1001 | } |
1002 | run_state->executors_done.Notify(); |
1003 | }); |
1004 | |
1005 | args.rendezvous = run_state->rendez.get(); |
1006 | args.cancellation_manager = cancellation_manager_; |
1007 | // Note that Collectives are not supported in partial runs |
1008 | // because RunOptions is not passed in so we can't know whether |
1009 | // their use is intended. |
1010 | args.collective_executor = nullptr; |
1011 | args.runner = [this, pool](Executor::Args::Closure c) { |
1012 | pool->Schedule(std::move(c)); |
1013 | }; |
1014 | args.session_state = &session_state_; |
1015 | args.session_handle = session_handle_; |
1016 | args.tensor_store = &run_state->tensor_store; |
1017 | args.step_container = &run_state->step_container; |
1018 | if (LogMemory::IsEnabled()) { |
1019 | LogMemory::RecordStep(args.step_id, run_state_args.handle); |
1020 | } |
1021 | args.sync_on_finish = sync_on_finish_; |
1022 | |
1023 | if (options_.config.graph_options().build_cost_model()) { |
1024 | run_state->collector.reset(new StepStatsCollector(nullptr)); |
1025 | args.stats_collector = run_state->collector.get(); |
1026 | } |
1027 | |
1028 | for (auto& item : executors_and_keys->items) { |
1029 | item.executor->RunAsync(args, barrier->Get()); |
1030 | } |
1031 | |
1032 | *handle = run_state_args.handle; |
1033 | return OkStatus(); |
1034 | } |
1035 | |
1036 | Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs, |
1037 | const std::vector<string>& output_names, |
1038 | std::vector<Tensor>* outputs) { |
1039 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
1040 | std::vector<string> parts = str_util::Split(handle, ';'); |
1041 | const string& key = parts[0]; |
1042 | // Get the executors for this partial run. |
1043 | ExecutorsAndKeys* executors_and_keys; |
1044 | PartialRunState* run_state; |
1045 | { |
1046 | mutex_lock l(executor_lock_); // could use reader lock |
1047 | auto exc_it = executors_.find(key); |
1048 | if (exc_it == executors_.end()) { |
1049 | return errors::InvalidArgument( |
1050 | "Must run 'setup' before performing partial runs!" ); |
1051 | } |
1052 | executors_and_keys = exc_it->second.get(); |
1053 | |
1054 | auto prun_it = partial_runs_.find(handle); |
1055 | if (prun_it == partial_runs_.end()) { |
1056 | return errors::InvalidArgument( |
1057 | "Must run 'setup' before performing partial runs!" ); |
1058 | } |
1059 | run_state = prun_it->second.get(); |
1060 | |
1061 | // Make sure that this is a new set of feeds that are still pending. |
1062 | for (const auto& input : inputs) { |
1063 | auto it = run_state->pending_inputs.find(input.first); |
1064 | if (it == run_state->pending_inputs.end()) { |
1065 | return errors::InvalidArgument( |
1066 | "The feed " , input.first, |
1067 | " was not specified in partial_run_setup." ); |
1068 | } else if (it->second) { |
1069 | return errors::InvalidArgument("The feed " , input.first, |
1070 | " has already been fed." ); |
1071 | } |
1072 | } |
1073 | // Check that this is a new set of fetches that are still pending. |
1074 | for (const auto& output : output_names) { |
1075 | auto it = run_state->pending_outputs.find(output); |
1076 | if (it == run_state->pending_outputs.end()) { |
1077 | return errors::InvalidArgument( |
1078 | "The fetch " , output, " was not specified in partial_run_setup." ); |
1079 | } else if (it->second) { |
1080 | return errors::InvalidArgument("The fetch " , output, |
1081 | " has already been fetched." ); |
1082 | } |
1083 | } |
1084 | } |
1085 | |
1086 | // Check that this new set of fetches can be computed from all the |
1087 | // feeds we have supplied. |
1088 | TF_RETURN_IF_ERROR( |
1089 | CheckFetch(inputs, output_names, executors_and_keys, run_state)); |
1090 | |
1091 | // Send inputs. |
1092 | Status s = |
1093 | SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get()); |
1094 | |
1095 | // Receive outputs. |
1096 | if (s.ok()) { |
1097 | s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs); |
1098 | } |
1099 | |
1100 | // Save the output tensors of this run we choose to keep. |
1101 | if (s.ok()) { |
1102 | s = run_state->tensor_store.SaveTensors(output_names, &session_state_); |
1103 | } |
1104 | |
1105 | { |
1106 | mutex_lock l(executor_lock_); |
1107 | // Delete the run state if there is an error or all fetches are done. |
1108 | bool done = true; |
1109 | if (s.ok()) { |
1110 | { |
1111 | mutex_lock l(run_state->mu); |
1112 | if (!run_state->status.ok()) { |
1113 | LOG(WARNING) << "An error unrelated to this prun has been detected. " |
1114 | << run_state->status; |
1115 | } |
1116 | } |
1117 | for (const auto& input : inputs) { |
1118 | auto it = run_state->pending_inputs.find(input.first); |
1119 | it->second = true; |
1120 | } |
1121 | for (const auto& name : output_names) { |
1122 | auto it = run_state->pending_outputs.find(name); |
1123 | it->second = true; |
1124 | } |
1125 | done = run_state->PendingDone(); |
1126 | } |
1127 | if (done) { |
1128 | WaitForNotification(&run_state->executors_done, run_state, |
1129 | cancellation_manager_, operation_timeout_in_ms_); |
1130 | partial_runs_.erase(handle); |
1131 | } |
1132 | } |
1133 | |
1134 | return s; |
1135 | } |
1136 | |
1137 | Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor, |
1138 | Tensor* retrieved_tensor) { |
1139 | if (resource_tensor.dtype() != DT_RESOURCE) { |
1140 | return errors::InvalidArgument(strings::StrCat( |
1141 | "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: " , |
1142 | resource_tensor.dtype())); |
1143 | } |
1144 | |
1145 | const ResourceHandle& resource_handle = |
1146 | resource_tensor.scalar<ResourceHandle>()(); |
1147 | |
1148 | if (resource_handle.container() == |
1149 | SessionState::kTensorHandleResourceTypeName) { |
1150 | return session_state_.GetTensor(resource_handle.name(), retrieved_tensor); |
1151 | } else { |
1152 | return errors::InvalidArgument(strings::StrCat( |
1153 | "Invalid resource type hash code: " , resource_handle.hash_code(), |
1154 | "(name: " , resource_handle.name(), |
1155 | " type: " , resource_handle.maybe_type_name(), |
1156 | "). Perhaps a resource tensor was being provided as a feed? That is " |
1157 | "not currently allowed. Please file an issue at " |
1158 | "https://github.com/tensorflow/tensorflow/issues/new, ideally with a " |
1159 | "short code snippet that leads to this error message." )); |
1160 | } |
1161 | } |
1162 | |
1163 | Status DirectSession::SendPRunInputs(const NamedTensorList& inputs, |
1164 | const ExecutorsAndKeys* executors_and_keys, |
1165 | IntraProcessRendezvous* rendez) { |
1166 | Status s; |
1167 | Rendezvous::ParsedKey parsed; |
1168 | // Insert the input tensors into the local rendezvous by their |
1169 | // rendezvous key. |
1170 | for (const auto& input : inputs) { |
1171 | auto it = |
1172 | executors_and_keys->input_name_to_rendezvous_key.find(input.first); |
1173 | if (it == executors_and_keys->input_name_to_rendezvous_key.end()) { |
1174 | return errors::Internal("'" , input.first, "' is not a pre-defined feed." ); |
1175 | } |
1176 | const string& input_key = it->second; |
1177 | |
1178 | s = Rendezvous::ParseKey(input_key, &parsed); |
1179 | if (!s.ok()) { |
1180 | rendez->StartAbort(s); |
1181 | return s; |
1182 | } |
1183 | |
1184 | if (input.second.dtype() == DT_RESOURCE) { |
1185 | Tensor tensor_from_handle; |
1186 | s = ResourceHandleToInputTensor(input.second, &tensor_from_handle); |
1187 | if (s.ok()) { |
1188 | s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false); |
1189 | } |
1190 | } else { |
1191 | s = rendez->Send(parsed, Rendezvous::Args(), input.second, false); |
1192 | } |
1193 | |
1194 | if (!s.ok()) { |
1195 | rendez->StartAbort(s); |
1196 | return s; |
1197 | } |
1198 | } |
1199 | return OkStatus(); |
1200 | } |
1201 | |
1202 | Status DirectSession::RecvPRunOutputs( |
1203 | const std::vector<string>& output_names, |
1204 | const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state, |
1205 | std::vector<Tensor>* outputs) { |
1206 | Status s; |
1207 | if (!output_names.empty()) { |
1208 | outputs->resize(output_names.size()); |
1209 | } |
1210 | |
1211 | Rendezvous::ParsedKey parsed; |
1212 | // Get the outputs from the rendezvous |
1213 | for (size_t output_offset = 0; output_offset < output_names.size(); |
1214 | ++output_offset) { |
1215 | const string& output_name = output_names[output_offset]; |
1216 | auto it = |
1217 | executors_and_keys->output_name_to_rendezvous_key.find(output_name); |
1218 | if (it == executors_and_keys->output_name_to_rendezvous_key.end()) { |
1219 | return errors::Internal("'" , output_name, |
1220 | "' is not a pre-defined fetch." ); |
1221 | } |
1222 | const string& output_key = it->second; |
1223 | Tensor output_tensor; |
1224 | bool is_dead; |
1225 | |
1226 | s = Rendezvous::ParseKey(output_key, &parsed); |
1227 | if (s.ok()) { |
1228 | // Fetch data from the Rendezvous. |
1229 | s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor, |
1230 | &is_dead, operation_timeout_in_ms_); |
1231 | if (is_dead && s.ok()) { |
1232 | s = errors::InvalidArgument("The tensor returned for " , output_name, |
1233 | " was not valid." ); |
1234 | } |
1235 | } |
1236 | if (!s.ok()) { |
1237 | run_state->rendez->StartAbort(s); |
1238 | outputs->clear(); |
1239 | return s; |
1240 | } |
1241 | |
1242 | (*outputs)[output_offset] = output_tensor; |
1243 | } |
1244 | return OkStatus(); |
1245 | } |
1246 | |
1247 | Status DirectSession::CheckFetch(const NamedTensorList& feeds, |
1248 | const std::vector<string>& fetches, |
1249 | const ExecutorsAndKeys* executors_and_keys, |
1250 | const PartialRunState* run_state) { |
1251 | const Graph* graph = executors_and_keys->graph.get(); |
1252 | const NameNodeMap* name_to_node = &executors_and_keys->name_to_node; |
1253 | |
1254 | // Build the set of pending feeds that we haven't seen. |
1255 | std::unordered_set<TensorId, TensorId::Hasher> pending_feeds; |
1256 | { |
1257 | mutex_lock l(executor_lock_); |
1258 | for (const auto& input : run_state->pending_inputs) { |
1259 | // Skip if the feed has already been fed. |
1260 | if (input.second) continue; |
1261 | TensorId id(ParseTensorName(input.first)); |
1262 | auto it = name_to_node->find(id.first); |
1263 | if (it == name_to_node->end()) { |
1264 | return errors::NotFound("Feed " , input.first, ": not found" ); |
1265 | } |
1266 | pending_feeds.insert(id); |
1267 | } |
1268 | } |
1269 | for (const auto& it : feeds) { |
1270 | TensorId id(ParseTensorName(it.first)); |
1271 | pending_feeds.erase(id); |
1272 | } |
1273 | |
1274 | // Initialize the stack with the fetch nodes. |
1275 | std::vector<const Node*> stack; |
1276 | for (const string& fetch : fetches) { |
1277 | TensorId id(ParseTensorName(fetch)); |
1278 | auto it = name_to_node->find(id.first); |
1279 | if (it == name_to_node->end()) { |
1280 | return errors::NotFound("Fetch " , fetch, ": not found" ); |
1281 | } |
1282 | stack.push_back(it->second); |
1283 | } |
1284 | |
1285 | // Any tensor needed for fetches can't be in pending_feeds. |
1286 | std::vector<bool> visited(graph->num_node_ids(), false); |
1287 | while (!stack.empty()) { |
1288 | const Node* n = stack.back(); |
1289 | stack.pop_back(); |
1290 | |
1291 | for (const Edge* in_edge : n->in_edges()) { |
1292 | const Node* in_node = in_edge->src(); |
1293 | if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) { |
1294 | return errors::InvalidArgument("Fetch " , in_node->name(), ":" , |
1295 | in_edge->src_output(), |
1296 | " can't be computed from the feeds" |
1297 | " that have been fed so far." ); |
1298 | } |
1299 | if (!visited[in_node->id()]) { |
1300 | visited[in_node->id()] = true; |
1301 | stack.push_back(in_node); |
1302 | } |
1303 | } |
1304 | } |
1305 | return OkStatus(); |
1306 | } |
1307 | |
1308 | Status DirectSession::CreateExecutors( |
1309 | const CallableOptions& callable_options, |
1310 | std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys, |
1311 | std::unique_ptr<FunctionInfo>* out_func_info, |
1312 | RunStateArgs* run_state_args) { |
1313 | BuildGraphOptions options; |
1314 | options.callable_options = callable_options; |
1315 | options.use_function_convention = !run_state_args->is_partial_run; |
1316 | options.collective_graph_key = |
1317 | callable_options.run_options().experimental().collective_graph_key(); |
1318 | if (options_.config.experimental() |
1319 | .collective_deterministic_sequential_execution()) { |
1320 | options.collective_order = GraphCollectiveOrder::kEdges; |
1321 | } else if (options_.config.experimental().collective_nccl()) { |
1322 | options.collective_order = GraphCollectiveOrder::kAttrs; |
1323 | } |
1324 | |
1325 | std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); |
1326 | std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); |
1327 | |
1328 | ek->callable_options = callable_options; |
1329 | |
1330 | std::unordered_map<string, std::unique_ptr<Graph>> graphs; |
1331 | TF_RETURN_IF_ERROR(CreateGraphs( |
1332 | options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types, |
1333 | &ek->output_types, &ek->collective_graph_key)); |
1334 | |
1335 | if (run_state_args->is_partial_run) { |
1336 | ek->graph = std::move(run_state_args->graph); |
1337 | std::unordered_set<StringPiece, StringPieceHasher> names; |
1338 | for (const string& input : callable_options.feed()) { |
1339 | TensorId id(ParseTensorName(input)); |
1340 | names.emplace(id.first); |
1341 | } |
1342 | for (const string& output : callable_options.fetch()) { |
1343 | TensorId id(ParseTensorName(output)); |
1344 | names.emplace(id.first); |
1345 | } |
1346 | for (Node* n : ek->graph->nodes()) { |
1347 | if (names.count(n->name()) > 0) { |
1348 | ek->name_to_node.insert({n->name(), n}); |
1349 | } |
1350 | } |
1351 | } |
1352 | ek->items.reserve(graphs.size()); |
1353 | const auto& optimizer_opts = |
1354 | options_.config.graph_options().optimizer_options(); |
1355 | |
1356 | int graph_def_version = graphs.begin()->second->versions().producer(); |
1357 | |
1358 | const auto* session_metadata = |
1359 | options_.config.experimental().has_session_metadata() |
1360 | ? &options_.config.experimental().session_metadata() |
1361 | : nullptr; |
1362 | func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime( |
1363 | device_mgr_.get(), options_.env, &options_.config, graph_def_version, |
1364 | func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first, |
1365 | /*parent=*/nullptr, session_metadata, |
1366 | Rendezvous::Factory{ |
1367 | [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) { |
1368 | *r = new IntraProcessRendezvous(device_mgr); |
1369 | return OkStatus(); |
1370 | }})); |
1371 | |
1372 | GraphOptimizer optimizer(optimizer_opts); |
1373 | for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) { |
1374 | const string& partition_name = iter->first; |
1375 | std::unique_ptr<Graph>& partition_graph = iter->second; |
1376 | |
1377 | Device* device; |
1378 | TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device)); |
1379 | |
1380 | ek->items.resize(ek->items.size() + 1); |
1381 | auto* item = &(ek->items.back()); |
1382 | auto lib = func_info->proc_flr->GetFLR(partition_name); |
1383 | if (lib == nullptr) { |
1384 | return errors::Internal("Could not find device: " , partition_name); |
1385 | } |
1386 | item->flib = lib; |
1387 | |
1388 | LocalExecutorParams params; |
1389 | params.device = device; |
1390 | params.session_metadata = session_metadata; |
1391 | params.function_library = lib; |
1392 | auto opseg = device->op_segment(); |
1393 | params.create_kernel = |
1394 | [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props, |
1395 | OpKernel** kernel) { |
1396 | // NOTE(mrry): We must not share function kernels (implemented |
1397 | // using `CallOp`) between subgraphs, because `CallOp::handle_` |
1398 | // is tied to a particular subgraph. Even if the function itself |
1399 | // is stateful, the `CallOp` that invokes it is not. |
1400 | if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) { |
1401 | return lib->CreateKernel(props, kernel); |
1402 | } |
1403 | auto create_fn = [lib, &props](OpKernel** kernel) { |
1404 | return lib->CreateKernel(props, kernel); |
1405 | }; |
1406 | // Kernels created for subgraph nodes need to be cached. On |
1407 | // cache miss, create_fn() is invoked to create a kernel based |
1408 | // on the function library here + global op registry. |
1409 | return opseg->FindOrCreate(session_handle_, props->node_def.name(), |
1410 | kernel, create_fn); |
1411 | }; |
1412 | params.delete_kernel = [lib](OpKernel* kernel) { |
1413 | if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) |
1414 | delete kernel; |
1415 | }; |
1416 | |
1417 | optimizer.Optimize(lib, options_.env, device, &partition_graph, |
1418 | GraphOptimizer::Options()); |
1419 | |
1420 | // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. |
1421 | const DebugOptions& debug_options = |
1422 | options.callable_options.run_options().debug_options(); |
1423 | if (!debug_options.debug_tensor_watch_opts().empty()) { |
1424 | TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug( |
1425 | debug_options, partition_graph.get(), params.device)); |
1426 | } |
1427 | |
1428 | TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()), |
1429 | device->name(), |
1430 | partition_graph.get())); |
1431 | |
1432 | item->executor = nullptr; |
1433 | item->device = device; |
1434 | auto executor_type = options_.config.experimental().executor_type(); |
1435 | TF_RETURN_IF_ERROR( |
1436 | NewExecutor(executor_type, params, *partition_graph, &item->executor)); |
1437 | if (!options_.config.experimental().disable_output_partition_graphs() || |
1438 | options_.config.graph_options().build_cost_model() > 0) { |
1439 | item->graph = std::move(partition_graph); |
1440 | } |
1441 | } |
1442 | |
1443 | // Cache the mapping from input/output names to graph elements to |
1444 | // avoid recomputing it every time. |
1445 | if (!run_state_args->is_partial_run) { |
1446 | // For regular `Run()`, we use the function calling convention, and so |
1447 | // maintain a mapping from input/output names to |
1448 | // argument/return-value ordinal index. |
1449 | for (int i = 0; i < callable_options.feed().size(); ++i) { |
1450 | const string& input = callable_options.feed(i); |
1451 | ek->input_name_to_index[input] = i; |
1452 | } |
1453 | for (int i = 0; i < callable_options.fetch().size(); ++i) { |
1454 | const string& output = callable_options.fetch(i); |
1455 | ek->output_name_to_index[output] = i; |
1456 | } |
1457 | } else { |
1458 | // For `PRun()`, we use the rendezvous calling convention, and so |
1459 | // maintain a mapping from input/output names to rendezvous keys. |
1460 | // |
1461 | // We always use the first device as the device name portion of the |
1462 | // key, even if we're feeding another graph. |
1463 | for (int i = 0; i < callable_options.feed().size(); ++i) { |
1464 | const string& input = callable_options.feed(i); |
1465 | ek->input_name_to_rendezvous_key[input] = GetRendezvousKey( |
1466 | input, device_set_.client_device()->attributes(), FrameAndIter(0, 0)); |
1467 | } |
1468 | for (int i = 0; i < callable_options.fetch().size(); ++i) { |
1469 | const string& output = callable_options.fetch(i); |
1470 | ek->output_name_to_rendezvous_key[output] = |
1471 | GetRendezvousKey(output, device_set_.client_device()->attributes(), |
1472 | FrameAndIter(0, 0)); |
1473 | } |
1474 | } |
1475 | |
1476 | *out_executors_and_keys = std::move(ek); |
1477 | *out_func_info = std::move(func_info); |
1478 | return OkStatus(); |
1479 | } |
1480 | |
1481 | Status DirectSession::GetOrCreateExecutors( |
1482 | gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs, |
1483 | gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys, |
1484 | RunStateArgs* run_state_args) { |
1485 | int64_t handle_name_counter_value = -1; |
1486 | if (LogMemory::IsEnabled() || run_state_args->is_partial_run) { |
1487 | handle_name_counter_value = handle_name_counter_.fetch_add(1); |
1488 | } |
1489 | |
1490 | string debug_tensor_watches_summary; |
1491 | if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) { |
1492 | debug_tensor_watches_summary = SummarizeDebugTensorWatches( |
1493 | run_state_args->debug_options.debug_tensor_watch_opts()); |
1494 | } |
1495 | |
1496 | // Fast lookup path, no sorting. |
1497 | const string key = strings::StrCat( |
1498 | absl::StrJoin(inputs, "," ), "->" , absl::StrJoin(outputs, "," ), "/" , |
1499 | absl::StrJoin(target_nodes, "," ), "/" , run_state_args->is_partial_run, |
1500 | "/" , debug_tensor_watches_summary); |
1501 | // Set the handle, if it's needed to log memory or for partial run. |
1502 | if (handle_name_counter_value >= 0) { |
1503 | run_state_args->handle = |
1504 | strings::StrCat(key, ";" , handle_name_counter_value); |
1505 | } |
1506 | |
1507 | // See if we already have the executors for this run. |
1508 | { |
1509 | mutex_lock l(executor_lock_); // could use reader lock |
1510 | auto it = executors_.find(key); |
1511 | if (it != executors_.end()) { |
1512 | *executors_and_keys = it->second.get(); |
1513 | return OkStatus(); |
1514 | } |
1515 | } |
1516 | |
1517 | // Slow lookup path, the unsorted key missed the cache. |
1518 | // Sort the inputs and outputs, and look up with the sorted key in case an |
1519 | // earlier call used a different order of inputs and outputs. |
1520 | // |
1521 | // We could consider some other signature instead of sorting that |
1522 | // preserves the same property to avoid the sort in the future. |
1523 | std::vector<string> inputs_sorted(inputs.begin(), inputs.end()); |
1524 | std::sort(inputs_sorted.begin(), inputs_sorted.end()); |
1525 | std::vector<string> outputs_sorted(outputs.begin(), outputs.end()); |
1526 | std::sort(outputs_sorted.begin(), outputs_sorted.end()); |
1527 | std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end()); |
1528 | std::sort(tn_sorted.begin(), tn_sorted.end()); |
1529 | |
1530 | const string sorted_key = strings::StrCat( |
1531 | absl::StrJoin(inputs_sorted, "," ), "->" , |
1532 | absl::StrJoin(outputs_sorted, "," ), "/" , absl::StrJoin(tn_sorted, "," ), |
1533 | "/" , run_state_args->is_partial_run, "/" , debug_tensor_watches_summary); |
1534 | // Set the handle, if its needed to log memory or for partial run. |
1535 | if (handle_name_counter_value >= 0) { |
1536 | run_state_args->handle = |
1537 | strings::StrCat(sorted_key, ";" , handle_name_counter_value); |
1538 | } |
1539 | |
1540 | // See if we already have the executors for this run. |
1541 | { |
1542 | mutex_lock l(executor_lock_); |
1543 | auto it = executors_.find(sorted_key); |
1544 | if (it != executors_.end()) { |
1545 | *executors_and_keys = it->second.get(); |
1546 | return OkStatus(); |
1547 | } |
1548 | } |
1549 | |
1550 | // Nothing found, so create the executors and store in the cache. |
1551 | // The executor_lock_ is intentionally released while executors are |
1552 | // being created. |
1553 | CallableOptions callable_options; |
1554 | callable_options.mutable_feed()->Reserve(inputs_sorted.size()); |
1555 | for (const string& input : inputs_sorted) { |
1556 | callable_options.add_feed(input); |
1557 | } |
1558 | callable_options.mutable_fetch()->Reserve(outputs_sorted.size()); |
1559 | for (const string& output : outputs_sorted) { |
1560 | callable_options.add_fetch(output); |
1561 | } |
1562 | callable_options.mutable_target()->Reserve(tn_sorted.size()); |
1563 | for (const string& target : tn_sorted) { |
1564 | callable_options.add_target(target); |
1565 | } |
1566 | *callable_options.mutable_run_options()->mutable_debug_options() = |
1567 | run_state_args->debug_options; |
1568 | callable_options.mutable_run_options() |
1569 | ->mutable_experimental() |
1570 | ->set_collective_graph_key(run_state_args->collective_graph_key); |
1571 | std::unique_ptr<ExecutorsAndKeys> ek; |
1572 | std::unique_ptr<FunctionInfo> func_info; |
1573 | TF_RETURN_IF_ERROR( |
1574 | CreateExecutors(callable_options, &ek, &func_info, run_state_args)); |
1575 | |
1576 | // Reacquire the lock, try to insert into the map. |
1577 | mutex_lock l(executor_lock_); |
1578 | |
1579 | // Another thread may have created the entry before us, in which case we will |
1580 | // reuse the already created one. |
1581 | auto insert_result = executors_.emplace( |
1582 | sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek))); |
1583 | if (insert_result.second) { |
1584 | functions_.push_back(std::move(func_info)); |
1585 | } |
1586 | |
1587 | // Insert the value under the original key, so the fast path lookup will work |
1588 | // if the user uses the same order of inputs, outputs, and targets again. |
1589 | executors_.emplace(key, insert_result.first->second); |
1590 | *executors_and_keys = insert_result.first->second.get(); |
1591 | |
1592 | return OkStatus(); |
1593 | } |
1594 | |
1595 | Status DirectSession::CreateGraphs( |
1596 | const BuildGraphOptions& subgraph_options, |
1597 | std::unordered_map<string, std::unique_ptr<Graph>>* outputs, |
1598 | std::unique_ptr<FunctionLibraryDefinition>* flib_def, |
1599 | RunStateArgs* run_state_args, DataTypeVector* input_types, |
1600 | DataTypeVector* output_types, int64_t* collective_graph_key) { |
1601 | mutex_lock l(graph_state_lock_); |
1602 | if (finalized_) { |
1603 | return errors::FailedPrecondition("Session has been finalized." ); |
1604 | } |
1605 | |
1606 | std::unique_ptr<ClientGraph> client_graph; |
1607 | |
1608 | std::unique_ptr<GraphExecutionState> temp_exec_state_holder; |
1609 | GraphExecutionState* execution_state = nullptr; |
1610 | if (options_.config.graph_options().place_pruned_graph()) { |
1611 | // Because we are placing pruned graphs, we need to create a |
1612 | // new GraphExecutionState for every new unseen graph, |
1613 | // and then place it. |
1614 | GraphExecutionStateOptions prune_options; |
1615 | prune_options.device_set = &device_set_; |
1616 | prune_options.session_options = &options_; |
1617 | prune_options.stateful_placements = stateful_placements_; |
1618 | prune_options.session_handle = session_handle_; |
1619 | TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph( |
1620 | *execution_state_, prune_options, subgraph_options, |
1621 | &temp_exec_state_holder, &client_graph)); |
1622 | execution_state = temp_exec_state_holder.get(); |
1623 | } else { |
1624 | execution_state = execution_state_.get(); |
1625 | TF_RETURN_IF_ERROR( |
1626 | execution_state->BuildGraph(subgraph_options, &client_graph)); |
1627 | } |
1628 | *collective_graph_key = client_graph->collective_graph_key; |
1629 | |
1630 | if (subgraph_options.callable_options.feed_size() != |
1631 | client_graph->feed_types.size()) { |
1632 | return errors::Internal( |
1633 | "Graph pruning failed: requested number of feed endpoints = " , |
1634 | subgraph_options.callable_options.feed_size(), |
1635 | " versus number of pruned feed endpoints = " , |
1636 | client_graph->feed_types.size()); |
1637 | } |
1638 | if (subgraph_options.callable_options.fetch_size() != |
1639 | client_graph->fetch_types.size()) { |
1640 | return errors::Internal( |
1641 | "Graph pruning failed: requested number of fetch endpoints = " , |
1642 | subgraph_options.callable_options.fetch_size(), |
1643 | " versus number of pruned fetch endpoints = " , |
1644 | client_graph->fetch_types.size()); |
1645 | } |
1646 | |
1647 | auto current_stateful_placements = execution_state->GetStatefulPlacements(); |
1648 | // Update our current state based on the execution_state's |
1649 | // placements. If there are any mismatches for a node, |
1650 | // we should fail, as this should never happen. |
1651 | for (const auto& placement_pair : current_stateful_placements) { |
1652 | const string& node_name = placement_pair.first; |
1653 | const string& placement = placement_pair.second; |
1654 | auto iter = stateful_placements_.find(node_name); |
1655 | if (iter == stateful_placements_.end()) { |
1656 | stateful_placements_.insert(std::make_pair(node_name, placement)); |
1657 | } else if (iter->second != placement) { |
1658 | return errors::Internal( |
1659 | "Stateful placement mismatch. " |
1660 | "Current assignment of " , |
1661 | node_name, " to " , iter->second, " does not match " , placement); |
1662 | } |
1663 | } |
1664 | |
1665 | stateful_placements_ = execution_state->GetStatefulPlacements(); |
1666 | |
1667 | // Remember the graph in run state if this is a partial run. |
1668 | if (run_state_args->is_partial_run) { |
1669 | run_state_args->graph.reset(new Graph(flib_def_.get())); |
1670 | CopyGraph(*execution_state->full_graph(), run_state_args->graph.get()); |
1671 | } |
1672 | |
1673 | // Partition the graph across devices. |
1674 | PartitionOptions popts; |
1675 | popts.node_to_loc = [](const Node* node) { |
1676 | return node->assigned_device_name(); |
1677 | }; |
1678 | popts.new_name = [this](const string& prefix) { |
1679 | return strings::StrCat(prefix, "/_" , edge_name_counter_.fetch_add(1)); |
1680 | }; |
1681 | popts.get_incarnation = [](const string& name) { |
1682 | // The direct session does not have changing incarnation numbers. |
1683 | // Just return '1'. |
1684 | return 1; |
1685 | }; |
1686 | popts.flib_def = flib_def->get(); |
1687 | popts.control_flow_added = false; |
1688 | |
1689 | std::unordered_map<string, GraphDef> partitions; |
1690 | TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions)); |
1691 | |
1692 | std::vector<string> device_names; |
1693 | for (auto device : devices_) { |
1694 | // Extract the LocalName from the device. |
1695 | device_names.push_back(DeviceNameUtils::LocalName(device->name())); |
1696 | } |
1697 | |
1698 | // Check for valid partitions. |
1699 | for (const auto& partition : partitions) { |
1700 | const string local_partition_name = |
1701 | DeviceNameUtils::LocalName(partition.first); |
1702 | if (std::count(device_names.begin(), device_names.end(), |
1703 | local_partition_name) == 0) { |
1704 | return errors::InvalidArgument( |
1705 | "Creating a partition for " , local_partition_name, |
1706 | " which doesn't exist in the list of available devices. Available " |
1707 | "devices: " , |
1708 | absl::StrJoin(device_names, "," )); |
1709 | } |
1710 | } |
1711 | |
1712 | for (auto& partition : partitions) { |
1713 | std::unique_ptr<Graph> device_graph( |
1714 | new Graph(client_graph->flib_def.get())); |
1715 | device_graph->SetConstructionContext(ConstructionContext::kDirectSession); |
1716 | GraphConstructorOptions device_opts; |
1717 | // There are internal operations (e.g., send/recv) that we now allow. |
1718 | device_opts.allow_internal_ops = true; |
1719 | device_opts.expect_device_spec = true; |
1720 | TF_RETURN_IF_ERROR(ConvertGraphDefToGraph( |
1721 | device_opts, std::move(partition.second), device_graph.get())); |
1722 | outputs->emplace(partition.first, std::move(device_graph)); |
1723 | } |
1724 | |
1725 | GraphOptimizationPassOptions optimization_options; |
1726 | optimization_options.session_options = &options_; |
1727 | optimization_options.flib_def = client_graph->flib_def.get(); |
1728 | optimization_options.partition_graphs = outputs; |
1729 | TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( |
1730 | OptimizationPassRegistry::POST_PARTITIONING, optimization_options)); |
1731 | |
1732 | Status s; |
1733 | for (auto& partition : *outputs) { |
1734 | const string& partition_name = partition.first; |
1735 | std::unique_ptr<Graph>* graph = &partition.second; |
1736 | |
1737 | VLOG(2) << "Created " << DebugString(graph->get()) << " for " |
1738 | << partition_name; |
1739 | |
1740 | // Give the device an opportunity to rewrite its subgraph. |
1741 | Device* d; |
1742 | s = device_mgr_->LookupDevice(partition_name, &d); |
1743 | if (!s.ok()) break; |
1744 | s = d->MaybeRewriteGraph(graph); |
1745 | if (!s.ok()) { |
1746 | break; |
1747 | } |
1748 | } |
1749 | *flib_def = std::move(client_graph->flib_def); |
1750 | std::swap(*input_types, client_graph->feed_types); |
1751 | std::swap(*output_types, client_graph->fetch_types); |
1752 | return s; |
1753 | } |
1754 | |
1755 | ::tensorflow::Status DirectSession::ListDevices( |
1756 | std::vector<DeviceAttributes>* response) { |
1757 | response->clear(); |
1758 | response->reserve(devices_.size()); |
1759 | for (Device* d : devices_) { |
1760 | const DeviceAttributes& attrs = d->attributes(); |
1761 | response->emplace_back(attrs); |
1762 | } |
1763 | return OkStatus(); |
1764 | } |
1765 | |
1766 | ::tensorflow::Status DirectSession::Reset( |
1767 | const std::vector<string>& containers) { |
1768 | device_mgr_->ClearContainers(containers); |
1769 | return OkStatus(); |
1770 | } |
1771 | |
1772 | ::tensorflow::Status DirectSession::Close() { |
1773 | cancellation_manager_->StartCancel(); |
1774 | { |
1775 | mutex_lock l(closed_lock_); |
1776 | if (closed_) return OkStatus(); |
1777 | closed_ = true; |
1778 | } |
1779 | if (factory_ != nullptr) factory_->Deregister(this); |
1780 | return OkStatus(); |
1781 | } |
1782 | |
1783 | DirectSession::RunState::RunState(int64_t step_id, |
1784 | const std::vector<Device*>* devices) |
1785 | : step_container(step_id, [devices, step_id](const string& name) { |
1786 | for (auto d : *devices) { |
1787 | if (!d->resource_manager()->Cleanup(name).ok()) { |
1788 | // Do nothing... |
1789 | } |
1790 | ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr(); |
1791 | if (sam) sam->Cleanup(step_id); |
1792 | } |
1793 | }) {} |
1794 | |
1795 | DirectSession::PartialRunState::PartialRunState( |
1796 | const std::vector<string>& pending_input_names, |
1797 | const std::vector<string>& pending_output_names, int64_t step_id, |
1798 | const std::vector<Device*>* devices) |
1799 | : RunState(step_id, devices) { |
1800 | // Initially all the feeds and fetches are pending. |
1801 | for (auto& name : pending_input_names) { |
1802 | pending_inputs[name] = false; |
1803 | } |
1804 | for (auto& name : pending_output_names) { |
1805 | pending_outputs[name] = false; |
1806 | } |
1807 | } |
1808 | |
1809 | DirectSession::PartialRunState::~PartialRunState() { |
1810 | if (rendez != nullptr) { |
1811 | rendez->StartAbort(errors::Cancelled("PRun cancellation" )); |
1812 | executors_done.WaitForNotification(); |
1813 | } |
1814 | } |
1815 | |
1816 | bool DirectSession::PartialRunState::PendingDone() const { |
1817 | for (const auto& it : pending_inputs) { |
1818 | if (!it.second) return false; |
1819 | } |
1820 | for (const auto& it : pending_outputs) { |
1821 | if (!it.second) return false; |
1822 | } |
1823 | return true; |
1824 | } |
1825 | |
1826 | void DirectSession::WaitForNotification(Notification* n, RunState* run_state, |
1827 | CancellationManager* cm, |
1828 | int64_t timeout_in_ms) { |
1829 | const Status status = WaitForNotification(n, timeout_in_ms); |
1830 | if (!status.ok()) { |
1831 | { |
1832 | mutex_lock l(run_state->mu); |
1833 | run_state->status.Update(status); |
1834 | } |
1835 | cm->StartCancel(); |
1836 | // We must wait for the executors to complete, because they have borrowed |
1837 | // references to `cm` and other per-step state. After this notification, it |
1838 | // is safe to clean up the step. |
1839 | n->WaitForNotification(); |
1840 | } |
1841 | } |
1842 | |
1843 | ::tensorflow::Status DirectSession::WaitForNotification( |
1844 | Notification* notification, int64_t timeout_in_ms) { |
1845 | if (timeout_in_ms > 0) { |
1846 | const int64_t timeout_in_us = timeout_in_ms * 1000; |
1847 | const bool notified = |
1848 | WaitForNotificationWithTimeout(notification, timeout_in_us); |
1849 | if (!notified) { |
1850 | return Status(error::DEADLINE_EXCEEDED, |
1851 | "Timed out waiting for notification" ); |
1852 | } |
1853 | } else { |
1854 | notification->WaitForNotification(); |
1855 | } |
1856 | return OkStatus(); |
1857 | } |
1858 | |
1859 | Status DirectSession::MakeCallable(const CallableOptions& callable_options, |
1860 | CallableHandle* out_handle) { |
1861 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
1862 | TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()" )); |
1863 | |
1864 | std::unique_ptr<ExecutorsAndKeys> ek; |
1865 | std::unique_ptr<FunctionInfo> func_info; |
1866 | RunStateArgs run_state_args(callable_options.run_options().debug_options()); |
1867 | TF_RETURN_IF_ERROR( |
1868 | CreateExecutors(callable_options, &ek, &func_info, &run_state_args)); |
1869 | { |
1870 | mutex_lock l(callables_lock_); |
1871 | *out_handle = next_callable_handle_++; |
1872 | callables_[*out_handle] = {std::move(ek), std::move(func_info)}; |
1873 | } |
1874 | return OkStatus(); |
1875 | } |
1876 | |
1877 | class DirectSession::RunCallableCallFrame : public CallFrameInterface { |
1878 | public: |
1879 | RunCallableCallFrame(DirectSession* session, |
1880 | ExecutorsAndKeys* executors_and_keys, |
1881 | const std::vector<Tensor>* feed_tensors, |
1882 | std::vector<Tensor>* fetch_tensors) |
1883 | : session_(session), |
1884 | executors_and_keys_(executors_and_keys), |
1885 | feed_tensors_(feed_tensors), |
1886 | fetch_tensors_(fetch_tensors) {} |
1887 | |
1888 | size_t num_args() const override { |
1889 | return executors_and_keys_->input_types.size(); |
1890 | } |
1891 | size_t num_retvals() const override { |
1892 | return executors_and_keys_->output_types.size(); |
1893 | } |
1894 | |
1895 | Status GetArg(int index, const Tensor** val) override { |
1896 | if (TF_PREDICT_FALSE(index > feed_tensors_->size())) { |
1897 | return errors::Internal("Args index out of bounds: " , index); |
1898 | } else { |
1899 | *val = &(*feed_tensors_)[index]; |
1900 | } |
1901 | return OkStatus(); |
1902 | } |
1903 | |
1904 | Status SetRetval(int index, const Tensor& val) override { |
1905 | if (index > fetch_tensors_->size()) { |
1906 | return errors::Internal("RetVal index out of bounds: " , index); |
1907 | } |
1908 | (*fetch_tensors_)[index] = val; |
1909 | return OkStatus(); |
1910 | } |
1911 | |
1912 | private: |
1913 | DirectSession* const session_; // Not owned. |
1914 | ExecutorsAndKeys* const executors_and_keys_; // Not owned. |
1915 | const std::vector<Tensor>* const feed_tensors_; // Not owned. |
1916 | std::vector<Tensor>* const fetch_tensors_; // Not owned. |
1917 | }; |
1918 | |
1919 | ::tensorflow::Status DirectSession::RunCallable( |
1920 | CallableHandle handle, const std::vector<Tensor>& feed_tensors, |
1921 | std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) { |
1922 | return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata, |
1923 | thread::ThreadPoolOptions()); |
1924 | } |
1925 | |
1926 | ::tensorflow::Status DirectSession::RunCallable( |
1927 | CallableHandle handle, const std::vector<Tensor>& feed_tensors, |
1928 | std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata, |
1929 | const thread::ThreadPoolOptions& threadpool_options) { |
1930 | TF_RETURN_IF_ERROR(CheckNotClosed()); |
1931 | TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()" )); |
1932 | direct_session_runs->GetCell()->IncrementBy(1); |
1933 | |
1934 | // Check if we already have an executor for these arguments. |
1935 | std::shared_ptr<ExecutorsAndKeys> executors_and_keys; |
1936 | const int64_t step_id = step_id_counter_.fetch_add(1); |
1937 | |
1938 | { |
1939 | tf_shared_lock l(callables_lock_); |
1940 | if (handle >= next_callable_handle_) { |
1941 | return errors::InvalidArgument("No such callable handle: " , handle); |
1942 | } |
1943 | executors_and_keys = callables_[handle].executors_and_keys; |
1944 | } |
1945 | |
1946 | if (!executors_and_keys) { |
1947 | return errors::InvalidArgument( |
1948 | "Attempted to run callable after handle was released: " , handle); |
1949 | } |
1950 | |
1951 | // NOTE(mrry): Debug options are not currently supported in the |
1952 | // callable interface. |
1953 | DebugOptions debug_options; |
1954 | RunStateArgs run_state_args(debug_options); |
1955 | |
1956 | // Configure a call frame for the step, which we use to feed and |
1957 | // fetch values to and from the executors. |
1958 | if (feed_tensors.size() != executors_and_keys->input_types.size()) { |
1959 | return errors::InvalidArgument( |
1960 | "Expected " , executors_and_keys->input_types.size(), |
1961 | " feed tensors, but got " , feed_tensors.size()); |
1962 | } |
1963 | if (fetch_tensors != nullptr) { |
1964 | fetch_tensors->resize(executors_and_keys->output_types.size()); |
1965 | } else if (!executors_and_keys->output_types.empty()) { |
1966 | return errors::InvalidArgument( |
1967 | "`fetch_tensors` must be provided when the callable has one or more " |
1968 | "outputs." ); |
1969 | } |
1970 | |
1971 | size_t input_size = 0; |
1972 | bool any_resource_feeds = false; |
1973 | for (auto& tensor : feed_tensors) { |
1974 | input_size += tensor.AllocatedBytes(); |
1975 | any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE; |
1976 | } |
1977 | metrics::RecordGraphInputTensors(input_size); |
1978 | |
1979 | std::unique_ptr<std::vector<Tensor>> converted_feed_tensors; |
1980 | const std::vector<Tensor>* actual_feed_tensors; |
1981 | |
1982 | if (TF_PREDICT_FALSE(any_resource_feeds)) { |
1983 | converted_feed_tensors = std::make_unique<std::vector<Tensor>>(); |
1984 | converted_feed_tensors->reserve(feed_tensors.size()); |
1985 | for (const Tensor& t : feed_tensors) { |
1986 | if (t.dtype() == DT_RESOURCE) { |
1987 | converted_feed_tensors->emplace_back(); |
1988 | Tensor* tensor_from_handle = &converted_feed_tensors->back(); |
1989 | TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle)); |
1990 | } else { |
1991 | converted_feed_tensors->emplace_back(t); |
1992 | } |
1993 | } |
1994 | actual_feed_tensors = converted_feed_tensors.get(); |
1995 | } else { |
1996 | actual_feed_tensors = &feed_tensors; |
1997 | } |
1998 | |
1999 | // A specialized CallFrame implementation that takes advantage of the |
2000 | // optimized RunCallable interface. |
2001 | RunCallableCallFrame call_frame(this, executors_and_keys.get(), |
2002 | actual_feed_tensors, fetch_tensors); |
2003 | |
2004 | if (LogMemory::IsEnabled()) { |
2005 | LogMemory::RecordStep(step_id, run_state_args.handle); |
2006 | } |
2007 | |
2008 | TF_RETURN_IF_ERROR(RunInternal( |
2009 | step_id, executors_and_keys->callable_options.run_options(), &call_frame, |
2010 | executors_and_keys.get(), run_metadata, threadpool_options)); |
2011 | |
2012 | if (fetch_tensors != nullptr) { |
2013 | size_t output_size = 0; |
2014 | for (auto& tensor : *fetch_tensors) { |
2015 | output_size += tensor.AllocatedBytes(); |
2016 | } |
2017 | metrics::RecordGraphOutputTensors(output_size); |
2018 | } |
2019 | |
2020 | return OkStatus(); |
2021 | } |
2022 | |
2023 | ::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) { |
2024 | mutex_lock l(callables_lock_); |
2025 | if (handle >= next_callable_handle_) { |
2026 | return errors::InvalidArgument("No such callable handle: " , handle); |
2027 | } |
2028 | callables_.erase(handle); |
2029 | return OkStatus(); |
2030 | } |
2031 | |
2032 | Status DirectSession::Finalize() { |
2033 | mutex_lock l(graph_state_lock_); |
2034 | if (finalized_) { |
2035 | return errors::FailedPrecondition("Session already finalized." ); |
2036 | } |
2037 | if (!graph_created_) { |
2038 | return errors::FailedPrecondition("Session not yet created." ); |
2039 | } |
2040 | execution_state_.reset(); |
2041 | flib_def_.reset(); |
2042 | finalized_ = true; |
2043 | return OkStatus(); |
2044 | } |
2045 | |
2046 | DirectSession::Callable::~Callable() { |
2047 | // We must delete the fields in this order, because the destructor |
2048 | // of `executors_and_keys` will call into an object owned by |
2049 | // `function_info` (in particular, when deleting a kernel, it relies |
2050 | // on the `FunctionLibraryRuntime` to know if the kernel is stateful |
2051 | // or not). |
2052 | executors_and_keys.reset(); |
2053 | function_info.reset(); |
2054 | } |
2055 | |
2056 | } // namespace tensorflow |
2057 | |