1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/executor.h"
17
18#include <algorithm>
19#include <atomic>
20#include <memory>
21#include <vector>
22
23#include "absl/memory/memory.h"
24#include "absl/time/time.h"
25#include "absl/types/optional.h"
26#include "tensorflow/core/activity_watcher/activity.h"
27#include "tensorflow/core/common_runtime/costmodel_manager.h"
28#include "tensorflow/core/common_runtime/entry.h"
29#include "tensorflow/core/common_runtime/executor_factory.h"
30#include "tensorflow/core/common_runtime/graph_view.h"
31#include "tensorflow/core/common_runtime/immutable_executor_state.h"
32#include "tensorflow/core/common_runtime/pending_counts.h"
33#include "tensorflow/core/common_runtime/propagator_state.h"
34#include "tensorflow/core/common_runtime/renamed_device.h"
35#include "tensorflow/core/common_runtime/simple_propagator_state.h"
36#include "tensorflow/core/common_runtime/step_stats_collector.h"
37#include "tensorflow/core/framework/allocator.h"
38#include "tensorflow/core/framework/cancellation.h"
39#include "tensorflow/core/framework/collective.h"
40#include "tensorflow/core/framework/control_flow.h"
41#include "tensorflow/core/framework/device_attributes.pb.h"
42#include "tensorflow/core/framework/log_memory.h"
43#include "tensorflow/core/framework/metrics.h"
44#include "tensorflow/core/framework/node_def_util.h"
45#include "tensorflow/core/framework/op_kernel.h"
46#include "tensorflow/core/framework/op_segment.h"
47#include "tensorflow/core/framework/tensor.h"
48#include "tensorflow/core/framework/tensor_reference.h"
49#include "tensorflow/core/framework/types.h"
50#include "tensorflow/core/framework/types.pb.h"
51#include "tensorflow/core/graph/edgeset.h"
52#include "tensorflow/core/graph/graph.h"
53#include "tensorflow/core/graph/graph_node_util.h"
54#include "tensorflow/core/lib/core/errors.h"
55#include "tensorflow/core/lib/core/notification.h"
56#include "tensorflow/core/lib/core/status.h"
57#include "tensorflow/core/lib/core/threadpool.h"
58#include "tensorflow/core/lib/gtl/flatmap.h"
59#include "tensorflow/core/lib/gtl/inlined_vector.h"
60#include "tensorflow/core/lib/gtl/manual_constructor.h"
61#include "tensorflow/core/lib/hash/hash.h"
62#include "tensorflow/core/platform/context.h"
63#include "tensorflow/core/platform/env.h"
64#include "tensorflow/core/platform/errors.h"
65#include "tensorflow/core/platform/logging.h"
66#include "tensorflow/core/platform/macros.h"
67#include "tensorflow/core/platform/mutex.h"
68#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
69#include "tensorflow/core/platform/status.h"
70#include "tensorflow/core/platform/thread_annotations.h"
71#include "tensorflow/core/platform/tracing.h"
72#include "tensorflow/core/platform/types.h"
73#include "tensorflow/core/profiler/lib/annotated_traceme.h"
74#include "tensorflow/core/profiler/lib/connected_traceme.h"
75#include "tensorflow/core/profiler/lib/scoped_annotation.h"
76#include "tensorflow/core/profiler/lib/traceme.h"
77#include "tensorflow/core/profiler/lib/traceme_encode.h"
78#include "tensorflow/core/protobuf/error_codes.pb.h"
79#include "tensorflow/core/util/determinism.h"
80#include "tensorflow/core/util/managed_stack_trace.h"
81#include "tensorflow/core/util/tensor_slice_reader_cache.h"
82
83namespace tensorflow {
84
85namespace {
86
87// 1-D, 0 element tensor.
88static const Tensor* const kEmptyTensor = new Tensor;
89
90// Helper routines for collecting step stats.
91namespace nodestats {
92inline int64_t NowInNsec() { return EnvTime::NowNanos(); }
93
94void SetScheduled(NodeExecStatsInterface* stats, int64_t micros) {
95 if (!stats) return;
96 stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
97}
98
99void SetAllStart(NodeExecStatsInterface* stats) {
100 if (!stats) return;
101 stats->RecordExecutorStarted();
102}
103
104void SetOpStart(NodeExecStatsInterface* stats) {
105 if (!stats) return;
106 stats->RecordComputeStarted();
107}
108
109void SetOpEnd(NodeExecStatsInterface* stats) {
110 if (!stats) return;
111 stats->RecordComputeEnded();
112}
113
114void SetAllEnd(NodeExecStatsInterface* stats) {
115 if (!stats) return;
116 stats->RecordExecutorEnded();
117}
118
119void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
120 if (!stats) return;
121 stats->SetOutput(slot, v);
122}
123
124void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
125 if (!stats) return;
126 stats->SetMemory(ctx);
127}
128
129} // namespace nodestats
130
131// Time the execution of kernels (in CPU cycles). Used to dynamically identify
132// inexpensive kernels which can be dispatched inline.
133struct KernelTimer {
134 uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle();
135
136 uint64 ElapsedCycles() {
137 return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles;
138 }
139};
140
141// TODO(b/152925936): Re-evaluate these constants with current usage patterns.
142typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
143typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
144
145class ExecutorImpl : public Executor {
146 public:
147 explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {}
148
149 Status Initialize(const Graph& graph) {
150 TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph));
151 kernel_stats_.Initialize(immutable_state_.graph_view());
152 return OkStatus();
153 }
154
155 void RunAsync(const Args& args, DoneCallback done) override;
156
157 private:
158 template <class PropagatorStateType>
159 friend class ExecutorState;
160
161 // Stores execution time information about the kernels in an executor's graph.
162 class KernelStats {
163 public:
164 KernelStats() = default;
165
166 void Initialize(const GraphView& gview) {
167 is_expensive_.resize(gview.num_nodes());
168 cost_estimates_ =
169 std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes());
170 for (int32_t i = 0; i < gview.num_nodes(); ++i) {
171 if (gview.node(i)) {
172 is_expensive_[i] =
173 gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive();
174 cost_estimates_[i] = kInitialCostEstimateCycles;
175 }
176 }
177 }
178
179 // Returns true iff the given node is considered "expensive". The
180 // executor uses this flag to optimize graph execution, for example
181 // by "inlining" inexpensive kernels.
182 bool IsExpensive(const NodeItem& node) const {
183 return is_expensive_[node.node_id] &&
184 (cost_estimates_[node.node_id].load(std::memory_order_relaxed) >
185 kOpIsExpensiveThresholdCycles);
186 }
187
188 // Returns the value of kernel->IsExpensive().
189 bool HasExpensiveMarker(const NodeItem& node) const {
190 return is_expensive_[node.node_id];
191 }
192
193 // Updates the dynamic cost estimate, which is used to determine whether the
194 // given node is expensive. The new cost estimate is a weighted average of
195 // the old cost estimate and the latest cost. We only update cost estimates
196 // for kernels for which IsExpensive() return true.
197 void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) {
198 // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous
199 // updates may result in one or more updates being ignored. This does not
200 // affect correctness but may slow down the update frequency.
201 std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id];
202 auto prev_estimate = cost_estimate.load(std::memory_order_relaxed);
203
204 uint64 new_estimate =
205 ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay;
206
207 cost_estimate.store(new_estimate, std::memory_order_relaxed);
208 }
209
210 private:
211 // Initial time (in CPU cycles) we expect an operation to take. Used to
212 // determine whether an operation should be place in a threadpool.
213 // Operations start out "expensive".
214 static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000;
215 static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000;
216 static constexpr uint64 kCostDecay = 10;
217
218 std::vector<bool> is_expensive_;
219 // std::unique_ptr<std::atomic<bool>[]> is_expensive_;
220 std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_;
221 };
222
223 ImmutableExecutorState immutable_state_;
224 KernelStats kernel_stats_;
225
226 TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl);
227};
228
229// The state associated with one invocation of ExecutorImpl::Run.
230//
231// ExecutorState dispatches nodes when they become ready, and delegates to an
232// instance of `PropagatorStateType` to keep track of how many predecessors of a
233// are still pending.
234//
235// The template argument `class PropagatorStateType` must define the following
236// public members:
237// * A type `TaggedNode`, representing a node to be processed, with public
238// members:
239// * `const NodeItem& get_node_item() const`
240// * `bool get_is_dead() const`
241// * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be
242// processed, with public members (having the same meanings as in an
243// `std::vector<TaggedNode>`):
244// * `void push_back(const TaggedNode& node)`
245// * `TaggedNode front() const`
246// * `void pop_front()`
247// * `bool empty() const`
248// * A type `TaggedNodeSeq`, representing a list of nodes to be scheduled, with
249// public members (having the same meanings as in an
250// `std::vector<TaggedNode>`):
251// * `size_t size() const`
252// * `bool empty() const`
253// * `void clear()`
254// * `const_iterator begin() const`
255// * `const_iterator end() const`
256// * A public constructor, `PropagatorStateType(const ImmutableExecutorState&
257// immutable_state, int64 step_id)`.
258// * The following public methods:
259// * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
260// TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the
261// nodes in `roots` and adds them to `*ready`
262// * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector*
263// outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the
264// given `tagged_node` to the destinations of its output edges, and adds
265// any newly runnable nodes to `*ready`
266// * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which
267// returns a pointer to the input tensors for the given `tagged_node`
268// * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`,
269// which creates a `FrameAndIter` for the given `tagged_node`
270// * `void DumpState()`, which dumps the dynamic state of the executing graph
271// * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records
272// that a node has started
273// * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records
274// that a node has completed
275//
276// See `PropagatorState` in "./propagator_state.h" for an example of a type that
277// can be used to instantiate `PropagatorStateType`.
278template <class PropagatorStateType>
279class ExecutorState {
280 public:
281 ExecutorState(const Executor::Args& args,
282 const ImmutableExecutorState& immutable_state_,
283 ExecutorImpl::KernelStats* kernel_stats_);
284 ~ExecutorState();
285
286 void RunAsync(Executor::DoneCallback done);
287
288 private:
289 // Use `TaggedNode` types defined by `PropagatorStateType`.
290 typedef typename PropagatorStateType::TaggedNode TaggedNode;
291 typedef
292 typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue;
293 typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq;
294
295 struct AsyncState;
296
297 // Process a ready node in current thread.
298 void Process(TaggedNode node, int64_t scheduled_nsec);
299
300 void ProcessInline(TaggedNodeReadyQueue* inline_ready,
301 int64_t scheduled_nsec);
302
303 Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params,
304 EntryVector* outputs, NodeExecStatsInterface* stats);
305 void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params,
306 const TaggedNode& tagged_node, Entry* first_input,
307 NodeExecStatsInterface* stats,
308 activity_watcher::ActivityId activity_id);
309 void ProcessNoop(NodeExecStatsInterface* stats);
310 void ProcessConstTensor(const NodeItem& item, EntryVector* outputs,
311 NodeExecStatsInterface* stats);
312
313 // Before invoking item->kernel, fills in its "inputs".
314 Status PrepareInputs(const NodeItem& item, Entry* first_input,
315 TensorValueVec* inputs,
316 AllocatorAttributeVec* input_alloc_attrs,
317 bool* is_input_dead);
318
319 // After item->kernel computation is done, processes its outputs.
320 Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
321 Entry* outputs, NodeExecStatsInterface* stats);
322
323 // Called after each node finishes. Takes ownership of "stats". Returns true
324 // if execution has completed.
325 //
326 // This method will clear `*ready` before returning.
327 bool NodeDone(const Status& s, TaggedNodeSeq* ready,
328 NodeExecStatsInterface* stats,
329 TaggedNodeReadyQueue* inline_ready);
330
331 // Schedule all the expensive nodes in '*ready', and put all the inexpensive
332 // nodes in 'ready' into 'inline_ready'.
333 //
334 // This method will clear `*ready` before returning.
335 //
336 // REQUIRES: `!ready->empty()`.
337 void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready);
338
339 // A wrapper for runner_ to keep track of the pending queue length. Op
340 // execution should dispatch work using this function instead of using runner_
341 // directly.
342 template <typename Closure>
343 void RunTask(Closure&& c, int sample_rate = 0);
344
345 // Clean up when this executor is done.
346 void Finish();
347 void ScheduleFinish();
348
349 // Contains the device context assigned by the device at the beginning of a
350 // step.
351 DeviceContext* device_context_ = nullptr;
352
353 const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply.
354
355 // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply.
356 const bool log_memory_;
357
358 int64_t step_id_;
359 int64_t start_time_usecs_ = 0;
360 // The deadline for the session to complete by. Empty if unspecified.
361 absl::optional<absl::Time> deadline_;
362
363 // Maximum number of kernels that can be scheduled inline. If lots of kernels
364 // are ready at the same time, scheduling them in one thread can be very slow.
365 // TODO(fishx): Make it configurable if necessary.
366 static constexpr uint64 kInlineScheduleReadyThreshold = 500;
367
368 // Not owned.
369 RendezvousInterface* rendezvous_;
370 CollectiveExecutor* collective_executor_ = nullptr;
371 SessionState* session_state_;
372 string session_handle_;
373 const SessionMetadata* session_metadata_ = nullptr;
374 TensorStore* tensor_store_;
375 // Step-local container.
376 ScopedStepContainer* step_container_;
377 StepStatsCollectorInterface* const stats_collector_;
378 const tracing::EventCollector* const event_collector_;
379 Context context_;
380
381 // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
382 // instead of a pointer? (avoids having to delete).
383 checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
384 CallFrameInterface* call_frame_;
385 const ImmutableExecutorState& immutable_state_;
386 ExecutorImpl::KernelStats* const kernel_stats_;
387 CancellationManager* cancellation_manager_;
388 CoordinationServiceAgent* coordination_service_agent_;
389 absl::optional<ManagedStackTrace> stack_trace_ = absl::nullopt;
390 // If not null, use this device to schedule intra-op operation
391 std::unique_ptr<DeviceBase> user_device_;
392 Executor::Args::Runner runner_;
393 bool sync_on_finish_;
394 const bool run_all_kernels_inline_;
395
396 PropagatorStateType propagator_;
397
398 // Invoked when the execution finishes.
399 Executor::DoneCallback done_cb_;
400
401 std::atomic_int_fast32_t num_outstanding_ops_;
402
403 // Available via OpKernelContext to every OpKernel invocation.
404 mutex num_deferred_ops_mu_;
405 int64_t num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0;
406 bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) =
407 false;
408
409 mutex mu_;
410 Status status_ TF_GUARDED_BY(mu_);
411};
412
413template <class PropagatorStateType>
414ExecutorState<PropagatorStateType>::ExecutorState(
415 const Executor::Args& args, const ImmutableExecutorState& immutable_state,
416 ExecutorImpl::KernelStats* kernel_stats)
417 : vlog_(VLOG_IS_ON(1)),
418 log_memory_(LogMemory::IsEnabled()),
419 step_id_(args.step_id),
420 start_time_usecs_(args.start_time_usecs),
421 deadline_(args.deadline),
422 rendezvous_(args.rendezvous),
423 collective_executor_(args.collective_executor),
424 session_state_(args.session_state),
425 session_handle_(args.session_handle),
426 session_metadata_(immutable_state.params().session_metadata),
427 tensor_store_(args.tensor_store),
428 step_container_(args.step_container),
429 stats_collector_(args.stats_collector),
430 event_collector_(
431 tracing::GetEventCollector(tracing::EventCategory::kCompute)),
432 context_(ContextKind::kThread),
433 slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
434 call_frame_(args.call_frame),
435 immutable_state_(immutable_state),
436 kernel_stats_(kernel_stats),
437 cancellation_manager_(args.cancellation_manager),
438 coordination_service_agent_(args.coordination_service_agent),
439 stack_trace_(args.stack_trace),
440 runner_(args.runner),
441 sync_on_finish_(args.sync_on_finish),
442 run_all_kernels_inline_(args.run_all_kernels_inline),
443 propagator_(immutable_state, step_id_, vlog_),
444 num_outstanding_ops_(0) {
445 if (args.user_intra_op_threadpool != nullptr) {
446 Device* device = immutable_state_.params().device;
447 user_device_ = RenamedDevice::NewRenamedDevice(
448 device->name(), device, false, false, args.user_intra_op_threadpool);
449 }
450}
451
452template <class PropagatorStateType>
453ExecutorState<PropagatorStateType>::~ExecutorState() {
454 if (device_context_) {
455 device_context_->Unref();
456 }
457 delete slice_reader_cache_;
458}
459
460template <class PropagatorStateType>
461template <typename Closure>
462void ExecutorState<PropagatorStateType>::RunTask(Closure&& c, int sample_rate) {
463 // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the
464 // cacheline size is 64 bytes or smaller.
465 alignas(64) static std::atomic<int64_t> num_enqueue_ops{0};
466 alignas(64) static std::atomic<int64_t> num_dequeue_ops{0};
467
468 auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed);
469 // Sample the queue length on at least every 16 enqueue operations. This
470 // amortizes the cost of metric updates across 16 operations.
471 if (n_enqueues % std::max(16, sample_rate) == 0) {
472 auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed);
473 metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues);
474 }
475
476 // mutable is needed because std::forward<Closure> in the lambda body may move
477 // the Closure `c`.
478 runner_([c = std::forward<Closure>(c)]() mutable {
479 num_dequeue_ops.fetch_add(1, std::memory_order_relaxed);
480 std::forward<Closure>(c)();
481 });
482}
483
484template <class PropagatorStateType>
485void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) {
486 TaggedNodeSeq ready;
487
488 // Ask the device to fill in the device context map.
489 Device* device = immutable_state_.params().device;
490 const Status get_context_status =
491 device->TryGetDeviceContext(&device_context_);
492 if (!get_context_status.ok()) {
493 delete this;
494 done(get_context_status);
495 return;
496 }
497
498 // Initialize the ready queue.
499 ready.reserve(immutable_state_.root_nodes().size());
500 propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready);
501 num_outstanding_ops_ = ready.size();
502 if (ready.empty()) {
503 delete this;
504 done(OkStatus());
505 } else {
506 done_cb_ = std::move(done);
507 // Schedule to run all the ready ops in thread pool.
508 ScheduleReady(&ready, nullptr);
509 }
510}
511
512// State kept alive for executing an asynchronous node in another
513// thread. NOTE: We need to make a copy of p.input and p.input_alloc_attrs for
514// asynchronous kernels because OpKernelContext methods like input_type(i) needs
515// the param points to valid input type vector. It's not an issue for
516// sync kernels because these vectors are kept on the stack.
517template <class PropagatorStateType>
518struct ExecutorState<PropagatorStateType>::AsyncState {
519 AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
520 const NodeItem* _item, Entry* _first_input,
521 NodeExecStatsInterface* _stats)
522 : saved_inputs(p.inputs.begin(), p.inputs.end()),
523 saved_input_alloc_attrs(p.input_alloc_attrs.begin(),
524 p.input_alloc_attrs.end()),
525 params(p),
526 tagged_node(_tagged_node),
527 item(_item),
528 first_input(_first_input),
529 // ParamsButClearingEigenGPUDevice does equivalent of
530 // params.eigen_gpu_device = nullptr;
531 ctx(ParamsButClearingEigenGPUDevice(&params), item->num_outputs),
532 stats(_stats) {
533 params.inputs = saved_inputs;
534 params.input_alloc_attrs = saved_input_alloc_attrs;
535 }
536
537 TensorValueVec saved_inputs;
538 AllocatorAttributeVec saved_input_alloc_attrs;
539 OpKernelContext::Params params;
540 TaggedNode tagged_node;
541 const NodeItem* item;
542 Entry* first_input;
543 OpKernelContext ctx;
544 NodeExecStatsInterface* stats;
545
546 private:
547 OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
548 OpKernelContext::Params* p) {
549 // Ensure OpKernelContext constructor will make a new eigen GPU device if
550 // necessary.
551 p->eigen_gpu_device = nullptr; // Force allocation
552 return p;
553 }
554};
555
556// Returns true if `item` might be traced by the given trace and event
557// collectors. Returns false only if `item` definitely will not be traced.
558bool MightTrace(const tracing::EventCollector* event_collector,
559 bool is_expensive) {
560 // Tracing will only be enabled if either `event_collector` is non null,
561 // or `trace_collector` is non-null and enabled for this particular kernel.
562 // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and
563 // `tracing::ScopedRegion` check subsets of these properties internally in
564 // their constructors, the cost of passing the necessary arguments to them can
565 // be significant, so we avoid constructing them in the common case (when we
566 // know they will not be used).
567 if (event_collector != nullptr) {
568 return true;
569 }
570
571 if (profiler::ScopedAnnotation::IsEnabled()) return true;
572
573 return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive));
574}
575
576template <class PropagatorStateType>
577Status ExecutorState<PropagatorStateType>::ProcessSync(
578 const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs,
579 NodeExecStatsInterface* stats) {
580 Status s;
581 OpKernelContext ctx(params, item.num_outputs);
582 nodestats::SetOpStart(stats);
583
584 OpKernel* op_kernel = item.kernel;
585 Device* device = immutable_state_.params().device;
586 const bool is_expensive = kernel_stats_->IsExpensive(item);
587
588 if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) {
589 tracing::ScopedRegion region(tracing::EventCategory::kCompute,
590 op_kernel->name_view());
591 profiler::AnnotatedTraceMe activity(
592 [op_kernel, &ctx] {
593 return op_kernel->TraceString(
594 ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
595 },
596 profiler::GetTFTraceMeLevel(is_expensive));
597 device->Compute(op_kernel, &ctx);
598 } else if (kernel_stats_->HasExpensiveMarker(item)) {
599 KernelTimer timer;
600 device->Compute(op_kernel, &ctx);
601 // For expensive kernels, always update the cost estimate. For inexpensive
602 // kernels, update the cost estimate with ~1/16 probability. This assumes
603 // that the last 4 bits of the CPU cycle count is uniformly distributed.
604 constexpr int kKernelExecutionTrackingInvocationSkipCount = 16;
605 if (is_expensive ||
606 timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) {
607 kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles());
608 }
609 } else {
610 device->Compute(op_kernel, &ctx);
611 }
612 nodestats::SetOpEnd(stats);
613 if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs);
614 s = ProcessOutputs(item, &ctx, outputs->data(), stats);
615 nodestats::SetMemory(stats, &ctx);
616 return s;
617}
618
619template <class PropagatorStateType>
620void ExecutorState<PropagatorStateType>::ProcessAsync(
621 const NodeItem& item, const OpKernelContext::Params& params,
622 const TaggedNode& tagged_node, Entry* first_input,
623 NodeExecStatsInterface* stats, activity_watcher::ActivityId activity_id) {
624 AsyncOpKernel* async_kernel = item.kernel->AsAsync();
625 DCHECK(async_kernel != nullptr);
626 AsyncState* state =
627 new AsyncState(params, tagged_node, &item, first_input, stats);
628
629 auto done = [this, state, activity_id]() {
630 Device* device = immutable_state_.params().device;
631 NodeExecStatsInterface* stats = state->stats; // Shorthand
632 Entry* first_input = state->first_input; // Shorthand
633
634 nodestats::SetOpEnd(stats);
635 EntryVector outputs(state->item->num_outputs);
636 Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats);
637 nodestats::SetMemory(stats, &state->ctx);
638 if (vlog_) {
639 VLOG(2) << "Async kernel done: " << state->item->node_id << " step "
640 << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def())
641 << (state->tagged_node.get_is_dead() ? " is dead" : "")
642 << " device: " << device->name();
643 }
644
645 // Clears inputs.
646 const int num_inputs = state->item->num_inputs;
647 for (int i = 0; i < num_inputs; ++i) {
648 (first_input + i)->ClearVal();
649 }
650 propagator_.MaybeMarkCompleted(state->tagged_node);
651 activity_watcher::ActivityEnd(activity_id);
652 TaggedNodeSeq ready;
653 if (s.ok()) {
654 propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready);
655 }
656 outputs.clear();
657 const bool completed = NodeDone(s, &ready, stats, nullptr);
658 delete state;
659 if (completed) ScheduleFinish();
660 };
661 nodestats::SetOpStart(stats);
662 {
663 profiler::AnnotatedTraceMe activity(
664 [async_kernel, state] {
665 return async_kernel->TraceString(
666 state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled());
667 },
668 profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item)));
669 immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx,
670 std::move(done));
671 }
672}
673
674template <class PropagatorStateType>
675void ExecutorState<PropagatorStateType>::ProcessNoop(
676 NodeExecStatsInterface* stats) {
677 nodestats::SetOpStart(stats);
678 nodestats::SetOpEnd(stats);
679}
680
681template <class PropagatorStateType>
682void ExecutorState<PropagatorStateType>::ProcessConstTensor(
683 const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) {
684 nodestats::SetOpStart(stats);
685 nodestats::SetOpEnd(stats);
686 Entry& output = (*outputs)[0];
687 output.state = Entry::State::HAS_CONST_TENSOR;
688 output.const_tensor = item.const_tensor;
689 output.alloc_attr = item.output_attrs()[0];
690}
691
692template <class PropagatorStateType>
693void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node,
694 int64_t scheduled_nsec) {
695 profiler::TraceMe traceme("ExecutorState::Process Scheduled",
696 profiler::TraceMeLevel::kVerbose);
697 TaggedNodeReadyQueue inline_ready;
698 inline_ready.push_back(tagged_node);
699 return ProcessInline(&inline_ready, scheduled_nsec);
700}
701
702template <class PropagatorStateType>
703void ExecutorState<PropagatorStateType>::ProcessInline(
704 TaggedNodeReadyQueue* inline_ready, int64_t scheduled_nsec) {
705 WithContext wc(context_);
706 TaggedNodeSeq ready;
707
708 // Parameters passed to OpKernel::Compute.
709 TensorValueVec inputs;
710 AllocatorAttributeVec input_alloc_attrs;
711
712 OpKernelContext::Params params;
713 params.step_id = step_id_;
714 // Override device's threadpool if user provides an intra_op_threadpool
715 Device* device = immutable_state_.params().device;
716 if (user_device_) {
717 params.device = user_device_.get();
718 } else {
719 params.device = device;
720 }
721 params.start_time_usecs = start_time_usecs_;
722 params.deadline = deadline_;
723 params.log_memory = log_memory_;
724 params.rendezvous = rendezvous_;
725 params.collective_executor = collective_executor_;
726 params.session_state = session_state_;
727 params.session_handle = session_handle_;
728 params.session_metadata = session_metadata_;
729 params.tensor_store = tensor_store_;
730 params.cancellation_manager = cancellation_manager_;
731 params.coordination_service_agent = coordination_service_agent_;
732 params.stack_trace = stack_trace_;
733 params.call_frame = call_frame_;
734 params.function_library = immutable_state_.params().function_library;
735 params.resource_manager = device->resource_manager();
736 params.step_container = step_container_;
737 params.slice_reader_cache = slice_reader_cache_;
738 params.runner = &runner_;
739 params.run_all_kernels_inline = run_all_kernels_inline_;
740 params.stats_collector = stats_collector_;
741 params.inc_num_deferred_ops_function = [this]() {
742 mutex_lock lock(num_deferred_ops_mu_);
743 num_deferred_ops_++;
744 };
745 params.dec_num_deferred_ops_function = [this]() {
746 bool finish_when_deferred_ops_done = false;
747 {
748 mutex_lock lock(num_deferred_ops_mu_);
749 num_deferred_ops_--;
750 if (num_deferred_ops_ == 0) {
751 finish_when_deferred_ops_done = finish_when_deferred_ops_done_;
752 }
753 }
754 // Invoke Finish if the graph processing has completed. Finish is always
755 // called exactly once per ExecutorState, either here if there are any
756 // deferred ops, or in ScheduleFinish if there aren't any deferred ops.
757 if (finish_when_deferred_ops_done) Finish();
758 };
759
760 // Set the device_context for this device, if it exists.
761 params.op_device_context = device_context_;
762
763 Status s;
764 NodeExecStatsInterface* stats = nullptr;
765
766 EntryVector outputs(1);
767
768 bool completed = false;
769 int64_t last_iter_num = -1;
770 std::unique_ptr<profiler::TraceMeConsumer> iteration_scope;
771 while (!inline_ready->empty()) {
772 TaggedNode tagged_node = inline_ready->front();
773
774 int64_t current_iter_num = tagged_node.get_iter_num();
775 if (current_iter_num != last_iter_num) {
776 iteration_scope = std::make_unique<profiler::TraceMeConsumer>(
777 // From TraceMeProducer in DirectSession::RunInternal,
778 // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run.
779 [&] {
780 // NOTE: This tracing uses the iteration number from the first
781 // tagged node that executes during this call to `Process()`. In
782 // principle, subsequent nodes could have different values of
783 // `iter_num` that will not be traced.
784 return profiler::TraceMeEncode(
785 "ExecutorState::Process",
786 {{"id", step_id_}, {"iter_num", tagged_node.get_iter_num()}});
787 },
788 profiler::ContextType::kTfExecutor, step_id_,
789 profiler::TraceMeLevel::kInfo);
790 last_iter_num = current_iter_num;
791 }
792 inline_ready->pop_front();
793 const NodeItem& item = tagged_node.get_node_item();
794 const int id = item.node_id;
795
796 propagator_.MaybeMarkStarted(tagged_node);
797 const activity_watcher::ActivityId activity_id =
798 activity_watcher::ActivityStart(
799 [&]() {
800 return std::make_unique<activity_watcher::Activity>(
801 "ExecutorState::Process",
802 activity_watcher::ActivityCategory::kMisc,
803 activity_watcher::Activity::Attributes{
804 {"node_name", item.kernel->def().name()},
805 {"op", item.kernel->def().op()},
806 {"iter_num", absl::StrCat(tagged_node.get_iter_num())},
807 {"step_id", absl::StrCat(params.step_id)},
808 {"node_id", absl::StrCat(id)},
809 {"device", device->name()},
810 });
811 },
812 /*level=*/2);
813
814 params.track_allocations = false;
815 stats = nullptr;
816 if (stats_collector_ && !tagged_node.get_is_dead()) {
817 stats = stats_collector_->CreateNodeExecStats(&item.kernel->def());
818 // Track allocations if and only if we are collecting statistics, and
819 // `stats` object is expecting allocations to be tracked.
820 params.track_allocations = stats ? stats->TrackAllocations() : false;
821 nodestats::SetScheduled(stats, scheduled_nsec);
822 nodestats::SetAllStart(stats);
823 }
824
825 if (vlog_) {
826 VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
827 << SummarizeNodeDef(item.kernel->def())
828 << (tagged_node.get_is_dead() ? " is dead" : "")
829 << " device: " << device->name();
830 }
831
832 Entry* first_input = propagator_.GetInputTensors(tagged_node);
833
834 // Only execute this node if it is not dead or it is a send/recv
835 // transfer node. For transfer nodes, we need to propagate the "dead"
836 // bit even when the node is dead.
837 bool launched_asynchronously = false;
838 if (tagged_node.get_is_dead() && !item.is_transfer_node) {
839 if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs);
840 } else if (TF_PREDICT_FALSE(item.is_noop)) {
841 ProcessNoop(stats);
842 } else if (item.const_tensor != nullptr && !params.track_allocations) {
843 ProcessConstTensor(item, &outputs, stats);
844 } else {
845 // Prepares inputs.
846 bool is_input_dead = false;
847 s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs,
848 &is_input_dead);
849 if (!s.ok()) {
850 // Clear inputs.
851 const int num_inputs = item.num_inputs;
852 for (int i = 0; i < num_inputs; ++i) {
853 (first_input + i)->ClearVal();
854 }
855 propagator_.MaybeMarkCompleted(tagged_node);
856 activity_watcher::ActivityEnd(activity_id);
857 // Continue to process the nodes in 'inline_ready'.
858 completed = NodeDone(s, &ready, stats, inline_ready);
859 continue;
860 }
861
862 // Set up compute params.
863 params.op_kernel = item.kernel;
864 params.frame_iter = propagator_.GetFrameAndIter(tagged_node);
865 params.is_input_dead = is_input_dead;
866 params.output_attr_array = item.output_attrs();
867 params.forward_from_array = item.forward_from();
868 params.outputs_required_array = item.outputs_required.get();
869 params.inputs = inputs;
870 params.input_alloc_attrs = input_alloc_attrs;
871
872 if (item.kernel_is_async) {
873 ProcessAsync(item, params, tagged_node, first_input, stats,
874 activity_id);
875 launched_asynchronously = true;
876 } else {
877 s = ProcessSync(item, &params, &outputs, stats);
878 }
879 }
880
881 if (!launched_asynchronously) {
882 if (vlog_) {
883 VLOG(2) << "Synchronous kernel done: " << id << " step "
884 << params.step_id << " " << SummarizeNodeDef(item.kernel->def())
885 << (tagged_node.get_is_dead() ? " is dead: " : "")
886 << " device: " << device->name();
887 }
888
889 // Clears inputs.
890 const int num_inputs = item.num_inputs;
891 for (int i = 0; i < num_inputs; ++i) {
892 (first_input + i)->ClearVal();
893 }
894 propagator_.MaybeMarkCompleted(tagged_node);
895 activity_watcher::ActivityEnd(activity_id);
896 // Propagates outputs.
897 if (s.ok()) {
898 propagator_.PropagateOutputs(tagged_node, &outputs, &ready);
899 }
900
901 // Clear outputs without deallocating the `outputs` vector.
902 const int num_outputs = item.num_outputs;
903 for (int i = 0; i < num_outputs; ++i) {
904 outputs[i].ClearVal();
905 }
906
907 if (stats) {
908 scheduled_nsec = nodestats::NowInNsec();
909 }
910 // Postprocess.
911 completed = NodeDone(s, &ready, stats, inline_ready);
912 }
913 } // while !inline_ready.empty()
914
915 // This thread of computation is done if completed = true.
916 if (completed) ScheduleFinish();
917}
918
919template <class PropagatorStateType>
920Status ExecutorState<PropagatorStateType>::PrepareInputs(
921 const NodeItem& item, Entry* first_input, TensorValueVec* inputs,
922 AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) {
923 inputs->resize(item.num_inputs);
924 input_alloc_attrs->resize(item.num_inputs);
925
926 *is_input_dead = false;
927
928 for (int i = 0; i < item.num_inputs; ++i) {
929 const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) &&
930 IsRefType(item.input_type(i));
931 Entry* entry = first_input + i;
932 (*input_alloc_attrs)[i] = entry->alloc_attr;
933
934 // i-th input.
935 TensorValue* inp = &(*inputs)[i];
936
937 switch (entry->state) {
938 case Entry::State::NO_VALUE: {
939 // Only merge and transfer nodes can have no-value inputs.
940 inp->mutex_if_ref = nullptr;
941 if (item.is_merge) {
942 inp->tensor = nullptr;
943 } else {
944 DCHECK(item.is_transfer_node)
945 << item.kernel->name() << " - input " << i;
946 entry->state = Entry::State::HAS_CONST_TENSOR;
947 entry->const_tensor = kEmptyTensor;
948 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
949 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
950 // accessors making dynamic checks that prevent using an immutable
951 // tensor as a mutable tensor.
952 inp->tensor = const_cast<Tensor*>(kEmptyTensor);
953 *is_input_dead = true;
954 }
955 break;
956 }
957
958 case Entry::State::HAS_VALUE: {
959 if (TF_PREDICT_FALSE(expect_ref)) {
960 return AttachDef(
961 errors::InvalidArgument(i, "-th input expects a ref type"),
962 item.kernel->def());
963 }
964 inp->mutex_if_ref = nullptr;
965 inp->tensor = entry->val.get();
966 break;
967 }
968
969 case Entry::State::HAS_CONST_TENSOR: {
970 if (TF_PREDICT_FALSE(expect_ref)) {
971 return AttachDef(
972 errors::InvalidArgument(i, "-th input expects a ref type"),
973 item.kernel->def());
974 }
975 // NOTE(mrry): This `const_cast` is necessary because `TensorValue`
976 // stores a non-const `Tensor*`, and relies on the `OpKernelContext`
977 // accessors making dynamic checks that prevent using an immutable
978 // tensor as a mutable tensor.
979 inp->mutex_if_ref = nullptr;
980 inp->tensor = const_cast<Tensor*>(entry->const_tensor);
981 break;
982 }
983
984 case Entry::State::HAS_REF_TENSOR: {
985 {
986 tf_shared_lock ml(*entry->ref_tensor.mu);
987 if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() &&
988 !item.is_initialization_op)) {
989 return AttachDef(errors::FailedPrecondition(
990 "Attempting to use uninitialized value ",
991 item.kernel->requested_input(i)),
992 item.kernel->def());
993 }
994 }
995
996 if (expect_ref) {
997 inp->mutex_if_ref = entry->ref_tensor.mu;
998 inp->tensor = entry->ref_tensor.tensor;
999 } else {
1000 // Automatically deref the tensor ref when the op expects a
1001 // tensor but is given a ref to a tensor. Need to deref it
1002 // under the mutex.
1003 {
1004 mutex* ref_mu = entry->ref_tensor.mu;
1005 Tensor* ref_tensor = entry->ref_tensor.tensor;
1006 tf_shared_lock l(*ref_mu);
1007 entry->val.Init(*ref_tensor);
1008 }
1009 entry->state = Entry::State::HAS_VALUE;
1010
1011 inp->mutex_if_ref = nullptr;
1012 inp->tensor = entry->val.get();
1013 // The dtype of entry->ref_tensor.tensor could have been changed by
1014 // another operation that ran after the operation that "produced" it
1015 // executed, so re-validate that the type of the dereferenced tensor
1016 // matches the expected input type.
1017 if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) {
1018 return AttachDef(
1019 errors::InvalidArgument(
1020 i, "-th input expects type ",
1021 DataTypeString(item.input_type(i)),
1022 " but automatically dereferenced input tensor has type ",
1023 DataTypeString(inp->tensor->dtype())),
1024 item.kernel->def());
1025 }
1026 }
1027 break;
1028 }
1029 }
1030 }
1031 return OkStatus();
1032}
1033
1034template <class PropagatorStateType>
1035Status ExecutorState<PropagatorStateType>::ProcessOutputs(
1036 const NodeItem& item, OpKernelContext* ctx, Entry* outputs,
1037 NodeExecStatsInterface* stats) {
1038 Status s = ctx->status();
1039 if (!s.ok()) {
1040 s = AttachDef(s, item.kernel->def());
1041 // TODO(misard) Replace with a finer-grain enabling flag once we
1042 // add better optional debugging support.
1043 if (vlog_ && VLOG_IS_ON(1)) {
1044 LOG(WARNING) << this << " Compute status: " << s;
1045 }
1046 if (s.code() == error::RESOURCE_EXHAUSTED) {
1047 if (stats_collector_) {
1048 string err = stats_collector_->ReportAllocsOnResourceExhausted(
1049 s.error_message());
1050 s = errors::CreateWithUpdatedMessage(
1051 s, strings::StrCat(s.error_message(), err));
1052 } else {
1053 s = errors::CreateWithUpdatedMessage(
1054 s,
1055 strings::StrCat(
1056 s.error_message(),
1057 "\nHint: If you want to see a list of allocated tensors when "
1058 "OOM happens, add report_tensor_allocations_upon_oom "
1059 "to RunOptions for current allocation info. This isn't "
1060 "available when running in Eager mode.\n"));
1061 }
1062 } else if (s.code() == error::UNAVAILABLE &&
1063 !item.is_distributed_communication) {
1064 s = errors::ReplaceErrorFromNonCommunicationOps(s, item.kernel->name());
1065 }
1066 return s;
1067 }
1068
1069 for (int i = 0; i < item.num_outputs; ++i) {
1070 const TensorValue val = ctx->release_output(i);
1071 Entry* out = &outputs[i];
1072 DCHECK(out->state == Entry::State::NO_VALUE);
1073
1074 if (val.tensor == nullptr) {
1075 // Unless it's a Switch or a Recv, or the executor has marked the output
1076 // as not required, the node must produce a tensor value at i-th output.
1077 if (!(item.is_recv_or_switch ||
1078 (item.outputs_required && !item.outputs_required[i]))) {
1079 s.Update(errors::Internal("Missing ", i, "-th output from ",
1080 FormatNodeDefForError(item.kernel->def())));
1081 }
1082 } else {
1083 // Set the allocator attributes of the output entry.
1084 out->alloc_attr = ctx->output_alloc_attr(i);
1085
1086 // Sanity check of output tensor types. We need to inspect this safely as
1087 // we are in the tensor buffer.
1088 DataType dtype = val.dtype_safe();
1089 if (dtype == item.output_type(i)) {
1090 if (stats && val.tensor->IsInitialized()) {
1091 nodestats::SetOutput(stats, i, val.tensor);
1092 }
1093 if (val.is_ref()) {
1094 out->state = Entry::State::HAS_REF_TENSOR;
1095 out->ref_tensor.tensor = val.tensor;
1096 out->ref_tensor.mu = val.mutex_if_ref;
1097 if (log_memory_) {
1098 Tensor to_log;
1099 {
1100 // Dereference the tensor under the lock.
1101 tf_shared_lock l(*out->ref_tensor.mu);
1102 to_log = *out->ref_tensor.tensor;
1103 }
1104 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1105 ctx->step_id(), i, to_log);
1106 }
1107 } else {
1108 // NOTE that std::move is used here, so val.tensor goes to
1109 // uninitialized state (val.tensor->IsInitialized return false).
1110 out->state = Entry::State::HAS_VALUE;
1111 out->val.Init(std::move(*val.tensor));
1112 if (log_memory_) {
1113 LogMemory::RecordTensorOutput(ctx->op_kernel().name(),
1114 ctx->step_id(), i, *out->val);
1115 }
1116 }
1117 } else {
1118 s.Update(
1119 errors::Internal("Output ", i, " of type ", DataTypeString(dtype),
1120 " does not match declared output type ",
1121 DataTypeString(item.output_type(i)), " for node ",
1122 FormatNodeDefForError(item.kernel->def())));
1123 }
1124 }
1125 if (!val.is_ref()) {
1126 // If OpKernelContext returns outputs via pass-by-value, we
1127 // don't need this trouble.
1128 delete val.tensor;
1129 }
1130 }
1131 return s;
1132}
1133
1134template <class PropagatorStateType>
1135bool ExecutorState<PropagatorStateType>::NodeDone(
1136 const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats,
1137 TaggedNodeReadyQueue* inline_ready) {
1138 if (stats) {
1139 nodestats::SetAllEnd(stats);
1140 DCHECK_NE(stats_collector_, nullptr);
1141 stats->Done(immutable_state_.params().device->name());
1142 }
1143
1144 if (TF_PREDICT_TRUE(s.ok())) {
1145 const size_t ready_size = ready->size();
1146 if (ready_size == 0) {
1147 return num_outstanding_ops_.fetch_sub(1) == 1;
1148 } else {
1149 // NOTE: Avoid touching the atomic counter if only one node becomes ready.
1150 if (ready_size > 1) {
1151 num_outstanding_ops_.fetch_add(ready_size - 1,
1152 std::memory_order_relaxed);
1153 }
1154
1155 // Schedule the ready nodes in 'ready'.
1156 ScheduleReady(ready, inline_ready);
1157
1158 return false;
1159 }
1160 } else {
1161 bool abort_run = false;
1162 Status maybe_derived_s(s);
1163
1164 // Some error happened. This thread of computation is done.
1165 {
1166 mutex_lock l(mu_);
1167 if (status_.ok()) {
1168 // If this is the first node to fail in this run, we are responsible for
1169 // aborting all other execution in the step.
1170 abort_run = true;
1171
1172 // If execution has been cancelled, mark cancelled or aborted errors as
1173 // being derived. Note that the original node that fails might also
1174 // trigger cancellation, and here we make sure the original error is
1175 // exposed to users and not buried as a derived error.
1176 if (cancellation_manager_ && cancellation_manager_->IsCancelled() &&
1177 (errors::IsCancelled(s) || errors::IsAborted(s))) {
1178 status_ = StatusGroup::MakeDerived(s);
1179 maybe_derived_s = status_;
1180 } else {
1181 status_ = s;
1182 }
1183 }
1184 }
1185
1186 if (abort_run) {
1187 TRACEPRINTF("StartAbort: %s", s.ToString().c_str());
1188 if (cancellation_manager_) {
1189 // Only log when the abort happens during the actual run time.
1190 // Use VLOG instead of LOG(warning) because error status is expected
1191 // when the executor is run under the grappler optimization phase or
1192 // when iterating through a tf.data input pipeline.
1193 VLOG(1) << "[" << immutable_state_.params().device->name()
1194 << "] Executor start aborting: " << s;
1195 }
1196
1197 if (rendezvous_) {
1198 rendezvous_->StartAbort(s);
1199 }
1200 if (cancellation_manager_) {
1201 cancellation_manager_->StartCancelWithStatus(maybe_derived_s);
1202 } else if (collective_executor_) {
1203 // If there's cancellation_manager_, collective ops aborts
1204 // collective_executor_ upon cancellation; otherwise we need to abort
1205 // here.
1206 collective_executor_->StartAbort(s);
1207 }
1208 }
1209
1210 return num_outstanding_ops_.fetch_sub(1) == 1;
1211 }
1212}
1213
1214template <class PropagatorStateType>
1215void ExecutorState<PropagatorStateType>::ScheduleReady(
1216 TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) {
1217 profiler::TraceMe activity(
1218 [&]() {
1219 return strings::StrCat(
1220 "ExecutorState::ScheduleReady#",
1221 "ready_size=", (ready == nullptr ? -1 : ready->size()),
1222 ",inline_ready_size=",
1223 (inline_ready == nullptr ? -1 : inline_ready->size()), "#");
1224 },
1225 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
1226 DCHECK(!ready->empty());
1227
1228 int64_t scheduled_nsec = 0;
1229 if (stats_collector_) {
1230 scheduled_nsec = nodestats::NowInNsec();
1231 }
1232
1233 if (run_all_kernels_inline_) {
1234 if (inline_ready == nullptr) {
1235 // Schedule all ready kernels from a single closure. This ensure that,
1236 // regardless of the `runner_` implementation, all kernels will run
1237 // sequentially on the same thread, and thread wakeup overhead and
1238 // executor mutex contention will be minimized.
1239 RunTask([this, ready = std::move(*ready), scheduled_nsec]() {
1240 for (auto& tagged_node : ready) {
1241 Process(tagged_node, scheduled_nsec);
1242 }
1243 });
1244 } else {
1245 for (auto& tagged_node : *ready) {
1246 inline_ready->push_back(tagged_node);
1247 }
1248 }
1249 } else {
1250 const TaggedNode* curr_expensive_node = nullptr;
1251 TaggedNodeSeq expensive_nodes;
1252 if (inline_ready == nullptr) {
1253 // Schedule to run all the ready ops in thread pool.
1254 for (auto& tagged_node : *ready) {
1255 RunTask([=]() { Process(tagged_node, scheduled_nsec); },
1256 /*sample_rate=*/ready->size());
1257 }
1258 } else {
1259 for (auto& tagged_node : *ready) {
1260 const NodeItem& item = *tagged_node.node_item;
1261 if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) {
1262 // Inline this inexpensive node.
1263 inline_ready->push_back(tagged_node);
1264 } else {
1265 if (curr_expensive_node) {
1266 expensive_nodes.push_back(*curr_expensive_node);
1267 }
1268 curr_expensive_node = &tagged_node;
1269 }
1270 }
1271 }
1272 if (curr_expensive_node) {
1273 if (inline_ready->empty()) {
1274 inline_ready->push_back(*curr_expensive_node);
1275 } else {
1276 // There are inline nodes to run already. We dispatch this expensive
1277 // node to other thread.
1278 expensive_nodes.push_back(*curr_expensive_node);
1279 }
1280 }
1281 if (!expensive_nodes.empty()) {
1282 if (expensive_nodes.size() < kInlineScheduleReadyThreshold) {
1283 for (auto& tagged_node : expensive_nodes) {
1284 RunTask(std::bind(&ExecutorState::Process, this, tagged_node,
1285 scheduled_nsec),
1286 /*sample_rate=*/expensive_nodes.size());
1287 }
1288 } else {
1289 // There are too many ready expensive nodes. Schedule them in child
1290 // threads.
1291 // TODO(fishx): Apply the same optimization to cheap ops as well since
1292 // executing lots of cheap ops in one thread can potentially be the
1293 // bottleneck as well.
1294 auto it = expensive_nodes.begin();
1295 while (it < expensive_nodes.end()) {
1296 auto end = it;
1297 std::advance(end, kInlineScheduleReadyThreshold);
1298 if (end > expensive_nodes.end()) {
1299 end = expensive_nodes.end();
1300 }
1301 TaggedNodeSeq ready_chunk{it, end};
1302 RunTask(
1303 [this, ready_chunk = std::move(ready_chunk), scheduled_nsec]() {
1304 profiler::TraceMe activity(
1305 [&]() {
1306 return strings::StrCat(
1307 "ExecutorState::ScheduleReady::"
1308 "ChildThreadExpensiveNodes#",
1309 "ready_chunk_size=", ready_chunk.size(), "#");
1310 },
1311 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
1312 for (auto& tagged_node : ready_chunk) {
1313 RunTask(std::bind(&ExecutorState::Process, this, tagged_node,
1314 scheduled_nsec),
1315 /*sample_rate=*/ready_chunk.size());
1316 }
1317 });
1318 it = end;
1319 }
1320 }
1321 }
1322 }
1323 ready->clear();
1324}
1325
1326template <class PropagatorStateType>
1327void ExecutorState<PropagatorStateType>::ScheduleFinish() {
1328 // Checks condition to decide if needs to invoke Finish(). If there are
1329 // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke
1330 // Finish(). Otherwise, invoke Finish() directly.
1331 // Note that it is critical that the ScheduleFinish / Finish codepath does not
1332 // block, otherwise we might deadlock. See b/124523000 for details.
1333 {
1334 mutex_lock lock(num_deferred_ops_mu_);
1335 if (num_deferred_ops_ > 0) {
1336 finish_when_deferred_ops_done_ = true;
1337 return;
1338 }
1339 }
1340 // Finish is always called exactly once per ExecutorState, either here if
1341 // there aren't any deferred ops, or in the dec_num_deferred_ops_function if
1342 // there are deferred ops.
1343 Finish();
1344}
1345
1346template <class PropagatorStateType>
1347void ExecutorState<PropagatorStateType>::Finish() {
1348 mu_.lock();
1349 auto status = status_;
1350 auto done_cb = std::move(done_cb_);
1351 auto runner = std::move(runner_);
1352 mu_.unlock();
1353 int64_t step_id = step_id_;
1354 CHECK(done_cb != nullptr);
1355 Device* device = immutable_state_.params().device;
1356
1357 if (vlog_ && !status.ok() && VLOG_IS_ON(1)) {
1358 // Logs verbose information about the current state of active and pending
1359 // nodes in the propagator.
1360 propagator_.DumpState();
1361 }
1362
1363 // There are several potential race conditions below. To name a few:
1364 // 1. Even if the device's status is OK at the precise moment when
1365 // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
1366 // is called below, caused by work enqueued onto the same device by other
1367 // concurrent ExecutorState objects.
1368 // 2. Some implementations of Device::RefreshStatus, such as
1369 // XlaDevice::RefreshStatus, may be inherently racy because it releases the
1370 // device mutex after a stream pointer is acquired and before the stream is
1371 // queried for status.
1372 // 3. It's the same for some implementations of Device::Sync, such as
1373 // XlaDevice::Sync.
1374 //
1375 // However, these race conditions are acceptable because a stream (and
1376 // therefore an XlaDevice) can only go from OK to not-OK, never the opposite,
1377 // which means we will at worst report errors when there isn't any, never the
1378 // opposite.
1379
1380 // An early exit for devices don't allow sync on completion. Ops that run on
1381 // these devices should have used num_deferred_ops correctly to ensure the
1382 // device has finished all relevant work at this point.
1383 if (!device->AllowsSyncOnCompletion()) {
1384 status.Update(device->RefreshStatus());
1385 if (!status.ok()) {
1386 // In device async execution mode, it's possible for device execution to
1387 // lag behind ExecutorState scheduling so much that this is the first
1388 // place a device execution error surfaces.
1389 // If so, all ExecutorState::NodeDone calls have already happened with OK
1390 // status. This is the last defense where StartCancel must be called to
1391 // abort all computation still running on any device.
1392 // TODO(b/124523000): Always call Finish in a separate thread, so even if
1393 // StartCancel blocks the current thread's execution, we won't encounter
1394 // deadlocks caused by inter-op thread exhaustion.
1395 if (rendezvous_) {
1396 rendezvous_->StartAbort(status);
1397 }
1398 if (cancellation_manager_) {
1399 cancellation_manager_->StartCancelWithStatus(status);
1400 } else if (collective_executor_) {
1401 // If there's cancellation_manager_, collective ops aborts
1402 // collective_executor_ upon cancellation; otherwise we need to abort
1403 // here.
1404 collective_executor_->StartAbort(status);
1405 }
1406 }
1407 delete this;
1408 runner([step_id, status, done_cb = std::move(done_cb)]() {
1409 profiler::TraceMeConsumer activity(
1410 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1411 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1412 [&] {
1413 return profiler::TraceMeEncode("ExecutorDoneCallback",
1414 {{"id", step_id}});
1415 },
1416 profiler::ContextType::kTfExecutor, step_id,
1417 profiler::TraceMeLevel::kInfo);
1418 done_cb(status);
1419 });
1420 return;
1421 }
1422
1423 if (sync_on_finish_ && status.ok()) {
1424 // Block until the device has finished all queued operations. For
1425 // devices like GPUs that continue to execute Ops after their Compute
1426 // methods have completed, this ensures that control is not returned to
1427 // the user until the step (and its side-effects) has actually completed.
1428 device->Sync([this, step_id, runner = std::move(runner),
1429 done_cb = std::move(done_cb)](const Status& status) mutable {
1430 delete this;
1431 runner([step_id, status, done_cb = std::move(done_cb)]() {
1432 profiler::TraceMeConsumer activity(
1433 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1434 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1435 [&] {
1436 return profiler::TraceMeEncode("ExecutorDoneCallback",
1437 {{"id", step_id}});
1438 },
1439 profiler::ContextType::kTfExecutor, step_id,
1440 profiler::TraceMeLevel::kInfo);
1441 done_cb(status);
1442 });
1443 });
1444 } else {
1445 delete this;
1446 runner([step_id, status, done_cb = std::move(done_cb)]() {
1447 profiler::TraceMeConsumer activity(
1448 // From TraceMeProducer in KernelAndDeviceFunc::RunAsync,
1449 // DirectSession::RunInternal or GraphMgr::ExecuteAsync.
1450 [&] {
1451 return profiler::TraceMeEncode("ExecutorDoneCallback",
1452 {{"id", step_id}});
1453 },
1454 profiler::ContextType::kTfExecutor, step_id,
1455 profiler::TraceMeLevel::kInfo);
1456 done_cb(status);
1457 });
1458 }
1459}
1460
1461void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) {
1462 if (OpOrderDeterminismRequired()) {
1463 (new ExecutorState<OrderedPropagatorState>(args, immutable_state_,
1464 &kernel_stats_))
1465 ->RunAsync(std::move(done));
1466 } else if (immutable_state_.requires_control_flow_support()) {
1467 (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_))
1468 ->RunAsync(std::move(done));
1469 } else {
1470 (new ExecutorState<SimplePropagatorState>(args, immutable_state_,
1471 &kernel_stats_))
1472 ->RunAsync(std::move(done));
1473 }
1474}
1475
1476} // namespace
1477
1478Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph,
1479 Executor** executor) {
1480 ExecutorImpl* impl = new ExecutorImpl(params);
1481 const Status s = impl->Initialize(graph);
1482 if (s.ok()) {
1483 *executor = impl;
1484 } else {
1485 delete impl;
1486 }
1487 return s;
1488}
1489
1490Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib,
1491 const std::shared_ptr<const NodeProperties>& props,
1492 int graph_def_version, OpKernel** kernel) {
1493 const auto device_type = DeviceType(device->attributes().device_type());
1494 auto allocator = device->GetAllocator(AllocatorAttributes());
1495 return CreateOpKernel(device_type, device, allocator, flib,
1496 device->resource_manager(), props, graph_def_version,
1497 kernel);
1498}
1499
1500void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; }
1501
1502namespace {
1503
1504class DefaultExecutorRegistrar {
1505 public:
1506 DefaultExecutorRegistrar() {
1507 Factory* factory = new Factory;
1508 ExecutorFactory::Register("", factory);
1509 ExecutorFactory::Register("DEFAULT", factory);
1510 }
1511
1512 private:
1513 class Factory : public ExecutorFactory {
1514 Status NewExecutor(const LocalExecutorParams& params, const Graph& graph,
1515 std::unique_ptr<Executor>* out_executor) override {
1516 Executor* ret = nullptr;
1517 TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret));
1518 out_executor->reset(ret);
1519 return OkStatus();
1520 }
1521 };
1522};
1523static DefaultExecutorRegistrar registrar;
1524
1525} // namespace
1526
1527} // namespace tensorflow
1528