1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/direct_session.h"
17
18#include <algorithm>
19#include <atomic>
20#include <string>
21#include <vector>
22
23#include "absl/container/flat_hash_set.h"
24#include "absl/time/time.h"
25#include "absl/types/optional.h"
26#include "tensorflow/core/common_runtime/collective_executor_mgr.h"
27#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
28#include "tensorflow/core/common_runtime/constant_folding.h"
29#include "tensorflow/core/common_runtime/debugger_state_interface.h"
30#include "tensorflow/core/common_runtime/device_factory.h"
31#include "tensorflow/core/common_runtime/device_resolver_local.h"
32#include "tensorflow/core/common_runtime/executor.h"
33#include "tensorflow/core/common_runtime/executor_factory.h"
34#include "tensorflow/core/common_runtime/function.h"
35#include "tensorflow/core/common_runtime/graph_constructor.h"
36#include "tensorflow/core/common_runtime/graph_optimizer.h"
37#include "tensorflow/core/common_runtime/local_session_selection.h"
38#include "tensorflow/core/common_runtime/memory_types.h"
39#include "tensorflow/core/common_runtime/optimization_registry.h"
40#include "tensorflow/core/common_runtime/process_util.h"
41#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
42#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
43#include "tensorflow/core/common_runtime/step_stats_collector.h"
44#include "tensorflow/core/framework/function.h"
45#include "tensorflow/core/framework/graph.pb.h"
46#include "tensorflow/core/framework/graph_def_util.h"
47#include "tensorflow/core/framework/log_memory.h"
48#include "tensorflow/core/framework/logging.h"
49#include "tensorflow/core/framework/metrics.h"
50#include "tensorflow/core/framework/node_def.pb.h"
51#include "tensorflow/core/framework/run_handler.h"
52#include "tensorflow/core/framework/tensor.h"
53#include "tensorflow/core/framework/versions.pb.h"
54#include "tensorflow/core/graph/algorithm.h"
55#include "tensorflow/core/graph/graph.h"
56#include "tensorflow/core/graph/graph_partition.h"
57#include "tensorflow/core/graph/subgraph.h"
58#include "tensorflow/core/graph/tensor_id.h"
59#include "tensorflow/core/lib/core/errors.h"
60#include "tensorflow/core/lib/core/notification.h"
61#include "tensorflow/core/lib/core/refcount.h"
62#include "tensorflow/core/lib/core/status.h"
63#include "tensorflow/core/lib/core/threadpool.h"
64#include "tensorflow/core/lib/core/threadpool_options.h"
65#include "tensorflow/core/lib/gtl/array_slice.h"
66#include "tensorflow/core/lib/monitoring/counter.h"
67#include "tensorflow/core/lib/random/random.h"
68#include "tensorflow/core/lib/strings/numbers.h"
69#include "tensorflow/core/lib/strings/str_util.h"
70#include "tensorflow/core/lib/strings/strcat.h"
71#include "tensorflow/core/nccl/collective_communicator.h"
72#include "tensorflow/core/platform/byte_order.h"
73#include "tensorflow/core/platform/cpu_info.h"
74#include "tensorflow/core/platform/logging.h"
75#include "tensorflow/core/platform/mutex.h"
76#include "tensorflow/core/platform/tracing.h"
77#include "tensorflow/core/platform/types.h"
78#include "tensorflow/core/profiler/lib/connected_traceme.h"
79#include "tensorflow/core/profiler/lib/device_profiler_session.h"
80#include "tensorflow/core/profiler/lib/traceme_encode.h"
81#include "tensorflow/core/protobuf/config.pb.h"
82#include "tensorflow/core/util/device_name_utils.h"
83#include "tensorflow/core/util/env_var.h"
84
85namespace tensorflow {
86
87namespace {
88
89auto* direct_session_runs = monitoring::Counter<0>::New(
90 "/tensorflow/core/direct_session_runs",
91 "The number of times DirectSession::Run() has been called.");
92
93Status NewThreadPoolFromThreadPoolOptions(
94 const SessionOptions& options,
95 const ThreadPoolOptionProto& thread_pool_options, int pool_number,
96 thread::ThreadPool** pool, bool* owned) {
97 int32_t num_threads = thread_pool_options.num_threads();
98 if (num_threads == 0) {
99 num_threads = NumInterOpThreadsFromSessionOptions(options);
100 }
101 const string& name = thread_pool_options.global_name();
102 if (name.empty()) {
103 // Session-local threadpool.
104 VLOG(1) << "Direct session inter op parallelism threads for pool "
105 << pool_number << ": " << num_threads;
106 *pool = new thread::ThreadPool(
107 options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
108 num_threads, !options.config.experimental().disable_thread_spinning(),
109 /*allocator=*/nullptr);
110 *owned = true;
111 return OkStatus();
112 }
113
114 // Global, named threadpool.
115 typedef std::pair<int32, thread::ThreadPool*> MapValue;
116 static std::map<string, MapValue>* global_pool_map =
117 new std::map<string, MapValue>;
118 static mutex* mu = new mutex();
119 mutex_lock l(*mu);
120 MapValue* mvalue = &(*global_pool_map)[name];
121 if (mvalue->second == nullptr) {
122 mvalue->first = thread_pool_options.num_threads();
123 mvalue->second = new thread::ThreadPool(
124 options.env, ThreadOptions(), strings::StrCat("Compute", pool_number),
125 num_threads, !options.config.experimental().disable_thread_spinning(),
126 /*allocator=*/nullptr);
127 } else {
128 if (mvalue->first != thread_pool_options.num_threads()) {
129 return errors::InvalidArgument(
130 "Pool ", name,
131 " configured previously with num_threads=", mvalue->first,
132 "; cannot re-configure with num_threads=",
133 thread_pool_options.num_threads());
134 }
135 }
136 *owned = false;
137 *pool = mvalue->second;
138 return OkStatus();
139}
140
141thread::ThreadPool* GlobalThreadPool(const SessionOptions& options) {
142 static thread::ThreadPool* const thread_pool =
143 NewThreadPoolFromSessionOptions(options);
144 return thread_pool;
145}
146
147// TODO(vrv): Figure out how to unify the many different functions
148// that generate RendezvousKey, since many of them have to be
149// consistent with each other.
150string GetRendezvousKey(const string& tensor_name,
151 const DeviceAttributes& device_info,
152 const FrameAndIter& frame_iter) {
153 return strings::StrCat(device_info.name(), ";",
154 strings::FpToString(device_info.incarnation()), ";",
155 device_info.name(), ";", tensor_name, ";",
156 frame_iter.frame_id, ":", frame_iter.iter_id);
157}
158
159} // namespace
160
161class DirectSessionFactory : public SessionFactory {
162 public:
163 DirectSessionFactory() {}
164
165 bool AcceptsOptions(const SessionOptions& options) override {
166 return options.target.empty() &&
167 !options.config.experimental().use_tfrt() &&
168 GetDefaultLocalSessionImpl() == LocalSessionImpl::kDirectSession;
169 }
170
171 Status NewSession(const SessionOptions& options,
172 Session** out_session) override {
173 const auto& experimental_config = options.config.experimental();
174 if (experimental_config.has_session_metadata()) {
175 if (experimental_config.session_metadata().version() < 0) {
176 return errors::InvalidArgument(
177 "Session version shouldn't be negative: ",
178 experimental_config.session_metadata().DebugString());
179 }
180 const string key = GetMetadataKey(experimental_config.session_metadata());
181 mutex_lock l(sessions_lock_);
182 if (!session_metadata_keys_.insert(key).second) {
183 return errors::InvalidArgument(
184 "A session with the same name and version has already been "
185 "created: ",
186 experimental_config.session_metadata().DebugString());
187 }
188 }
189
190 // Must do this before the CPU allocator is created.
191 if (options.config.graph_options().build_cost_model() > 0) {
192 EnableCPUAllocatorFullStats();
193 }
194 std::vector<std::unique_ptr<Device>> devices;
195 TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
196 options, "/job:localhost/replica:0/task:0", &devices));
197
198 DirectSession* session = new DirectSession(
199 options, new StaticDeviceMgr(std::move(devices)), this);
200 {
201 mutex_lock l(sessions_lock_);
202 sessions_.push_back(session);
203 }
204 *out_session = session;
205 return OkStatus();
206 }
207
208 Status Reset(const SessionOptions& options,
209 const std::vector<string>& containers) override {
210 std::vector<DirectSession*> sessions_to_reset;
211 {
212 mutex_lock l(sessions_lock_);
213 // We create a copy to ensure that we don't have a deadlock when
214 // session->Close calls the DirectSessionFactory.Deregister, which
215 // acquires sessions_lock_.
216 std::swap(sessions_to_reset, sessions_);
217 }
218 Status s;
219 for (auto session : sessions_to_reset) {
220 s.Update(session->Reset(containers));
221 }
222 // TODO(suharshs): Change the Reset behavior of all SessionFactories so that
223 // it doesn't close the sessions?
224 for (auto session : sessions_to_reset) {
225 s.Update(session->Close());
226 }
227 return s;
228 }
229
230 void Deregister(const DirectSession* session) {
231 mutex_lock l(sessions_lock_);
232 sessions_.erase(std::remove(sessions_.begin(), sessions_.end(), session),
233 sessions_.end());
234 if (session->options().config.experimental().has_session_metadata()) {
235 session_metadata_keys_.erase(GetMetadataKey(
236 session->options().config.experimental().session_metadata()));
237 }
238 }
239
240 private:
241 static string GetMetadataKey(const SessionMetadata& metadata) {
242 return absl::StrCat(metadata.name(), "/", metadata.version());
243 }
244
245 mutex sessions_lock_;
246 std::vector<DirectSession*> sessions_ TF_GUARDED_BY(sessions_lock_);
247 absl::flat_hash_set<string> session_metadata_keys_
248 TF_GUARDED_BY(sessions_lock_);
249};
250
251class DirectSessionRegistrar {
252 public:
253 DirectSessionRegistrar() {
254 SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory());
255 }
256};
257static DirectSessionRegistrar registrar;
258
259std::atomic_int_fast64_t DirectSession::step_id_counter_(1);
260
261static RunHandlerPool* GetOrCreateRunHandlerPool(
262 const SessionOptions& options) {
263 int num_inter_threads = 0;
264 int num_intra_threads = 0;
265 static const int env_num_inter_threads = NumInterOpThreadsFromEnvironment();
266 static const int env_num_intra_threads = NumIntraOpThreadsFromEnvironment();
267 if (env_num_inter_threads > 0) {
268 num_inter_threads = env_num_inter_threads;
269 }
270 if (env_num_intra_threads > 0) {
271 num_intra_threads = env_num_intra_threads;
272 }
273
274 if (num_inter_threads == 0) {
275 if (options.config.session_inter_op_thread_pool_size() > 0) {
276 // Note due to ShouldUseRunHandler we are guaranteed that
277 // run_options.inter_op_thread_pool() == 0
278 num_inter_threads =
279 options.config.session_inter_op_thread_pool(0).num_threads();
280 }
281 if (num_inter_threads == 0) {
282 num_inter_threads = NumInterOpThreadsFromSessionOptions(options);
283 }
284 }
285
286 if (num_intra_threads == 0) {
287 num_intra_threads = options.config.intra_op_parallelism_threads();
288 if (num_intra_threads == 0) {
289 num_intra_threads = port::MaxParallelism();
290 }
291 }
292
293 static RunHandlerPool* pool = [&]() {
294 LOG(INFO) << "Creating run-handler pool with "
295 "[num_inter_threads, num_intra_threads] as ["
296 << num_inter_threads << "," << num_intra_threads << "]";
297 return new RunHandlerPool(num_inter_threads, num_intra_threads);
298 }();
299 return pool;
300}
301
302bool DirectSession::ShouldUseRunHandlerPool(
303 const RunOptions& run_options) const {
304 if (options_.config.use_per_session_threads()) return false;
305 if (options_.config.session_inter_op_thread_pool_size() > 0 &&
306 run_options.inter_op_thread_pool() > 0)
307 return false;
308 // Only use RunHandlerPool when:
309 // a. Single global thread pool is used for inter-op parallelism.
310 // b. When multiple inter_op_thread_pool(s) are created, use it only while
311 // running sessions on the default inter_op_thread_pool=0. Typically,
312 // servo-team uses inter_op_thread_pool > 0 for model loading.
313 // TODO(crk): Revisit whether we'd want to create one (static) RunHandlerPool
314 // per entry in session_inter_op_thread_pool() in the future.
315 return true;
316}
317
318DirectSession::DirectSession(const SessionOptions& options,
319 const DeviceMgr* device_mgr,
320 DirectSessionFactory* const factory)
321 : options_(options),
322 device_mgr_(device_mgr),
323 factory_(factory),
324 cancellation_manager_(new CancellationManager()),
325 operation_timeout_in_ms_(options_.config.operation_timeout_in_ms()) {
326 const int thread_pool_size =
327 options_.config.session_inter_op_thread_pool_size();
328 if (thread_pool_size > 0) {
329 for (int i = 0; i < thread_pool_size; ++i) {
330 thread::ThreadPool* pool = nullptr;
331 bool owned = false;
332 init_error_.Update(NewThreadPoolFromThreadPoolOptions(
333 options_, options_.config.session_inter_op_thread_pool(i), i, &pool,
334 &owned));
335 thread_pools_.emplace_back(pool, owned);
336 }
337 } else if (options_.config.use_per_session_threads()) {
338 thread_pools_.emplace_back(NewThreadPoolFromSessionOptions(options_),
339 true /* owned */);
340 } else {
341 thread_pools_.emplace_back(GlobalThreadPool(options), false /* owned */);
342 // Run locally if environment value of TF_NUM_INTEROP_THREADS is negative
343 // and config.inter_op_parallelism_threads is unspecified or negative.
344 static const int env_num_threads = NumInterOpThreadsFromEnvironment();
345 if (options_.config.inter_op_parallelism_threads() < 0 ||
346 (options_.config.inter_op_parallelism_threads() == 0 &&
347 env_num_threads < 0)) {
348 run_in_caller_thread_ = true;
349 }
350 }
351 // The default value of sync_on_finish will be flipped soon and this
352 // environment variable will be removed as well.
353 const Status status =
354 ReadBoolFromEnvVar("TF_SYNC_ON_FINISH", true, &sync_on_finish_);
355 if (!status.ok()) {
356 LOG(ERROR) << status.error_message();
357 }
358 session_handle_ =
359 strings::StrCat("direct", strings::FpToString(random::New64()));
360 int devices_added = 0;
361 if (options.config.log_device_placement()) {
362 const string mapping_str = device_mgr_->DeviceMappingString();
363 string msg;
364 if (mapping_str.empty()) {
365 msg = "Device mapping: no known devices.";
366 } else {
367 msg = strings::StrCat("Device mapping:\n", mapping_str);
368 }
369 if (!logging::LogToListeners(msg)) {
370 LOG(INFO) << msg;
371 }
372 }
373 for (auto d : device_mgr_->ListDevices()) {
374 devices_.push_back(d);
375 device_set_.AddDevice(d);
376 d->op_segment()->AddHold(session_handle_);
377
378 // The first device added is special: it is the 'client device' (a
379 // CPU device) from which we feed and fetch Tensors.
380 if (devices_added == 0) {
381 device_set_.set_client_device(d);
382 }
383 ++devices_added;
384 }
385}
386
387DirectSession::~DirectSession() {
388 if (!closed_) Close().IgnoreError();
389 for (auto& it : partial_runs_) {
390 it.second.reset(nullptr);
391 }
392 for (auto& it : executors_) {
393 it.second.reset();
394 }
395 callables_.clear();
396 for (auto d : device_mgr_->ListDevices()) {
397 d->op_segment()->RemoveHold(session_handle_);
398 }
399 functions_.clear();
400 delete cancellation_manager_;
401 for (const auto& p_and_owned : thread_pools_) {
402 if (p_and_owned.second) delete p_and_owned.first;
403 }
404
405 execution_state_.reset(nullptr);
406 flib_def_.reset(nullptr);
407}
408
409Status DirectSession::Create(const GraphDef& graph) {
410 return Create(GraphDef(graph));
411}
412
413Status DirectSession::Create(GraphDef&& graph) {
414 TF_RETURN_IF_ERROR(init_error_);
415 if (graph.node_size() > 0) {
416 mutex_lock l(graph_state_lock_);
417 if (graph_created_) {
418 return errors::AlreadyExists(
419 "A Graph has already been created for this session.");
420 }
421 return ExtendLocked(std::move(graph));
422 }
423 return OkStatus();
424}
425
426Status DirectSession::Extend(const GraphDef& graph) {
427 return Extend(GraphDef(graph));
428}
429
430Status DirectSession::Extend(GraphDef&& graph) {
431 TF_RETURN_IF_ERROR(CheckNotClosed());
432 mutex_lock l(graph_state_lock_);
433 return ExtendLocked(std::move(graph));
434}
435
436Status DirectSession::ExtendLocked(GraphDef&& graph) {
437 if (finalized_) {
438 return errors::FailedPrecondition("Session has been finalized.");
439 }
440 if (!(flib_def_ && execution_state_)) {
441 // If this is the first call, we can initialize the execution state
442 // with `graph` and do not need to call `Extend()`.
443 GraphExecutionStateOptions options;
444 options.device_set = &device_set_;
445 options.session_options = &options_;
446 options.session_handle = session_handle_;
447 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
448 std::move(graph), options, &execution_state_));
449 // NOTE(mrry): The function library created here will be used for
450 // all subsequent extensions of the graph. Also, note how using the copy
451 // constructor of FunctionLibraryDefinition avoids duplicating the memory
452 // that is occupied by its shared_ptr members.
453 flib_def_.reset(
454 new FunctionLibraryDefinition(execution_state_->flib_def()));
455 graph_created_ = true;
456 } else {
457 std::unique_ptr<GraphExecutionState> state;
458 // TODO(mrry): Rewrite GraphExecutionState::Extend() to take `graph` by
459 // value and move `graph` in here.
460 TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
461 execution_state_.swap(state);
462 TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));
463 }
464 return OkStatus();
465}
466
467Status DirectSession::Run(const NamedTensorList& inputs,
468 const std::vector<string>& output_names,
469 const std::vector<string>& target_nodes,
470 std::vector<Tensor>* outputs) {
471 RunMetadata run_metadata;
472 return Run(RunOptions(), inputs, output_names, target_nodes, outputs,
473 &run_metadata);
474}
475
476Status DirectSession::CreateDebuggerState(
477 const CallableOptions& callable_options, int64_t global_step,
478 int64_t session_run_index, int64_t executor_step_index,
479 std::unique_ptr<DebuggerStateInterface>* debugger_state) {
480 TF_RETURN_IF_ERROR(DebuggerStateRegistry::CreateState(
481 callable_options.run_options().debug_options(), debugger_state));
482 std::vector<string> input_names(callable_options.feed().begin(),
483 callable_options.feed().end());
484 std::vector<string> output_names(callable_options.fetch().begin(),
485 callable_options.fetch().end());
486 std::vector<string> target_names(callable_options.target().begin(),
487 callable_options.target().end());
488
489 TF_RETURN_IF_ERROR(debugger_state->get()->PublishDebugMetadata(
490 global_step, session_run_index, executor_step_index, input_names,
491 output_names, target_names));
492 return OkStatus();
493}
494
495Status DirectSession::DecorateAndPublishGraphForDebug(
496 const DebugOptions& debug_options, Graph* graph, Device* device) {
497 std::unique_ptr<DebugGraphDecoratorInterface> decorator;
498 TF_RETURN_IF_ERROR(
499 DebugGraphDecoratorRegistry::CreateDecorator(debug_options, &decorator));
500
501 TF_RETURN_IF_ERROR(decorator->DecorateGraph(graph, device));
502 TF_RETURN_IF_ERROR(decorator->PublishGraph(*graph, device->name()));
503 return OkStatus();
504}
505
506Status DirectSession::RunInternal(
507 int64_t step_id, const RunOptions& run_options,
508 CallFrameInterface* call_frame, ExecutorsAndKeys* executors_and_keys,
509 RunMetadata* run_metadata,
510 const thread::ThreadPoolOptions& threadpool_options) {
511 const uint64 start_time_usecs = options_.env->NowMicros();
512 const int64_t executor_step_count =
513 executors_and_keys->step_count.fetch_add(1);
514 RunState run_state(step_id, &devices_);
515 const size_t num_executors = executors_and_keys->items.size();
516
517 profiler::TraceMeProducer activity(
518 // To TraceMeConsumers in ExecutorState::Process/Finish.
519 [&] {
520 if (options_.config.experimental().has_session_metadata()) {
521 const auto& model_metadata =
522 options_.config.experimental().session_metadata();
523 string model_id = strings::StrCat(model_metadata.name(), ":",
524 model_metadata.version());
525 return profiler::TraceMeEncode("SessionRun",
526 {{"id", step_id},
527 {"_r", 1} /*root_event*/,
528 {"model_id", model_id}});
529 } else {
530 return profiler::TraceMeEncode(
531 "SessionRun", {{"id", step_id}, {"_r", 1} /*root_event*/});
532 }
533 },
534 profiler::ContextType::kTfExecutor, step_id,
535 profiler::TraceMeLevel::kInfo);
536
537 std::unique_ptr<DebuggerStateInterface> debugger_state;
538 if (!run_options.debug_options().debug_tensor_watch_opts().empty()) {
539 TF_RETURN_IF_ERROR(
540 CreateDebuggerState(executors_and_keys->callable_options,
541 run_options.debug_options().global_step(), step_id,
542 executor_step_count, &debugger_state));
543 }
544
545 if (run_metadata != nullptr &&
546 options_.config.experimental().has_session_metadata()) {
547 *run_metadata->mutable_session_metadata() =
548 options_.config.experimental().session_metadata();
549 }
550
551#ifndef __ANDROID__
552 // Set up for collectives if ExecutorsAndKeys declares a key.
553 if (executors_and_keys->collective_graph_key !=
554 BuildGraphOptions::kNoCollectiveGraphKey) {
555 if (run_options.experimental().collective_graph_key() !=
556 BuildGraphOptions::kNoCollectiveGraphKey) {
557 // If a collective_graph_key was specified in run_options, ensure that it
558 // matches what came out of GraphExecutionState::BuildGraph().
559 if (run_options.experimental().collective_graph_key() !=
560 executors_and_keys->collective_graph_key) {
561 return errors::Internal(
562 "collective_graph_key in RunOptions ",
563 run_options.experimental().collective_graph_key(),
564 " should match collective_graph_key from optimized graph ",
565 executors_and_keys->collective_graph_key);
566 }
567 }
568 if (!collective_executor_mgr_) {
569 collective_executor_mgr_ = CreateProdLocalCollectiveExecutorMgr(
570 options_.config, device_mgr_.get(),
571 MaybeCreateNcclCommunicator(options_.config));
572 }
573 run_state.collective_executor.reset(new CollectiveExecutor::Handle(
574 collective_executor_mgr_->FindOrCreate(step_id), true /*inherit_ref*/));
575 }
576#endif
577
578 thread::ThreadPool* pool;
579 // Use std::unique_ptr to ensure garbage collection
580 std::unique_ptr<thread::ThreadPool> threadpool_wrapper;
581
582 const bool inline_execution_requested =
583 run_in_caller_thread_ || run_options.inter_op_thread_pool() == -1;
584
585 if (inline_execution_requested) {
586 // We allow using the caller thread only when having a single executor
587 // specified.
588 if (executors_and_keys->items.size() > 1) {
589 pool = thread_pools_[0].first;
590 } else {
591 VLOG(1) << "Executing Session::Run() synchronously!";
592 pool = nullptr;
593 }
594 } else if (threadpool_options.inter_op_threadpool != nullptr) {
595 threadpool_wrapper = std::make_unique<thread::ThreadPool>(
596 threadpool_options.inter_op_threadpool);
597 pool = threadpool_wrapper.get();
598 } else {
599 if (run_options.inter_op_thread_pool() < -1 ||
600 run_options.inter_op_thread_pool() >=
601 static_cast<int32>(thread_pools_.size())) {
602 return errors::InvalidArgument("Invalid inter_op_thread_pool: ",
603 run_options.inter_op_thread_pool());
604 }
605
606 pool = thread_pools_[run_options.inter_op_thread_pool()].first;
607 }
608
609 const int64_t call_timeout = run_options.timeout_in_ms() > 0
610 ? run_options.timeout_in_ms()
611 : operation_timeout_in_ms_;
612 absl::optional<absl::Time> deadline;
613 if (call_timeout > 0) {
614 deadline = absl::Now() + absl::Milliseconds(call_timeout);
615 }
616
617 std::unique_ptr<RunHandler> handler;
618 if (ShouldUseRunHandlerPool(run_options) &&
619 run_options.experimental().use_run_handler_pool()) {
620 VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
621 handler = GetOrCreateRunHandlerPool(options_)->Get(
622 step_id, call_timeout,
623 run_options.experimental().run_handler_pool_options());
624 if (!handler) {
625 return errors::DeadlineExceeded(
626 "Could not obtain RunHandler for request after waiting for ",
627 call_timeout, "ms.");
628 }
629 }
630 auto* handler_ptr = handler.get();
631
632 Executor::Args::Runner default_runner = nullptr;
633
634 if (pool == nullptr) {
635 default_runner = [](const Executor::Args::Closure& c) { c(); };
636 } else if (handler_ptr != nullptr) {
637 default_runner = [handler_ptr](Executor::Args::Closure c) {
638 handler_ptr->ScheduleInterOpClosure(std::move(c));
639 };
640 } else {
641 default_runner = [pool](Executor::Args::Closure c) {
642 pool->Schedule(std::move(c));
643 };
644 }
645
646 // Start parallel Executors.
647
648 // We can execute this step synchronously on the calling thread whenever
649 // there is a single device and the timeout mechanism is not used.
650 //
651 // When timeouts are used, we must execute the graph(s) asynchronously, in
652 // order to invoke the cancellation manager on the calling thread if the
653 // timeout expires.
654 const bool can_execute_synchronously =
655 executors_and_keys->items.size() == 1 && call_timeout == 0;
656
657 Executor::Args args;
658 args.step_id = step_id;
659 args.call_frame = call_frame;
660 args.collective_executor =
661 (run_state.collective_executor ? run_state.collective_executor->get()
662 : nullptr);
663 args.session_state = &session_state_;
664 args.session_handle = session_handle_;
665 args.tensor_store = &run_state.tensor_store;
666 args.step_container = &run_state.step_container;
667 args.sync_on_finish = sync_on_finish_;
668 args.user_intra_op_threadpool = threadpool_options.intra_op_threadpool;
669 args.run_all_kernels_inline = pool == nullptr;
670 args.start_time_usecs = start_time_usecs;
671 args.deadline = deadline;
672
673 const bool do_trace = (run_options.trace_level() > RunOptions::NO_TRACE);
674
675 bool update_cost_model = false;
676 if (options_.config.graph_options().build_cost_model() > 0) {
677 const int64_t build_cost_model_every =
678 options_.config.graph_options().build_cost_model();
679 const int64_t build_cost_model_after =
680 options_.config.graph_options().build_cost_model_after();
681 int64_t measure_step_count = executor_step_count - build_cost_model_after;
682 if (measure_step_count >= 0) {
683 update_cost_model =
684 ((measure_step_count + 1) % build_cost_model_every == 0);
685 }
686 }
687 if (do_trace || update_cost_model ||
688 run_options.report_tensor_allocations_upon_oom()) {
689 run_state.collector.reset(
690 new StepStatsCollector(run_metadata->mutable_step_stats()));
691 args.stats_collector = run_state.collector.get();
692 }
693
694 std::unique_ptr<DeviceProfilerSession> device_profiler_session;
695 if (run_options.trace_level() >= RunOptions::HARDWARE_TRACE) {
696 device_profiler_session = DeviceProfilerSession::Create();
697 }
698
699 // Register this step with session's cancellation manager, so that
700 // `Session::Close()` will cancel the step.
701 CancellationManager step_cancellation_manager(cancellation_manager_);
702 if (step_cancellation_manager.IsCancelled()) {
703 return errors::Cancelled("Run call was cancelled");
704 }
705 args.cancellation_manager = &step_cancellation_manager;
706
707 Status run_status;
708
709 auto set_threadpool_args_for_item =
710 [&default_runner, &handler](const PerPartitionExecutorsAndLib& item,
711 Executor::Args* args) {
712 // TODO(azaks): support partial run.
713 // TODO(azaks): if the device picks its own threadpool, we need to
714 // assign
715 // less threads to the main compute pool by default.
716 thread::ThreadPool* device_thread_pool =
717 item.device->tensorflow_device_thread_pool();
718 // TODO(crk): Investigate usage of RunHandlerPool when using device
719 // specific thread pool(s).
720 if (!device_thread_pool) {
721 args->runner = default_runner;
722 } else {
723 args->runner = [device_thread_pool](Executor::Args::Closure c) {
724 device_thread_pool->Schedule(std::move(c));
725 };
726 }
727 if (handler != nullptr) {
728 args->user_intra_op_threadpool =
729 handler->AsIntraThreadPoolInterface();
730 }
731 };
732
733 if (can_execute_synchronously) {
734 PrivateIntraProcessRendezvous rendezvous(device_mgr_.get());
735 args.rendezvous = &rendezvous;
736
737 const auto& item = executors_and_keys->items[0];
738 set_threadpool_args_for_item(item, &args);
739 run_status = item.executor->Run(args);
740 } else {
741 core::RefCountPtr<RefCountedIntraProcessRendezvous> rendezvous(
742 new RefCountedIntraProcessRendezvous(device_mgr_.get()));
743 args.rendezvous = rendezvous.get();
744
745 // `barrier` will delete itself after the final executor finishes.
746 Notification executors_done;
747 ExecutorBarrier* barrier =
748 new ExecutorBarrier(num_executors, rendezvous.get(),
749 [&run_state, &executors_done](const Status& ret) {
750 {
751 mutex_lock l(run_state.mu);
752 run_state.status.Update(ret);
753 }
754 executors_done.Notify();
755 });
756
757 for (const auto& item : executors_and_keys->items) {
758 set_threadpool_args_for_item(item, &args);
759 item.executor->RunAsync(args, barrier->Get());
760 }
761
762 WaitForNotification(&executors_done, &run_state, &step_cancellation_manager,
763 call_timeout);
764 {
765 tf_shared_lock l(run_state.mu);
766 run_status = run_state.status;
767 }
768 }
769
770 if (step_cancellation_manager.IsCancelled()) {
771 run_status.Update(errors::Cancelled("Run call was cancelled"));
772 }
773
774 if (device_profiler_session) {
775 TF_RETURN_IF_ERROR(device_profiler_session->CollectData(
776 run_metadata->mutable_step_stats()));
777 }
778
779 TF_RETURN_IF_ERROR(run_status);
780
781 // Save the output tensors of this run we choose to keep.
782 if (!run_state.tensor_store.empty()) {
783 TF_RETURN_IF_ERROR(run_state.tensor_store.SaveTensors(
784 {executors_and_keys->callable_options.fetch().begin(),
785 executors_and_keys->callable_options.fetch().end()},
786 &session_state_));
787 }
788
789 if (run_state.collector) {
790 run_state.collector->Finalize();
791 }
792
793 // Build and return the cost model as instructed.
794 if (update_cost_model) {
795 // Build the cost model
796 std::unordered_map<string, const Graph*> device_to_graph;
797 for (const PerPartitionExecutorsAndLib& partition :
798 executors_and_keys->items) {
799 const Graph* graph = partition.graph.get();
800 const string& device = partition.flib->device()->name();
801 device_to_graph[device] = graph;
802 }
803
804 mutex_lock l(executor_lock_);
805 run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
806
807 // annotate stats onto cost graph.
808 CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
809 for (const auto& item : executors_and_keys->items) {
810 TF_RETURN_IF_ERROR(
811 cost_model_manager_.AddToCostGraphDef(item.graph.get(), cost_graph));
812 }
813 }
814
815 // If requested via RunOptions, output the partition graphs.
816 if (run_options.output_partition_graphs()) {
817 if (options_.config.experimental().disable_output_partition_graphs()) {
818 return errors::InvalidArgument(
819 "RunOptions.output_partition_graphs() is not supported when "
820 "disable_output_partition_graphs is true.");
821 } else {
822 protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
823 run_metadata->mutable_partition_graphs();
824 for (const PerPartitionExecutorsAndLib& exec_and_lib :
825 executors_and_keys->items) {
826 GraphDef* partition_graph_def = partition_graph_defs->Add();
827 exec_and_lib.graph->ToGraphDef(partition_graph_def);
828 }
829 }
830 }
831 metrics::UpdateGraphExecTime(options_.env->NowMicros() - start_time_usecs);
832
833 return OkStatus();
834}
835
836Status DirectSession::Run(const RunOptions& run_options,
837 const NamedTensorList& inputs,
838 const std::vector<string>& output_names,
839 const std::vector<string>& target_nodes,
840 std::vector<Tensor>* outputs,
841 RunMetadata* run_metadata) {
842 return Run(run_options, inputs, output_names, target_nodes, outputs,
843 run_metadata, thread::ThreadPoolOptions());
844}
845
846Status DirectSession::Run(const RunOptions& run_options,
847 const NamedTensorList& inputs,
848 const std::vector<string>& output_names,
849 const std::vector<string>& target_nodes,
850 std::vector<Tensor>* outputs,
851 RunMetadata* run_metadata,
852 const thread::ThreadPoolOptions& threadpool_options) {
853 TF_RETURN_IF_ERROR(CheckNotClosed());
854 TF_RETURN_IF_ERROR(CheckGraphCreated("Run()"));
855 direct_session_runs->GetCell()->IncrementBy(1);
856
857 // Extract the inputs names for this run of the session.
858 std::vector<string> input_tensor_names;
859 input_tensor_names.reserve(inputs.size());
860 size_t input_size = 0;
861 for (const auto& it : inputs) {
862 input_tensor_names.push_back(it.first);
863 input_size += it.second.AllocatedBytes();
864 }
865 metrics::RecordGraphInputTensors(input_size);
866
867 // Check if we already have an executor for these arguments.
868 ExecutorsAndKeys* executors_and_keys;
869 RunStateArgs run_state_args(run_options.debug_options());
870 run_state_args.collective_graph_key =
871 run_options.experimental().collective_graph_key();
872
873 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
874 target_nodes, &executors_and_keys,
875 &run_state_args));
876 {
877 mutex_lock l(collective_graph_key_lock_);
878 collective_graph_key_ = executors_and_keys->collective_graph_key;
879 }
880
881 // Configure a call frame for the step, which we use to feed and
882 // fetch values to and from the executors.
883 FunctionCallFrame call_frame(executors_and_keys->input_types,
884 executors_and_keys->output_types);
885 gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
886 for (const auto& it : inputs) {
887 if (it.second.dtype() == DT_RESOURCE) {
888 Tensor tensor_from_handle;
889 TF_RETURN_IF_ERROR(
890 ResourceHandleToInputTensor(it.second, &tensor_from_handle));
891 feed_args[executors_and_keys->input_name_to_index[it.first]] =
892 tensor_from_handle;
893 } else {
894 feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
895 }
896 }
897 const Status s = call_frame.SetArgs(feed_args);
898 if (errors::IsInternal(s)) {
899 return errors::InvalidArgument(s.error_message());
900 } else if (!s.ok()) {
901 return s;
902 }
903
904 const int64_t step_id = step_id_counter_.fetch_add(1);
905
906 if (LogMemory::IsEnabled()) {
907 LogMemory::RecordStep(step_id, run_state_args.handle);
908 }
909
910 TF_RETURN_IF_ERROR(RunInternal(step_id, run_options, &call_frame,
911 executors_and_keys, run_metadata,
912 threadpool_options));
913
914 // Receive outputs.
915 if (outputs) {
916 std::vector<Tensor> sorted_outputs;
917 const Status s = call_frame.ConsumeRetvals(
918 &sorted_outputs, /* allow_dead_tensors = */ false);
919 if (errors::IsInternal(s)) {
920 return errors::InvalidArgument(s.error_message());
921 } else if (!s.ok()) {
922 return s;
923 }
924 const bool unique_outputs =
925 output_names.size() == executors_and_keys->output_name_to_index.size();
926 // first_indices[i] = j implies that j is the smallest value for which
927 // output_names[i] == output_names[j].
928 std::vector<int> first_indices;
929 if (!unique_outputs) {
930 first_indices.reserve(output_names.size());
931 for (const auto& name : output_names) {
932 first_indices.push_back(
933 std::find(output_names.begin(), output_names.end(), name) -
934 output_names.begin());
935 }
936 }
937 outputs->clear();
938 size_t output_size = 0;
939 outputs->reserve(sorted_outputs.size());
940 for (int i = 0; i < output_names.size(); ++i) {
941 const string& output_name = output_names[i];
942 if (first_indices.empty() || first_indices[i] == i) {
943 outputs->emplace_back(
944 std::move(sorted_outputs[executors_and_keys
945 ->output_name_to_index[output_name]]));
946 } else {
947 outputs->push_back((*outputs)[first_indices[i]]);
948 }
949 output_size += outputs->back().AllocatedBytes();
950 }
951 metrics::RecordGraphOutputTensors(output_size);
952 }
953
954 return OkStatus();
955}
956
957Status DirectSession::PRunSetup(const std::vector<string>& input_names,
958 const std::vector<string>& output_names,
959 const std::vector<string>& target_nodes,
960 string* handle) {
961 TF_RETURN_IF_ERROR(CheckNotClosed());
962 TF_RETURN_IF_ERROR(CheckGraphCreated("PRunSetup()"));
963
964 // RunOptions is not available in PRunSetup, so use thread pool 0.
965 thread::ThreadPool* pool = thread_pools_[0].first;
966
967 // Check if we already have an executor for these arguments.
968 ExecutorsAndKeys* executors_and_keys;
969 // TODO(cais): TFDBG support for partial runs.
970 DebugOptions debug_options;
971 RunStateArgs run_state_args(debug_options);
972 run_state_args.is_partial_run = true;
973 TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_names, output_names,
974 target_nodes, &executors_and_keys,
975 &run_state_args));
976
977 // Create the run state and save it for future PRun calls.
978 Executor::Args args;
979 args.step_id = step_id_counter_.fetch_add(1);
980 PartialRunState* run_state =
981 new PartialRunState(input_names, output_names, args.step_id, &devices_);
982 run_state->rendez.reset(new IntraProcessRendezvous(device_mgr_.get()));
983 {
984 mutex_lock l(executor_lock_);
985 if (!partial_runs_
986 .emplace(run_state_args.handle,
987 std::unique_ptr<PartialRunState>(run_state))
988 .second) {
989 return errors::Internal("The handle '", run_state_args.handle,
990 "' created for this partial run is not unique.");
991 }
992 }
993
994 // Start parallel Executors.
995 const size_t num_executors = executors_and_keys->items.size();
996 ExecutorBarrier* barrier = new ExecutorBarrier(
997 num_executors, run_state->rendez.get(), [run_state](const Status& ret) {
998 if (!ret.ok()) {
999 mutex_lock l(run_state->mu);
1000 run_state->status.Update(ret);
1001 }
1002 run_state->executors_done.Notify();
1003 });
1004
1005 args.rendezvous = run_state->rendez.get();
1006 args.cancellation_manager = cancellation_manager_;
1007 // Note that Collectives are not supported in partial runs
1008 // because RunOptions is not passed in so we can't know whether
1009 // their use is intended.
1010 args.collective_executor = nullptr;
1011 args.runner = [this, pool](Executor::Args::Closure c) {
1012 pool->Schedule(std::move(c));
1013 };
1014 args.session_state = &session_state_;
1015 args.session_handle = session_handle_;
1016 args.tensor_store = &run_state->tensor_store;
1017 args.step_container = &run_state->step_container;
1018 if (LogMemory::IsEnabled()) {
1019 LogMemory::RecordStep(args.step_id, run_state_args.handle);
1020 }
1021 args.sync_on_finish = sync_on_finish_;
1022
1023 if (options_.config.graph_options().build_cost_model()) {
1024 run_state->collector.reset(new StepStatsCollector(nullptr));
1025 args.stats_collector = run_state->collector.get();
1026 }
1027
1028 for (auto& item : executors_and_keys->items) {
1029 item.executor->RunAsync(args, barrier->Get());
1030 }
1031
1032 *handle = run_state_args.handle;
1033 return OkStatus();
1034}
1035
1036Status DirectSession::PRun(const string& handle, const NamedTensorList& inputs,
1037 const std::vector<string>& output_names,
1038 std::vector<Tensor>* outputs) {
1039 TF_RETURN_IF_ERROR(CheckNotClosed());
1040 std::vector<string> parts = str_util::Split(handle, ';');
1041 const string& key = parts[0];
1042 // Get the executors for this partial run.
1043 ExecutorsAndKeys* executors_and_keys;
1044 PartialRunState* run_state;
1045 {
1046 mutex_lock l(executor_lock_); // could use reader lock
1047 auto exc_it = executors_.find(key);
1048 if (exc_it == executors_.end()) {
1049 return errors::InvalidArgument(
1050 "Must run 'setup' before performing partial runs!");
1051 }
1052 executors_and_keys = exc_it->second.get();
1053
1054 auto prun_it = partial_runs_.find(handle);
1055 if (prun_it == partial_runs_.end()) {
1056 return errors::InvalidArgument(
1057 "Must run 'setup' before performing partial runs!");
1058 }
1059 run_state = prun_it->second.get();
1060
1061 // Make sure that this is a new set of feeds that are still pending.
1062 for (const auto& input : inputs) {
1063 auto it = run_state->pending_inputs.find(input.first);
1064 if (it == run_state->pending_inputs.end()) {
1065 return errors::InvalidArgument(
1066 "The feed ", input.first,
1067 " was not specified in partial_run_setup.");
1068 } else if (it->second) {
1069 return errors::InvalidArgument("The feed ", input.first,
1070 " has already been fed.");
1071 }
1072 }
1073 // Check that this is a new set of fetches that are still pending.
1074 for (const auto& output : output_names) {
1075 auto it = run_state->pending_outputs.find(output);
1076 if (it == run_state->pending_outputs.end()) {
1077 return errors::InvalidArgument(
1078 "The fetch ", output, " was not specified in partial_run_setup.");
1079 } else if (it->second) {
1080 return errors::InvalidArgument("The fetch ", output,
1081 " has already been fetched.");
1082 }
1083 }
1084 }
1085
1086 // Check that this new set of fetches can be computed from all the
1087 // feeds we have supplied.
1088 TF_RETURN_IF_ERROR(
1089 CheckFetch(inputs, output_names, executors_and_keys, run_state));
1090
1091 // Send inputs.
1092 Status s =
1093 SendPRunInputs(inputs, executors_and_keys, run_state->rendez.get());
1094
1095 // Receive outputs.
1096 if (s.ok()) {
1097 s = RecvPRunOutputs(output_names, executors_and_keys, run_state, outputs);
1098 }
1099
1100 // Save the output tensors of this run we choose to keep.
1101 if (s.ok()) {
1102 s = run_state->tensor_store.SaveTensors(output_names, &session_state_);
1103 }
1104
1105 {
1106 mutex_lock l(executor_lock_);
1107 // Delete the run state if there is an error or all fetches are done.
1108 bool done = true;
1109 if (s.ok()) {
1110 {
1111 mutex_lock l(run_state->mu);
1112 if (!run_state->status.ok()) {
1113 LOG(WARNING) << "An error unrelated to this prun has been detected. "
1114 << run_state->status;
1115 }
1116 }
1117 for (const auto& input : inputs) {
1118 auto it = run_state->pending_inputs.find(input.first);
1119 it->second = true;
1120 }
1121 for (const auto& name : output_names) {
1122 auto it = run_state->pending_outputs.find(name);
1123 it->second = true;
1124 }
1125 done = run_state->PendingDone();
1126 }
1127 if (done) {
1128 WaitForNotification(&run_state->executors_done, run_state,
1129 cancellation_manager_, operation_timeout_in_ms_);
1130 partial_runs_.erase(handle);
1131 }
1132 }
1133
1134 return s;
1135}
1136
1137Status DirectSession::ResourceHandleToInputTensor(const Tensor& resource_tensor,
1138 Tensor* retrieved_tensor) {
1139 if (resource_tensor.dtype() != DT_RESOURCE) {
1140 return errors::InvalidArgument(strings::StrCat(
1141 "ResourceHandleToInputTensor() received non-DT_RESOURCE Tensor: ",
1142 resource_tensor.dtype()));
1143 }
1144
1145 const ResourceHandle& resource_handle =
1146 resource_tensor.scalar<ResourceHandle>()();
1147
1148 if (resource_handle.container() ==
1149 SessionState::kTensorHandleResourceTypeName) {
1150 return session_state_.GetTensor(resource_handle.name(), retrieved_tensor);
1151 } else {
1152 return errors::InvalidArgument(strings::StrCat(
1153 "Invalid resource type hash code: ", resource_handle.hash_code(),
1154 "(name: ", resource_handle.name(),
1155 " type: ", resource_handle.maybe_type_name(),
1156 "). Perhaps a resource tensor was being provided as a feed? That is "
1157 "not currently allowed. Please file an issue at "
1158 "https://github.com/tensorflow/tensorflow/issues/new, ideally with a "
1159 "short code snippet that leads to this error message."));
1160 }
1161}
1162
1163Status DirectSession::SendPRunInputs(const NamedTensorList& inputs,
1164 const ExecutorsAndKeys* executors_and_keys,
1165 IntraProcessRendezvous* rendez) {
1166 Status s;
1167 Rendezvous::ParsedKey parsed;
1168 // Insert the input tensors into the local rendezvous by their
1169 // rendezvous key.
1170 for (const auto& input : inputs) {
1171 auto it =
1172 executors_and_keys->input_name_to_rendezvous_key.find(input.first);
1173 if (it == executors_and_keys->input_name_to_rendezvous_key.end()) {
1174 return errors::Internal("'", input.first, "' is not a pre-defined feed.");
1175 }
1176 const string& input_key = it->second;
1177
1178 s = Rendezvous::ParseKey(input_key, &parsed);
1179 if (!s.ok()) {
1180 rendez->StartAbort(s);
1181 return s;
1182 }
1183
1184 if (input.second.dtype() == DT_RESOURCE) {
1185 Tensor tensor_from_handle;
1186 s = ResourceHandleToInputTensor(input.second, &tensor_from_handle);
1187 if (s.ok()) {
1188 s = rendez->Send(parsed, Rendezvous::Args(), tensor_from_handle, false);
1189 }
1190 } else {
1191 s = rendez->Send(parsed, Rendezvous::Args(), input.second, false);
1192 }
1193
1194 if (!s.ok()) {
1195 rendez->StartAbort(s);
1196 return s;
1197 }
1198 }
1199 return OkStatus();
1200}
1201
1202Status DirectSession::RecvPRunOutputs(
1203 const std::vector<string>& output_names,
1204 const ExecutorsAndKeys* executors_and_keys, PartialRunState* run_state,
1205 std::vector<Tensor>* outputs) {
1206 Status s;
1207 if (!output_names.empty()) {
1208 outputs->resize(output_names.size());
1209 }
1210
1211 Rendezvous::ParsedKey parsed;
1212 // Get the outputs from the rendezvous
1213 for (size_t output_offset = 0; output_offset < output_names.size();
1214 ++output_offset) {
1215 const string& output_name = output_names[output_offset];
1216 auto it =
1217 executors_and_keys->output_name_to_rendezvous_key.find(output_name);
1218 if (it == executors_and_keys->output_name_to_rendezvous_key.end()) {
1219 return errors::Internal("'", output_name,
1220 "' is not a pre-defined fetch.");
1221 }
1222 const string& output_key = it->second;
1223 Tensor output_tensor;
1224 bool is_dead;
1225
1226 s = Rendezvous::ParseKey(output_key, &parsed);
1227 if (s.ok()) {
1228 // Fetch data from the Rendezvous.
1229 s = run_state->rendez->Recv(parsed, Rendezvous::Args(), &output_tensor,
1230 &is_dead, operation_timeout_in_ms_);
1231 if (is_dead && s.ok()) {
1232 s = errors::InvalidArgument("The tensor returned for ", output_name,
1233 " was not valid.");
1234 }
1235 }
1236 if (!s.ok()) {
1237 run_state->rendez->StartAbort(s);
1238 outputs->clear();
1239 return s;
1240 }
1241
1242 (*outputs)[output_offset] = output_tensor;
1243 }
1244 return OkStatus();
1245}
1246
1247Status DirectSession::CheckFetch(const NamedTensorList& feeds,
1248 const std::vector<string>& fetches,
1249 const ExecutorsAndKeys* executors_and_keys,
1250 const PartialRunState* run_state) {
1251 const Graph* graph = executors_and_keys->graph.get();
1252 const NameNodeMap* name_to_node = &executors_and_keys->name_to_node;
1253
1254 // Build the set of pending feeds that we haven't seen.
1255 std::unordered_set<TensorId, TensorId::Hasher> pending_feeds;
1256 {
1257 mutex_lock l(executor_lock_);
1258 for (const auto& input : run_state->pending_inputs) {
1259 // Skip if the feed has already been fed.
1260 if (input.second) continue;
1261 TensorId id(ParseTensorName(input.first));
1262 auto it = name_to_node->find(id.first);
1263 if (it == name_to_node->end()) {
1264 return errors::NotFound("Feed ", input.first, ": not found");
1265 }
1266 pending_feeds.insert(id);
1267 }
1268 }
1269 for (const auto& it : feeds) {
1270 TensorId id(ParseTensorName(it.first));
1271 pending_feeds.erase(id);
1272 }
1273
1274 // Initialize the stack with the fetch nodes.
1275 std::vector<const Node*> stack;
1276 for (const string& fetch : fetches) {
1277 TensorId id(ParseTensorName(fetch));
1278 auto it = name_to_node->find(id.first);
1279 if (it == name_to_node->end()) {
1280 return errors::NotFound("Fetch ", fetch, ": not found");
1281 }
1282 stack.push_back(it->second);
1283 }
1284
1285 // Any tensor needed for fetches can't be in pending_feeds.
1286 std::vector<bool> visited(graph->num_node_ids(), false);
1287 while (!stack.empty()) {
1288 const Node* n = stack.back();
1289 stack.pop_back();
1290
1291 for (const Edge* in_edge : n->in_edges()) {
1292 const Node* in_node = in_edge->src();
1293 if (pending_feeds.count({in_node->name(), in_edge->src_output()}) > 0) {
1294 return errors::InvalidArgument("Fetch ", in_node->name(), ":",
1295 in_edge->src_output(),
1296 " can't be computed from the feeds"
1297 " that have been fed so far.");
1298 }
1299 if (!visited[in_node->id()]) {
1300 visited[in_node->id()] = true;
1301 stack.push_back(in_node);
1302 }
1303 }
1304 }
1305 return OkStatus();
1306}
1307
1308Status DirectSession::CreateExecutors(
1309 const CallableOptions& callable_options,
1310 std::unique_ptr<ExecutorsAndKeys>* out_executors_and_keys,
1311 std::unique_ptr<FunctionInfo>* out_func_info,
1312 RunStateArgs* run_state_args) {
1313 BuildGraphOptions options;
1314 options.callable_options = callable_options;
1315 options.use_function_convention = !run_state_args->is_partial_run;
1316 options.collective_graph_key =
1317 callable_options.run_options().experimental().collective_graph_key();
1318 if (options_.config.experimental()
1319 .collective_deterministic_sequential_execution()) {
1320 options.collective_order = GraphCollectiveOrder::kEdges;
1321 } else if (options_.config.experimental().collective_nccl()) {
1322 options.collective_order = GraphCollectiveOrder::kAttrs;
1323 }
1324
1325 std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
1326 std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
1327
1328 ek->callable_options = callable_options;
1329
1330 std::unordered_map<string, std::unique_ptr<Graph>> graphs;
1331 TF_RETURN_IF_ERROR(CreateGraphs(
1332 options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
1333 &ek->output_types, &ek->collective_graph_key));
1334
1335 if (run_state_args->is_partial_run) {
1336 ek->graph = std::move(run_state_args->graph);
1337 std::unordered_set<StringPiece, StringPieceHasher> names;
1338 for (const string& input : callable_options.feed()) {
1339 TensorId id(ParseTensorName(input));
1340 names.emplace(id.first);
1341 }
1342 for (const string& output : callable_options.fetch()) {
1343 TensorId id(ParseTensorName(output));
1344 names.emplace(id.first);
1345 }
1346 for (Node* n : ek->graph->nodes()) {
1347 if (names.count(n->name()) > 0) {
1348 ek->name_to_node.insert({n->name(), n});
1349 }
1350 }
1351 }
1352 ek->items.reserve(graphs.size());
1353 const auto& optimizer_opts =
1354 options_.config.graph_options().optimizer_options();
1355
1356 int graph_def_version = graphs.begin()->second->versions().producer();
1357
1358 const auto* session_metadata =
1359 options_.config.experimental().has_session_metadata()
1360 ? &options_.config.experimental().session_metadata()
1361 : nullptr;
1362 func_info->proc_flr.reset(new ProcessFunctionLibraryRuntime(
1363 device_mgr_.get(), options_.env, &options_.config, graph_def_version,
1364 func_info->flib_def.get(), optimizer_opts, thread_pools_[0].first,
1365 /*parent=*/nullptr, session_metadata,
1366 Rendezvous::Factory{
1367 [](const int64_t, const DeviceMgr* device_mgr, Rendezvous** r) {
1368 *r = new IntraProcessRendezvous(device_mgr);
1369 return OkStatus();
1370 }}));
1371
1372 GraphOptimizer optimizer(optimizer_opts);
1373 for (auto iter = graphs.begin(); iter != graphs.end(); ++iter) {
1374 const string& partition_name = iter->first;
1375 std::unique_ptr<Graph>& partition_graph = iter->second;
1376
1377 Device* device;
1378 TF_RETURN_IF_ERROR(device_mgr_->LookupDevice(partition_name, &device));
1379
1380 ek->items.resize(ek->items.size() + 1);
1381 auto* item = &(ek->items.back());
1382 auto lib = func_info->proc_flr->GetFLR(partition_name);
1383 if (lib == nullptr) {
1384 return errors::Internal("Could not find device: ", partition_name);
1385 }
1386 item->flib = lib;
1387
1388 LocalExecutorParams params;
1389 params.device = device;
1390 params.session_metadata = session_metadata;
1391 params.function_library = lib;
1392 auto opseg = device->op_segment();
1393 params.create_kernel =
1394 [this, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
1395 OpKernel** kernel) {
1396 // NOTE(mrry): We must not share function kernels (implemented
1397 // using `CallOp`) between subgraphs, because `CallOp::handle_`
1398 // is tied to a particular subgraph. Even if the function itself
1399 // is stateful, the `CallOp` that invokes it is not.
1400 if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
1401 return lib->CreateKernel(props, kernel);
1402 }
1403 auto create_fn = [lib, &props](OpKernel** kernel) {
1404 return lib->CreateKernel(props, kernel);
1405 };
1406 // Kernels created for subgraph nodes need to be cached. On
1407 // cache miss, create_fn() is invoked to create a kernel based
1408 // on the function library here + global op registry.
1409 return opseg->FindOrCreate(session_handle_, props->node_def.name(),
1410 kernel, create_fn);
1411 };
1412 params.delete_kernel = [lib](OpKernel* kernel) {
1413 if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
1414 delete kernel;
1415 };
1416
1417 optimizer.Optimize(lib, options_.env, device, &partition_graph,
1418 GraphOptimizer::Options());
1419
1420 // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
1421 const DebugOptions& debug_options =
1422 options.callable_options.run_options().debug_options();
1423 if (!debug_options.debug_tensor_watch_opts().empty()) {
1424 TF_RETURN_IF_ERROR(DecorateAndPublishGraphForDebug(
1425 debug_options, partition_graph.get(), params.device));
1426 }
1427
1428 TF_RETURN_IF_ERROR(EnsureMemoryTypes(DeviceType(device->device_type()),
1429 device->name(),
1430 partition_graph.get()));
1431
1432 item->executor = nullptr;
1433 item->device = device;
1434 auto executor_type = options_.config.experimental().executor_type();
1435 TF_RETURN_IF_ERROR(
1436 NewExecutor(executor_type, params, *partition_graph, &item->executor));
1437 if (!options_.config.experimental().disable_output_partition_graphs() ||
1438 options_.config.graph_options().build_cost_model() > 0) {
1439 item->graph = std::move(partition_graph);
1440 }
1441 }
1442
1443 // Cache the mapping from input/output names to graph elements to
1444 // avoid recomputing it every time.
1445 if (!run_state_args->is_partial_run) {
1446 // For regular `Run()`, we use the function calling convention, and so
1447 // maintain a mapping from input/output names to
1448 // argument/return-value ordinal index.
1449 for (int i = 0; i < callable_options.feed().size(); ++i) {
1450 const string& input = callable_options.feed(i);
1451 ek->input_name_to_index[input] = i;
1452 }
1453 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1454 const string& output = callable_options.fetch(i);
1455 ek->output_name_to_index[output] = i;
1456 }
1457 } else {
1458 // For `PRun()`, we use the rendezvous calling convention, and so
1459 // maintain a mapping from input/output names to rendezvous keys.
1460 //
1461 // We always use the first device as the device name portion of the
1462 // key, even if we're feeding another graph.
1463 for (int i = 0; i < callable_options.feed().size(); ++i) {
1464 const string& input = callable_options.feed(i);
1465 ek->input_name_to_rendezvous_key[input] = GetRendezvousKey(
1466 input, device_set_.client_device()->attributes(), FrameAndIter(0, 0));
1467 }
1468 for (int i = 0; i < callable_options.fetch().size(); ++i) {
1469 const string& output = callable_options.fetch(i);
1470 ek->output_name_to_rendezvous_key[output] =
1471 GetRendezvousKey(output, device_set_.client_device()->attributes(),
1472 FrameAndIter(0, 0));
1473 }
1474 }
1475
1476 *out_executors_and_keys = std::move(ek);
1477 *out_func_info = std::move(func_info);
1478 return OkStatus();
1479}
1480
1481Status DirectSession::GetOrCreateExecutors(
1482 gtl::ArraySlice<string> inputs, gtl::ArraySlice<string> outputs,
1483 gtl::ArraySlice<string> target_nodes, ExecutorsAndKeys** executors_and_keys,
1484 RunStateArgs* run_state_args) {
1485 int64_t handle_name_counter_value = -1;
1486 if (LogMemory::IsEnabled() || run_state_args->is_partial_run) {
1487 handle_name_counter_value = handle_name_counter_.fetch_add(1);
1488 }
1489
1490 string debug_tensor_watches_summary;
1491 if (!run_state_args->debug_options.debug_tensor_watch_opts().empty()) {
1492 debug_tensor_watches_summary = SummarizeDebugTensorWatches(
1493 run_state_args->debug_options.debug_tensor_watch_opts());
1494 }
1495
1496 // Fast lookup path, no sorting.
1497 const string key = strings::StrCat(
1498 absl::StrJoin(inputs, ","), "->", absl::StrJoin(outputs, ","), "/",
1499 absl::StrJoin(target_nodes, ","), "/", run_state_args->is_partial_run,
1500 "/", debug_tensor_watches_summary);
1501 // Set the handle, if it's needed to log memory or for partial run.
1502 if (handle_name_counter_value >= 0) {
1503 run_state_args->handle =
1504 strings::StrCat(key, ";", handle_name_counter_value);
1505 }
1506
1507 // See if we already have the executors for this run.
1508 {
1509 mutex_lock l(executor_lock_); // could use reader lock
1510 auto it = executors_.find(key);
1511 if (it != executors_.end()) {
1512 *executors_and_keys = it->second.get();
1513 return OkStatus();
1514 }
1515 }
1516
1517 // Slow lookup path, the unsorted key missed the cache.
1518 // Sort the inputs and outputs, and look up with the sorted key in case an
1519 // earlier call used a different order of inputs and outputs.
1520 //
1521 // We could consider some other signature instead of sorting that
1522 // preserves the same property to avoid the sort in the future.
1523 std::vector<string> inputs_sorted(inputs.begin(), inputs.end());
1524 std::sort(inputs_sorted.begin(), inputs_sorted.end());
1525 std::vector<string> outputs_sorted(outputs.begin(), outputs.end());
1526 std::sort(outputs_sorted.begin(), outputs_sorted.end());
1527 std::vector<string> tn_sorted(target_nodes.begin(), target_nodes.end());
1528 std::sort(tn_sorted.begin(), tn_sorted.end());
1529
1530 const string sorted_key = strings::StrCat(
1531 absl::StrJoin(inputs_sorted, ","), "->",
1532 absl::StrJoin(outputs_sorted, ","), "/", absl::StrJoin(tn_sorted, ","),
1533 "/", run_state_args->is_partial_run, "/", debug_tensor_watches_summary);
1534 // Set the handle, if its needed to log memory or for partial run.
1535 if (handle_name_counter_value >= 0) {
1536 run_state_args->handle =
1537 strings::StrCat(sorted_key, ";", handle_name_counter_value);
1538 }
1539
1540 // See if we already have the executors for this run.
1541 {
1542 mutex_lock l(executor_lock_);
1543 auto it = executors_.find(sorted_key);
1544 if (it != executors_.end()) {
1545 *executors_and_keys = it->second.get();
1546 return OkStatus();
1547 }
1548 }
1549
1550 // Nothing found, so create the executors and store in the cache.
1551 // The executor_lock_ is intentionally released while executors are
1552 // being created.
1553 CallableOptions callable_options;
1554 callable_options.mutable_feed()->Reserve(inputs_sorted.size());
1555 for (const string& input : inputs_sorted) {
1556 callable_options.add_feed(input);
1557 }
1558 callable_options.mutable_fetch()->Reserve(outputs_sorted.size());
1559 for (const string& output : outputs_sorted) {
1560 callable_options.add_fetch(output);
1561 }
1562 callable_options.mutable_target()->Reserve(tn_sorted.size());
1563 for (const string& target : tn_sorted) {
1564 callable_options.add_target(target);
1565 }
1566 *callable_options.mutable_run_options()->mutable_debug_options() =
1567 run_state_args->debug_options;
1568 callable_options.mutable_run_options()
1569 ->mutable_experimental()
1570 ->set_collective_graph_key(run_state_args->collective_graph_key);
1571 std::unique_ptr<ExecutorsAndKeys> ek;
1572 std::unique_ptr<FunctionInfo> func_info;
1573 TF_RETURN_IF_ERROR(
1574 CreateExecutors(callable_options, &ek, &func_info, run_state_args));
1575
1576 // Reacquire the lock, try to insert into the map.
1577 mutex_lock l(executor_lock_);
1578
1579 // Another thread may have created the entry before us, in which case we will
1580 // reuse the already created one.
1581 auto insert_result = executors_.emplace(
1582 sorted_key, std::shared_ptr<ExecutorsAndKeys>(std::move(ek)));
1583 if (insert_result.second) {
1584 functions_.push_back(std::move(func_info));
1585 }
1586
1587 // Insert the value under the original key, so the fast path lookup will work
1588 // if the user uses the same order of inputs, outputs, and targets again.
1589 executors_.emplace(key, insert_result.first->second);
1590 *executors_and_keys = insert_result.first->second.get();
1591
1592 return OkStatus();
1593}
1594
1595Status DirectSession::CreateGraphs(
1596 const BuildGraphOptions& subgraph_options,
1597 std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
1598 std::unique_ptr<FunctionLibraryDefinition>* flib_def,
1599 RunStateArgs* run_state_args, DataTypeVector* input_types,
1600 DataTypeVector* output_types, int64_t* collective_graph_key) {
1601 mutex_lock l(graph_state_lock_);
1602 if (finalized_) {
1603 return errors::FailedPrecondition("Session has been finalized.");
1604 }
1605
1606 std::unique_ptr<ClientGraph> client_graph;
1607
1608 std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
1609 GraphExecutionState* execution_state = nullptr;
1610 if (options_.config.graph_options().place_pruned_graph()) {
1611 // Because we are placing pruned graphs, we need to create a
1612 // new GraphExecutionState for every new unseen graph,
1613 // and then place it.
1614 GraphExecutionStateOptions prune_options;
1615 prune_options.device_set = &device_set_;
1616 prune_options.session_options = &options_;
1617 prune_options.stateful_placements = stateful_placements_;
1618 prune_options.session_handle = session_handle_;
1619 TF_RETURN_IF_ERROR(GraphExecutionState::MakeForPrunedGraph(
1620 *execution_state_, prune_options, subgraph_options,
1621 &temp_exec_state_holder, &client_graph));
1622 execution_state = temp_exec_state_holder.get();
1623 } else {
1624 execution_state = execution_state_.get();
1625 TF_RETURN_IF_ERROR(
1626 execution_state->BuildGraph(subgraph_options, &client_graph));
1627 }
1628 *collective_graph_key = client_graph->collective_graph_key;
1629
1630 if (subgraph_options.callable_options.feed_size() !=
1631 client_graph->feed_types.size()) {
1632 return errors::Internal(
1633 "Graph pruning failed: requested number of feed endpoints = ",
1634 subgraph_options.callable_options.feed_size(),
1635 " versus number of pruned feed endpoints = ",
1636 client_graph->feed_types.size());
1637 }
1638 if (subgraph_options.callable_options.fetch_size() !=
1639 client_graph->fetch_types.size()) {
1640 return errors::Internal(
1641 "Graph pruning failed: requested number of fetch endpoints = ",
1642 subgraph_options.callable_options.fetch_size(),
1643 " versus number of pruned fetch endpoints = ",
1644 client_graph->fetch_types.size());
1645 }
1646
1647 auto current_stateful_placements = execution_state->GetStatefulPlacements();
1648 // Update our current state based on the execution_state's
1649 // placements. If there are any mismatches for a node,
1650 // we should fail, as this should never happen.
1651 for (const auto& placement_pair : current_stateful_placements) {
1652 const string& node_name = placement_pair.first;
1653 const string& placement = placement_pair.second;
1654 auto iter = stateful_placements_.find(node_name);
1655 if (iter == stateful_placements_.end()) {
1656 stateful_placements_.insert(std::make_pair(node_name, placement));
1657 } else if (iter->second != placement) {
1658 return errors::Internal(
1659 "Stateful placement mismatch. "
1660 "Current assignment of ",
1661 node_name, " to ", iter->second, " does not match ", placement);
1662 }
1663 }
1664
1665 stateful_placements_ = execution_state->GetStatefulPlacements();
1666
1667 // Remember the graph in run state if this is a partial run.
1668 if (run_state_args->is_partial_run) {
1669 run_state_args->graph.reset(new Graph(flib_def_.get()));
1670 CopyGraph(*execution_state->full_graph(), run_state_args->graph.get());
1671 }
1672
1673 // Partition the graph across devices.
1674 PartitionOptions popts;
1675 popts.node_to_loc = [](const Node* node) {
1676 return node->assigned_device_name();
1677 };
1678 popts.new_name = [this](const string& prefix) {
1679 return strings::StrCat(prefix, "/_", edge_name_counter_.fetch_add(1));
1680 };
1681 popts.get_incarnation = [](const string& name) {
1682 // The direct session does not have changing incarnation numbers.
1683 // Just return '1'.
1684 return 1;
1685 };
1686 popts.flib_def = flib_def->get();
1687 popts.control_flow_added = false;
1688
1689 std::unordered_map<string, GraphDef> partitions;
1690 TF_RETURN_IF_ERROR(Partition(popts, &client_graph->graph, &partitions));
1691
1692 std::vector<string> device_names;
1693 for (auto device : devices_) {
1694 // Extract the LocalName from the device.
1695 device_names.push_back(DeviceNameUtils::LocalName(device->name()));
1696 }
1697
1698 // Check for valid partitions.
1699 for (const auto& partition : partitions) {
1700 const string local_partition_name =
1701 DeviceNameUtils::LocalName(partition.first);
1702 if (std::count(device_names.begin(), device_names.end(),
1703 local_partition_name) == 0) {
1704 return errors::InvalidArgument(
1705 "Creating a partition for ", local_partition_name,
1706 " which doesn't exist in the list of available devices. Available "
1707 "devices: ",
1708 absl::StrJoin(device_names, ","));
1709 }
1710 }
1711
1712 for (auto& partition : partitions) {
1713 std::unique_ptr<Graph> device_graph(
1714 new Graph(client_graph->flib_def.get()));
1715 device_graph->SetConstructionContext(ConstructionContext::kDirectSession);
1716 GraphConstructorOptions device_opts;
1717 // There are internal operations (e.g., send/recv) that we now allow.
1718 device_opts.allow_internal_ops = true;
1719 device_opts.expect_device_spec = true;
1720 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1721 device_opts, std::move(partition.second), device_graph.get()));
1722 outputs->emplace(partition.first, std::move(device_graph));
1723 }
1724
1725 GraphOptimizationPassOptions optimization_options;
1726 optimization_options.session_options = &options_;
1727 optimization_options.flib_def = client_graph->flib_def.get();
1728 optimization_options.partition_graphs = outputs;
1729 TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
1730 OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
1731
1732 Status s;
1733 for (auto& partition : *outputs) {
1734 const string& partition_name = partition.first;
1735 std::unique_ptr<Graph>* graph = &partition.second;
1736
1737 VLOG(2) << "Created " << DebugString(graph->get()) << " for "
1738 << partition_name;
1739
1740 // Give the device an opportunity to rewrite its subgraph.
1741 Device* d;
1742 s = device_mgr_->LookupDevice(partition_name, &d);
1743 if (!s.ok()) break;
1744 s = d->MaybeRewriteGraph(graph);
1745 if (!s.ok()) {
1746 break;
1747 }
1748 }
1749 *flib_def = std::move(client_graph->flib_def);
1750 std::swap(*input_types, client_graph->feed_types);
1751 std::swap(*output_types, client_graph->fetch_types);
1752 return s;
1753}
1754
1755::tensorflow::Status DirectSession::ListDevices(
1756 std::vector<DeviceAttributes>* response) {
1757 response->clear();
1758 response->reserve(devices_.size());
1759 for (Device* d : devices_) {
1760 const DeviceAttributes& attrs = d->attributes();
1761 response->emplace_back(attrs);
1762 }
1763 return OkStatus();
1764}
1765
1766::tensorflow::Status DirectSession::Reset(
1767 const std::vector<string>& containers) {
1768 device_mgr_->ClearContainers(containers);
1769 return OkStatus();
1770}
1771
1772::tensorflow::Status DirectSession::Close() {
1773 cancellation_manager_->StartCancel();
1774 {
1775 mutex_lock l(closed_lock_);
1776 if (closed_) return OkStatus();
1777 closed_ = true;
1778 }
1779 if (factory_ != nullptr) factory_->Deregister(this);
1780 return OkStatus();
1781}
1782
1783DirectSession::RunState::RunState(int64_t step_id,
1784 const std::vector<Device*>* devices)
1785 : step_container(step_id, [devices, step_id](const string& name) {
1786 for (auto d : *devices) {
1787 if (!d->resource_manager()->Cleanup(name).ok()) {
1788 // Do nothing...
1789 }
1790 ScopedAllocatorMgr* sam = d->GetScopedAllocatorMgr();
1791 if (sam) sam->Cleanup(step_id);
1792 }
1793 }) {}
1794
1795DirectSession::PartialRunState::PartialRunState(
1796 const std::vector<string>& pending_input_names,
1797 const std::vector<string>& pending_output_names, int64_t step_id,
1798 const std::vector<Device*>* devices)
1799 : RunState(step_id, devices) {
1800 // Initially all the feeds and fetches are pending.
1801 for (auto& name : pending_input_names) {
1802 pending_inputs[name] = false;
1803 }
1804 for (auto& name : pending_output_names) {
1805 pending_outputs[name] = false;
1806 }
1807}
1808
1809DirectSession::PartialRunState::~PartialRunState() {
1810 if (rendez != nullptr) {
1811 rendez->StartAbort(errors::Cancelled("PRun cancellation"));
1812 executors_done.WaitForNotification();
1813 }
1814}
1815
1816bool DirectSession::PartialRunState::PendingDone() const {
1817 for (const auto& it : pending_inputs) {
1818 if (!it.second) return false;
1819 }
1820 for (const auto& it : pending_outputs) {
1821 if (!it.second) return false;
1822 }
1823 return true;
1824}
1825
1826void DirectSession::WaitForNotification(Notification* n, RunState* run_state,
1827 CancellationManager* cm,
1828 int64_t timeout_in_ms) {
1829 const Status status = WaitForNotification(n, timeout_in_ms);
1830 if (!status.ok()) {
1831 {
1832 mutex_lock l(run_state->mu);
1833 run_state->status.Update(status);
1834 }
1835 cm->StartCancel();
1836 // We must wait for the executors to complete, because they have borrowed
1837 // references to `cm` and other per-step state. After this notification, it
1838 // is safe to clean up the step.
1839 n->WaitForNotification();
1840 }
1841}
1842
1843::tensorflow::Status DirectSession::WaitForNotification(
1844 Notification* notification, int64_t timeout_in_ms) {
1845 if (timeout_in_ms > 0) {
1846 const int64_t timeout_in_us = timeout_in_ms * 1000;
1847 const bool notified =
1848 WaitForNotificationWithTimeout(notification, timeout_in_us);
1849 if (!notified) {
1850 return Status(error::DEADLINE_EXCEEDED,
1851 "Timed out waiting for notification");
1852 }
1853 } else {
1854 notification->WaitForNotification();
1855 }
1856 return OkStatus();
1857}
1858
1859Status DirectSession::MakeCallable(const CallableOptions& callable_options,
1860 CallableHandle* out_handle) {
1861 TF_RETURN_IF_ERROR(CheckNotClosed());
1862 TF_RETURN_IF_ERROR(CheckGraphCreated("MakeCallable()"));
1863
1864 std::unique_ptr<ExecutorsAndKeys> ek;
1865 std::unique_ptr<FunctionInfo> func_info;
1866 RunStateArgs run_state_args(callable_options.run_options().debug_options());
1867 TF_RETURN_IF_ERROR(
1868 CreateExecutors(callable_options, &ek, &func_info, &run_state_args));
1869 {
1870 mutex_lock l(callables_lock_);
1871 *out_handle = next_callable_handle_++;
1872 callables_[*out_handle] = {std::move(ek), std::move(func_info)};
1873 }
1874 return OkStatus();
1875}
1876
1877class DirectSession::RunCallableCallFrame : public CallFrameInterface {
1878 public:
1879 RunCallableCallFrame(DirectSession* session,
1880 ExecutorsAndKeys* executors_and_keys,
1881 const std::vector<Tensor>* feed_tensors,
1882 std::vector<Tensor>* fetch_tensors)
1883 : session_(session),
1884 executors_and_keys_(executors_and_keys),
1885 feed_tensors_(feed_tensors),
1886 fetch_tensors_(fetch_tensors) {}
1887
1888 size_t num_args() const override {
1889 return executors_and_keys_->input_types.size();
1890 }
1891 size_t num_retvals() const override {
1892 return executors_and_keys_->output_types.size();
1893 }
1894
1895 Status GetArg(int index, const Tensor** val) override {
1896 if (TF_PREDICT_FALSE(index > feed_tensors_->size())) {
1897 return errors::Internal("Args index out of bounds: ", index);
1898 } else {
1899 *val = &(*feed_tensors_)[index];
1900 }
1901 return OkStatus();
1902 }
1903
1904 Status SetRetval(int index, const Tensor& val) override {
1905 if (index > fetch_tensors_->size()) {
1906 return errors::Internal("RetVal index out of bounds: ", index);
1907 }
1908 (*fetch_tensors_)[index] = val;
1909 return OkStatus();
1910 }
1911
1912 private:
1913 DirectSession* const session_; // Not owned.
1914 ExecutorsAndKeys* const executors_and_keys_; // Not owned.
1915 const std::vector<Tensor>* const feed_tensors_; // Not owned.
1916 std::vector<Tensor>* const fetch_tensors_; // Not owned.
1917};
1918
1919::tensorflow::Status DirectSession::RunCallable(
1920 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1921 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata) {
1922 return RunCallable(handle, feed_tensors, fetch_tensors, run_metadata,
1923 thread::ThreadPoolOptions());
1924}
1925
1926::tensorflow::Status DirectSession::RunCallable(
1927 CallableHandle handle, const std::vector<Tensor>& feed_tensors,
1928 std::vector<Tensor>* fetch_tensors, RunMetadata* run_metadata,
1929 const thread::ThreadPoolOptions& threadpool_options) {
1930 TF_RETURN_IF_ERROR(CheckNotClosed());
1931 TF_RETURN_IF_ERROR(CheckGraphCreated("RunCallable()"));
1932 direct_session_runs->GetCell()->IncrementBy(1);
1933
1934 // Check if we already have an executor for these arguments.
1935 std::shared_ptr<ExecutorsAndKeys> executors_and_keys;
1936 const int64_t step_id = step_id_counter_.fetch_add(1);
1937
1938 {
1939 tf_shared_lock l(callables_lock_);
1940 if (handle >= next_callable_handle_) {
1941 return errors::InvalidArgument("No such callable handle: ", handle);
1942 }
1943 executors_and_keys = callables_[handle].executors_and_keys;
1944 }
1945
1946 if (!executors_and_keys) {
1947 return errors::InvalidArgument(
1948 "Attempted to run callable after handle was released: ", handle);
1949 }
1950
1951 // NOTE(mrry): Debug options are not currently supported in the
1952 // callable interface.
1953 DebugOptions debug_options;
1954 RunStateArgs run_state_args(debug_options);
1955
1956 // Configure a call frame for the step, which we use to feed and
1957 // fetch values to and from the executors.
1958 if (feed_tensors.size() != executors_and_keys->input_types.size()) {
1959 return errors::InvalidArgument(
1960 "Expected ", executors_and_keys->input_types.size(),
1961 " feed tensors, but got ", feed_tensors.size());
1962 }
1963 if (fetch_tensors != nullptr) {
1964 fetch_tensors->resize(executors_and_keys->output_types.size());
1965 } else if (!executors_and_keys->output_types.empty()) {
1966 return errors::InvalidArgument(
1967 "`fetch_tensors` must be provided when the callable has one or more "
1968 "outputs.");
1969 }
1970
1971 size_t input_size = 0;
1972 bool any_resource_feeds = false;
1973 for (auto& tensor : feed_tensors) {
1974 input_size += tensor.AllocatedBytes();
1975 any_resource_feeds = any_resource_feeds || tensor.dtype() == DT_RESOURCE;
1976 }
1977 metrics::RecordGraphInputTensors(input_size);
1978
1979 std::unique_ptr<std::vector<Tensor>> converted_feed_tensors;
1980 const std::vector<Tensor>* actual_feed_tensors;
1981
1982 if (TF_PREDICT_FALSE(any_resource_feeds)) {
1983 converted_feed_tensors = std::make_unique<std::vector<Tensor>>();
1984 converted_feed_tensors->reserve(feed_tensors.size());
1985 for (const Tensor& t : feed_tensors) {
1986 if (t.dtype() == DT_RESOURCE) {
1987 converted_feed_tensors->emplace_back();
1988 Tensor* tensor_from_handle = &converted_feed_tensors->back();
1989 TF_RETURN_IF_ERROR(ResourceHandleToInputTensor(t, tensor_from_handle));
1990 } else {
1991 converted_feed_tensors->emplace_back(t);
1992 }
1993 }
1994 actual_feed_tensors = converted_feed_tensors.get();
1995 } else {
1996 actual_feed_tensors = &feed_tensors;
1997 }
1998
1999 // A specialized CallFrame implementation that takes advantage of the
2000 // optimized RunCallable interface.
2001 RunCallableCallFrame call_frame(this, executors_and_keys.get(),
2002 actual_feed_tensors, fetch_tensors);
2003
2004 if (LogMemory::IsEnabled()) {
2005 LogMemory::RecordStep(step_id, run_state_args.handle);
2006 }
2007
2008 TF_RETURN_IF_ERROR(RunInternal(
2009 step_id, executors_and_keys->callable_options.run_options(), &call_frame,
2010 executors_and_keys.get(), run_metadata, threadpool_options));
2011
2012 if (fetch_tensors != nullptr) {
2013 size_t output_size = 0;
2014 for (auto& tensor : *fetch_tensors) {
2015 output_size += tensor.AllocatedBytes();
2016 }
2017 metrics::RecordGraphOutputTensors(output_size);
2018 }
2019
2020 return OkStatus();
2021}
2022
2023::tensorflow::Status DirectSession::ReleaseCallable(CallableHandle handle) {
2024 mutex_lock l(callables_lock_);
2025 if (handle >= next_callable_handle_) {
2026 return errors::InvalidArgument("No such callable handle: ", handle);
2027 }
2028 callables_.erase(handle);
2029 return OkStatus();
2030}
2031
2032Status DirectSession::Finalize() {
2033 mutex_lock l(graph_state_lock_);
2034 if (finalized_) {
2035 return errors::FailedPrecondition("Session already finalized.");
2036 }
2037 if (!graph_created_) {
2038 return errors::FailedPrecondition("Session not yet created.");
2039 }
2040 execution_state_.reset();
2041 flib_def_.reset();
2042 finalized_ = true;
2043 return OkStatus();
2044}
2045
2046DirectSession::Callable::~Callable() {
2047 // We must delete the fields in this order, because the destructor
2048 // of `executors_and_keys` will call into an object owned by
2049 // `function_info` (in particular, when deleting a kernel, it relies
2050 // on the `FunctionLibraryRuntime` to know if the kernel is stateful
2051 // or not).
2052 executors_and_keys.reset();
2053 function_info.reset();
2054}
2055
2056} // namespace tensorflow
2057