1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/executor.h" |
17 | |
18 | #include <algorithm> |
19 | #include <atomic> |
20 | #include <memory> |
21 | #include <vector> |
22 | |
23 | #include "absl/memory/memory.h" |
24 | #include "absl/time/time.h" |
25 | #include "absl/types/optional.h" |
26 | #include "tensorflow/core/activity_watcher/activity.h" |
27 | #include "tensorflow/core/common_runtime/costmodel_manager.h" |
28 | #include "tensorflow/core/common_runtime/entry.h" |
29 | #include "tensorflow/core/common_runtime/executor_factory.h" |
30 | #include "tensorflow/core/common_runtime/graph_view.h" |
31 | #include "tensorflow/core/common_runtime/immutable_executor_state.h" |
32 | #include "tensorflow/core/common_runtime/pending_counts.h" |
33 | #include "tensorflow/core/common_runtime/propagator_state.h" |
34 | #include "tensorflow/core/common_runtime/renamed_device.h" |
35 | #include "tensorflow/core/common_runtime/simple_propagator_state.h" |
36 | #include "tensorflow/core/common_runtime/step_stats_collector.h" |
37 | #include "tensorflow/core/framework/allocator.h" |
38 | #include "tensorflow/core/framework/cancellation.h" |
39 | #include "tensorflow/core/framework/collective.h" |
40 | #include "tensorflow/core/framework/control_flow.h" |
41 | #include "tensorflow/core/framework/device_attributes.pb.h" |
42 | #include "tensorflow/core/framework/log_memory.h" |
43 | #include "tensorflow/core/framework/metrics.h" |
44 | #include "tensorflow/core/framework/node_def_util.h" |
45 | #include "tensorflow/core/framework/op_kernel.h" |
46 | #include "tensorflow/core/framework/op_segment.h" |
47 | #include "tensorflow/core/framework/tensor.h" |
48 | #include "tensorflow/core/framework/tensor_reference.h" |
49 | #include "tensorflow/core/framework/types.h" |
50 | #include "tensorflow/core/framework/types.pb.h" |
51 | #include "tensorflow/core/graph/edgeset.h" |
52 | #include "tensorflow/core/graph/graph.h" |
53 | #include "tensorflow/core/graph/graph_node_util.h" |
54 | #include "tensorflow/core/lib/core/errors.h" |
55 | #include "tensorflow/core/lib/core/notification.h" |
56 | #include "tensorflow/core/lib/core/status.h" |
57 | #include "tensorflow/core/lib/core/threadpool.h" |
58 | #include "tensorflow/core/lib/gtl/flatmap.h" |
59 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
60 | #include "tensorflow/core/lib/gtl/manual_constructor.h" |
61 | #include "tensorflow/core/lib/hash/hash.h" |
62 | #include "tensorflow/core/platform/context.h" |
63 | #include "tensorflow/core/platform/env.h" |
64 | #include "tensorflow/core/platform/errors.h" |
65 | #include "tensorflow/core/platform/logging.h" |
66 | #include "tensorflow/core/platform/macros.h" |
67 | #include "tensorflow/core/platform/mutex.h" |
68 | #include "tensorflow/core/platform/profile_utils/cpu_utils.h" |
69 | #include "tensorflow/core/platform/status.h" |
70 | #include "tensorflow/core/platform/thread_annotations.h" |
71 | #include "tensorflow/core/platform/tracing.h" |
72 | #include "tensorflow/core/platform/types.h" |
73 | #include "tensorflow/core/profiler/lib/annotated_traceme.h" |
74 | #include "tensorflow/core/profiler/lib/connected_traceme.h" |
75 | #include "tensorflow/core/profiler/lib/scoped_annotation.h" |
76 | #include "tensorflow/core/profiler/lib/traceme.h" |
77 | #include "tensorflow/core/profiler/lib/traceme_encode.h" |
78 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
79 | #include "tensorflow/core/util/determinism.h" |
80 | #include "tensorflow/core/util/managed_stack_trace.h" |
81 | #include "tensorflow/core/util/tensor_slice_reader_cache.h" |
82 | |
83 | namespace tensorflow { |
84 | |
85 | namespace { |
86 | |
87 | // 1-D, 0 element tensor. |
88 | static const Tensor* const kEmptyTensor = new Tensor; |
89 | |
90 | // Helper routines for collecting step stats. |
91 | namespace nodestats { |
92 | inline int64_t NowInNsec() { return EnvTime::NowNanos(); } |
93 | |
94 | void SetScheduled(NodeExecStatsInterface* stats, int64_t micros) { |
95 | if (!stats) return; |
96 | stats->SetScheduled(micros * EnvTime::kMicrosToNanos); |
97 | } |
98 | |
99 | void SetAllStart(NodeExecStatsInterface* stats) { |
100 | if (!stats) return; |
101 | stats->RecordExecutorStarted(); |
102 | } |
103 | |
104 | void SetOpStart(NodeExecStatsInterface* stats) { |
105 | if (!stats) return; |
106 | stats->RecordComputeStarted(); |
107 | } |
108 | |
109 | void SetOpEnd(NodeExecStatsInterface* stats) { |
110 | if (!stats) return; |
111 | stats->RecordComputeEnded(); |
112 | } |
113 | |
114 | void SetAllEnd(NodeExecStatsInterface* stats) { |
115 | if (!stats) return; |
116 | stats->RecordExecutorEnded(); |
117 | } |
118 | |
119 | void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) { |
120 | if (!stats) return; |
121 | stats->SetOutput(slot, v); |
122 | } |
123 | |
124 | void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { |
125 | if (!stats) return; |
126 | stats->SetMemory(ctx); |
127 | } |
128 | |
129 | } // namespace nodestats |
130 | |
131 | // Time the execution of kernels (in CPU cycles). Used to dynamically identify |
132 | // inexpensive kernels which can be dispatched inline. |
133 | struct KernelTimer { |
134 | uint64 start_cycles = profile_utils::CpuUtils::GetCurrentClockCycle(); |
135 | |
136 | uint64 ElapsedCycles() { |
137 | return profile_utils::CpuUtils::GetCurrentClockCycle() - start_cycles; |
138 | } |
139 | }; |
140 | |
141 | // TODO(b/152925936): Re-evaluate these constants with current usage patterns. |
142 | typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; |
143 | typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; |
144 | |
145 | class ExecutorImpl : public Executor { |
146 | public: |
147 | explicit ExecutorImpl(const LocalExecutorParams& p) : immutable_state_(p) {} |
148 | |
149 | Status Initialize(const Graph& graph) { |
150 | TF_RETURN_IF_ERROR(immutable_state_.Initialize(graph)); |
151 | kernel_stats_.Initialize(immutable_state_.graph_view()); |
152 | return OkStatus(); |
153 | } |
154 | |
155 | void RunAsync(const Args& args, DoneCallback done) override; |
156 | |
157 | private: |
158 | template <class PropagatorStateType> |
159 | friend class ExecutorState; |
160 | |
161 | // Stores execution time information about the kernels in an executor's graph. |
162 | class KernelStats { |
163 | public: |
164 | KernelStats() = default; |
165 | |
166 | void Initialize(const GraphView& gview) { |
167 | is_expensive_.resize(gview.num_nodes()); |
168 | cost_estimates_ = |
169 | std::make_unique<std::atomic_uint_fast64_t[]>(gview.num_nodes()); |
170 | for (int32_t i = 0; i < gview.num_nodes(); ++i) { |
171 | if (gview.node(i)) { |
172 | is_expensive_[i] = |
173 | gview.node(i)->kernel && gview.node(i)->kernel->IsExpensive(); |
174 | cost_estimates_[i] = kInitialCostEstimateCycles; |
175 | } |
176 | } |
177 | } |
178 | |
179 | // Returns true iff the given node is considered "expensive". The |
180 | // executor uses this flag to optimize graph execution, for example |
181 | // by "inlining" inexpensive kernels. |
182 | bool IsExpensive(const NodeItem& node) const { |
183 | return is_expensive_[node.node_id] && |
184 | (cost_estimates_[node.node_id].load(std::memory_order_relaxed) > |
185 | kOpIsExpensiveThresholdCycles); |
186 | } |
187 | |
188 | // Returns the value of kernel->IsExpensive(). |
189 | bool HasExpensiveMarker(const NodeItem& node) const { |
190 | return is_expensive_[node.node_id]; |
191 | } |
192 | |
193 | // Updates the dynamic cost estimate, which is used to determine whether the |
194 | // given node is expensive. The new cost estimate is a weighted average of |
195 | // the old cost estimate and the latest cost. We only update cost estimates |
196 | // for kernels for which IsExpensive() return true. |
197 | void UpdateCostEstimate(const NodeItem& node, uint64 elapsed_cycles) { |
198 | // N.B. Updates to `cost_estimate` are atomic but unlocked. Simultaneous |
199 | // updates may result in one or more updates being ignored. This does not |
200 | // affect correctness but may slow down the update frequency. |
201 | std::atomic_uint_fast64_t& cost_estimate = cost_estimates_[node.node_id]; |
202 | auto prev_estimate = cost_estimate.load(std::memory_order_relaxed); |
203 | |
204 | uint64 new_estimate = |
205 | ((kCostDecay - 1) * prev_estimate + elapsed_cycles) / kCostDecay; |
206 | |
207 | cost_estimate.store(new_estimate, std::memory_order_relaxed); |
208 | } |
209 | |
210 | private: |
211 | // Initial time (in CPU cycles) we expect an operation to take. Used to |
212 | // determine whether an operation should be place in a threadpool. |
213 | // Operations start out "expensive". |
214 | static constexpr uint64 kInitialCostEstimateCycles = 100 * 1000 * 1000; |
215 | static constexpr uint64 kOpIsExpensiveThresholdCycles = 8000; |
216 | static constexpr uint64 kCostDecay = 10; |
217 | |
218 | std::vector<bool> is_expensive_; |
219 | // std::unique_ptr<std::atomic<bool>[]> is_expensive_; |
220 | std::unique_ptr<std::atomic_uint_fast64_t[]> cost_estimates_; |
221 | }; |
222 | |
223 | ImmutableExecutorState immutable_state_; |
224 | KernelStats kernel_stats_; |
225 | |
226 | TF_DISALLOW_COPY_AND_ASSIGN(ExecutorImpl); |
227 | }; |
228 | |
229 | // The state associated with one invocation of ExecutorImpl::Run. |
230 | // |
231 | // ExecutorState dispatches nodes when they become ready, and delegates to an |
232 | // instance of `PropagatorStateType` to keep track of how many predecessors of a |
233 | // are still pending. |
234 | // |
235 | // The template argument `class PropagatorStateType` must define the following |
236 | // public members: |
237 | // * A type `TaggedNode`, representing a node to be processed, with public |
238 | // members: |
239 | // * `const NodeItem& get_node_item() const` |
240 | // * `bool get_is_dead() const` |
241 | // * A type `TaggedNodeReadyQueue`, representing a queue of nodes to be |
242 | // processed, with public members (having the same meanings as in an |
243 | // `std::vector<TaggedNode>`): |
244 | // * `void push_back(const TaggedNode& node)` |
245 | // * `TaggedNode front() const` |
246 | // * `void pop_front()` |
247 | // * `bool empty() const` |
248 | // * A type `TaggedNodeSeq`, representing a list of nodes to be scheduled, with |
249 | // public members (having the same meanings as in an |
250 | // `std::vector<TaggedNode>`): |
251 | // * `size_t size() const` |
252 | // * `bool empty() const` |
253 | // * `void clear()` |
254 | // * `const_iterator begin() const` |
255 | // * `const_iterator end() const` |
256 | // * A public constructor, `PropagatorStateType(const ImmutableExecutorState& |
257 | // immutable_state, int64 step_id)`. |
258 | // * The following public methods: |
259 | // * `void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, |
260 | // TaggedNodeSeq* ready)`, which creates `TaggedNode` instances for the |
261 | // nodes in `roots` and adds them to `*ready` |
262 | // * `void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* |
263 | // outputs, TaggedNodeSeq* ready)`, which propagates `outputs` from the |
264 | // given `tagged_node` to the destinations of its output edges, and adds |
265 | // any newly runnable nodes to `*ready` |
266 | // * `Entry* GetInputTensors(const TaggedNode& tagged_node) const`, which |
267 | // returns a pointer to the input tensors for the given `tagged_node` |
268 | // * `FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const`, |
269 | // which creates a `FrameAndIter` for the given `tagged_node` |
270 | // * `void DumpState()`, which dumps the dynamic state of the executing graph |
271 | // * `void MaybeMarkStarted(const TaggedNode& tagged_node)`, which records |
272 | // that a node has started |
273 | // * `void MaybeMarkCompleted(const TaggedNode& tagged_node)`, which records |
274 | // that a node has completed |
275 | // |
276 | // See `PropagatorState` in "./propagator_state.h" for an example of a type that |
277 | // can be used to instantiate `PropagatorStateType`. |
278 | template <class PropagatorStateType> |
279 | class ExecutorState { |
280 | public: |
281 | ExecutorState(const Executor::Args& args, |
282 | const ImmutableExecutorState& immutable_state_, |
283 | ExecutorImpl::KernelStats* kernel_stats_); |
284 | ~ExecutorState(); |
285 | |
286 | void RunAsync(Executor::DoneCallback done); |
287 | |
288 | private: |
289 | // Use `TaggedNode` types defined by `PropagatorStateType`. |
290 | typedef typename PropagatorStateType::TaggedNode TaggedNode; |
291 | typedef |
292 | typename PropagatorStateType::TaggedNodeReadyQueue TaggedNodeReadyQueue; |
293 | typedef typename PropagatorStateType::TaggedNodeSeq TaggedNodeSeq; |
294 | |
295 | struct AsyncState; |
296 | |
297 | // Process a ready node in current thread. |
298 | void Process(TaggedNode node, int64_t scheduled_nsec); |
299 | |
300 | void ProcessInline(TaggedNodeReadyQueue* inline_ready, |
301 | int64_t scheduled_nsec); |
302 | |
303 | Status ProcessSync(const NodeItem& item, OpKernelContext::Params* params, |
304 | EntryVector* outputs, NodeExecStatsInterface* stats); |
305 | void ProcessAsync(const NodeItem& item, const OpKernelContext::Params& params, |
306 | const TaggedNode& tagged_node, Entry* first_input, |
307 | NodeExecStatsInterface* stats, |
308 | activity_watcher::ActivityId activity_id); |
309 | void ProcessNoop(NodeExecStatsInterface* stats); |
310 | void ProcessConstTensor(const NodeItem& item, EntryVector* outputs, |
311 | NodeExecStatsInterface* stats); |
312 | |
313 | // Before invoking item->kernel, fills in its "inputs". |
314 | Status PrepareInputs(const NodeItem& item, Entry* first_input, |
315 | TensorValueVec* inputs, |
316 | AllocatorAttributeVec* input_alloc_attrs, |
317 | bool* is_input_dead); |
318 | |
319 | // After item->kernel computation is done, processes its outputs. |
320 | Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, |
321 | Entry* outputs, NodeExecStatsInterface* stats); |
322 | |
323 | // Called after each node finishes. Takes ownership of "stats". Returns true |
324 | // if execution has completed. |
325 | // |
326 | // This method will clear `*ready` before returning. |
327 | bool NodeDone(const Status& s, TaggedNodeSeq* ready, |
328 | NodeExecStatsInterface* stats, |
329 | TaggedNodeReadyQueue* inline_ready); |
330 | |
331 | // Schedule all the expensive nodes in '*ready', and put all the inexpensive |
332 | // nodes in 'ready' into 'inline_ready'. |
333 | // |
334 | // This method will clear `*ready` before returning. |
335 | // |
336 | // REQUIRES: `!ready->empty()`. |
337 | void ScheduleReady(TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready); |
338 | |
339 | // A wrapper for runner_ to keep track of the pending queue length. Op |
340 | // execution should dispatch work using this function instead of using runner_ |
341 | // directly. |
342 | template <typename Closure> |
343 | void RunTask(Closure&& c, int sample_rate = 0); |
344 | |
345 | // Clean up when this executor is done. |
346 | void Finish(); |
347 | void ScheduleFinish(); |
348 | |
349 | // Contains the device context assigned by the device at the beginning of a |
350 | // step. |
351 | DeviceContext* device_context_ = nullptr; |
352 | |
353 | const bool vlog_; // true if VLOG_IS_ON(1). Used to check vlog cheaply. |
354 | |
355 | // true if LogMemory::IsEnabled(). Used to check memory enabled cheaply. |
356 | const bool log_memory_; |
357 | |
358 | int64_t step_id_; |
359 | int64_t start_time_usecs_ = 0; |
360 | // The deadline for the session to complete by. Empty if unspecified. |
361 | absl::optional<absl::Time> deadline_; |
362 | |
363 | // Maximum number of kernels that can be scheduled inline. If lots of kernels |
364 | // are ready at the same time, scheduling them in one thread can be very slow. |
365 | // TODO(fishx): Make it configurable if necessary. |
366 | static constexpr uint64 kInlineScheduleReadyThreshold = 500; |
367 | |
368 | // Not owned. |
369 | RendezvousInterface* rendezvous_; |
370 | CollectiveExecutor* collective_executor_ = nullptr; |
371 | SessionState* session_state_; |
372 | string session_handle_; |
373 | const SessionMetadata* session_metadata_ = nullptr; |
374 | TensorStore* tensor_store_; |
375 | // Step-local container. |
376 | ScopedStepContainer* step_container_; |
377 | StepStatsCollectorInterface* const stats_collector_; |
378 | const tracing::EventCollector* const event_collector_; |
379 | Context context_; |
380 | |
381 | // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper |
382 | // instead of a pointer? (avoids having to delete). |
383 | checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; |
384 | CallFrameInterface* call_frame_; |
385 | const ImmutableExecutorState& immutable_state_; |
386 | ExecutorImpl::KernelStats* const kernel_stats_; |
387 | CancellationManager* cancellation_manager_; |
388 | CoordinationServiceAgent* coordination_service_agent_; |
389 | absl::optional<ManagedStackTrace> stack_trace_ = absl::nullopt; |
390 | // If not null, use this device to schedule intra-op operation |
391 | std::unique_ptr<DeviceBase> user_device_; |
392 | Executor::Args::Runner runner_; |
393 | bool sync_on_finish_; |
394 | const bool run_all_kernels_inline_; |
395 | |
396 | PropagatorStateType propagator_; |
397 | |
398 | // Invoked when the execution finishes. |
399 | Executor::DoneCallback done_cb_; |
400 | |
401 | std::atomic_int_fast32_t num_outstanding_ops_; |
402 | |
403 | // Available via OpKernelContext to every OpKernel invocation. |
404 | mutex num_deferred_ops_mu_; |
405 | int64_t num_deferred_ops_ TF_GUARDED_BY(num_deferred_ops_mu_) = 0; |
406 | bool finish_when_deferred_ops_done_ TF_GUARDED_BY(num_deferred_ops_mu_) = |
407 | false; |
408 | |
409 | mutex mu_; |
410 | Status status_ TF_GUARDED_BY(mu_); |
411 | }; |
412 | |
413 | template <class PropagatorStateType> |
414 | ExecutorState<PropagatorStateType>::ExecutorState( |
415 | const Executor::Args& args, const ImmutableExecutorState& immutable_state, |
416 | ExecutorImpl::KernelStats* kernel_stats) |
417 | : vlog_(VLOG_IS_ON(1)), |
418 | log_memory_(LogMemory::IsEnabled()), |
419 | step_id_(args.step_id), |
420 | start_time_usecs_(args.start_time_usecs), |
421 | deadline_(args.deadline), |
422 | rendezvous_(args.rendezvous), |
423 | collective_executor_(args.collective_executor), |
424 | session_state_(args.session_state), |
425 | session_handle_(args.session_handle), |
426 | session_metadata_(immutable_state.params().session_metadata), |
427 | tensor_store_(args.tensor_store), |
428 | step_container_(args.step_container), |
429 | stats_collector_(args.stats_collector), |
430 | event_collector_( |
431 | tracing::GetEventCollector(tracing::EventCategory::kCompute)), |
432 | context_(ContextKind::kThread), |
433 | slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), |
434 | call_frame_(args.call_frame), |
435 | immutable_state_(immutable_state), |
436 | kernel_stats_(kernel_stats), |
437 | cancellation_manager_(args.cancellation_manager), |
438 | coordination_service_agent_(args.coordination_service_agent), |
439 | stack_trace_(args.stack_trace), |
440 | runner_(args.runner), |
441 | sync_on_finish_(args.sync_on_finish), |
442 | run_all_kernels_inline_(args.run_all_kernels_inline), |
443 | propagator_(immutable_state, step_id_, vlog_), |
444 | num_outstanding_ops_(0) { |
445 | if (args.user_intra_op_threadpool != nullptr) { |
446 | Device* device = immutable_state_.params().device; |
447 | user_device_ = RenamedDevice::NewRenamedDevice( |
448 | device->name(), device, false, false, args.user_intra_op_threadpool); |
449 | } |
450 | } |
451 | |
452 | template <class PropagatorStateType> |
453 | ExecutorState<PropagatorStateType>::~ExecutorState() { |
454 | if (device_context_) { |
455 | device_context_->Unref(); |
456 | } |
457 | delete slice_reader_cache_; |
458 | } |
459 | |
460 | template <class PropagatorStateType> |
461 | template <typename Closure> |
462 | void ExecutorState<PropagatorStateType>::RunTask(Closure&& c, int sample_rate) { |
463 | // Align the atomic variables at 64 bytes to avoid false-sharing, assuming the |
464 | // cacheline size is 64 bytes or smaller. |
465 | alignas(64) static std::atomic<int64_t> num_enqueue_ops{0}; |
466 | alignas(64) static std::atomic<int64_t> num_dequeue_ops{0}; |
467 | |
468 | auto n_enqueues = num_enqueue_ops.fetch_add(1, std::memory_order_relaxed); |
469 | // Sample the queue length on at least every 16 enqueue operations. This |
470 | // amortizes the cost of metric updates across 16 operations. |
471 | if (n_enqueues % std::max(16, sample_rate) == 0) { |
472 | auto n_dequeues = num_dequeue_ops.load(std::memory_order_relaxed); |
473 | metrics::UpdateGraphPendingQueueLength(n_enqueues - n_dequeues); |
474 | } |
475 | |
476 | // mutable is needed because std::forward<Closure> in the lambda body may move |
477 | // the Closure `c`. |
478 | runner_([c = std::forward<Closure>(c)]() mutable { |
479 | num_dequeue_ops.fetch_add(1, std::memory_order_relaxed); |
480 | std::forward<Closure>(c)(); |
481 | }); |
482 | } |
483 | |
484 | template <class PropagatorStateType> |
485 | void ExecutorState<PropagatorStateType>::RunAsync(Executor::DoneCallback done) { |
486 | TaggedNodeSeq ready; |
487 | |
488 | // Ask the device to fill in the device context map. |
489 | Device* device = immutable_state_.params().device; |
490 | const Status get_context_status = |
491 | device->TryGetDeviceContext(&device_context_); |
492 | if (!get_context_status.ok()) { |
493 | delete this; |
494 | done(get_context_status); |
495 | return; |
496 | } |
497 | |
498 | // Initialize the ready queue. |
499 | ready.reserve(immutable_state_.root_nodes().size()); |
500 | propagator_.ActivateRoots(immutable_state_.root_nodes(), &ready); |
501 | num_outstanding_ops_ = ready.size(); |
502 | if (ready.empty()) { |
503 | delete this; |
504 | done(OkStatus()); |
505 | } else { |
506 | done_cb_ = std::move(done); |
507 | // Schedule to run all the ready ops in thread pool. |
508 | ScheduleReady(&ready, nullptr); |
509 | } |
510 | } |
511 | |
512 | // State kept alive for executing an asynchronous node in another |
513 | // thread. NOTE: We need to make a copy of p.input and p.input_alloc_attrs for |
514 | // asynchronous kernels because OpKernelContext methods like input_type(i) needs |
515 | // the param points to valid input type vector. It's not an issue for |
516 | // sync kernels because these vectors are kept on the stack. |
517 | template <class PropagatorStateType> |
518 | struct ExecutorState<PropagatorStateType>::AsyncState { |
519 | AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, |
520 | const NodeItem* _item, Entry* _first_input, |
521 | NodeExecStatsInterface* _stats) |
522 | : saved_inputs(p.inputs.begin(), p.inputs.end()), |
523 | saved_input_alloc_attrs(p.input_alloc_attrs.begin(), |
524 | p.input_alloc_attrs.end()), |
525 | params(p), |
526 | tagged_node(_tagged_node), |
527 | item(_item), |
528 | first_input(_first_input), |
529 | // ParamsButClearingEigenGPUDevice does equivalent of |
530 | // params.eigen_gpu_device = nullptr; |
531 | ctx(ParamsButClearingEigenGPUDevice(¶ms), item->num_outputs), |
532 | stats(_stats) { |
533 | params.inputs = saved_inputs; |
534 | params.input_alloc_attrs = saved_input_alloc_attrs; |
535 | } |
536 | |
537 | TensorValueVec saved_inputs; |
538 | AllocatorAttributeVec saved_input_alloc_attrs; |
539 | OpKernelContext::Params params; |
540 | TaggedNode tagged_node; |
541 | const NodeItem* item; |
542 | Entry* first_input; |
543 | OpKernelContext ctx; |
544 | NodeExecStatsInterface* stats; |
545 | |
546 | private: |
547 | OpKernelContext::Params* ParamsButClearingEigenGPUDevice( |
548 | OpKernelContext::Params* p) { |
549 | // Ensure OpKernelContext constructor will make a new eigen GPU device if |
550 | // necessary. |
551 | p->eigen_gpu_device = nullptr; // Force allocation |
552 | return p; |
553 | } |
554 | }; |
555 | |
556 | // Returns true if `item` might be traced by the given trace and event |
557 | // collectors. Returns false only if `item` definitely will not be traced. |
558 | bool MightTrace(const tracing::EventCollector* event_collector, |
559 | bool is_expensive) { |
560 | // Tracing will only be enabled if either `event_collector` is non null, |
561 | // or `trace_collector` is non-null and enabled for this particular kernel. |
562 | // Although `profiler::TraceMe`, `profiler::ScopedAnnotation`, and |
563 | // `tracing::ScopedRegion` check subsets of these properties internally in |
564 | // their constructors, the cost of passing the necessary arguments to them can |
565 | // be significant, so we avoid constructing them in the common case (when we |
566 | // know they will not be used). |
567 | if (event_collector != nullptr) { |
568 | return true; |
569 | } |
570 | |
571 | if (profiler::ScopedAnnotation::IsEnabled()) return true; |
572 | |
573 | return profiler::TraceMe::Active(profiler::GetTFTraceMeLevel(is_expensive)); |
574 | } |
575 | |
576 | template <class PropagatorStateType> |
577 | Status ExecutorState<PropagatorStateType>::ProcessSync( |
578 | const NodeItem& item, OpKernelContext::Params* params, EntryVector* outputs, |
579 | NodeExecStatsInterface* stats) { |
580 | Status s; |
581 | OpKernelContext ctx(params, item.num_outputs); |
582 | nodestats::SetOpStart(stats); |
583 | |
584 | OpKernel* op_kernel = item.kernel; |
585 | Device* device = immutable_state_.params().device; |
586 | const bool is_expensive = kernel_stats_->IsExpensive(item); |
587 | |
588 | if (TF_PREDICT_FALSE(MightTrace(event_collector_, is_expensive))) { |
589 | tracing::ScopedRegion region(tracing::EventCategory::kCompute, |
590 | op_kernel->name_view()); |
591 | profiler::AnnotatedTraceMe activity( |
592 | [op_kernel, &ctx] { |
593 | return op_kernel->TraceString( |
594 | ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); |
595 | }, |
596 | profiler::GetTFTraceMeLevel(is_expensive)); |
597 | device->Compute(op_kernel, &ctx); |
598 | } else if (kernel_stats_->HasExpensiveMarker(item)) { |
599 | KernelTimer timer; |
600 | device->Compute(op_kernel, &ctx); |
601 | // For expensive kernels, always update the cost estimate. For inexpensive |
602 | // kernels, update the cost estimate with ~1/16 probability. This assumes |
603 | // that the last 4 bits of the CPU cycle count is uniformly distributed. |
604 | constexpr int kKernelExecutionTrackingInvocationSkipCount = 16; |
605 | if (is_expensive || |
606 | timer.start_cycles % kKernelExecutionTrackingInvocationSkipCount == 0) { |
607 | kernel_stats_->UpdateCostEstimate(item, timer.ElapsedCycles()); |
608 | } |
609 | } else { |
610 | device->Compute(op_kernel, &ctx); |
611 | } |
612 | nodestats::SetOpEnd(stats); |
613 | if (outputs->size() < item.num_outputs) outputs->resize(item.num_outputs); |
614 | s = ProcessOutputs(item, &ctx, outputs->data(), stats); |
615 | nodestats::SetMemory(stats, &ctx); |
616 | return s; |
617 | } |
618 | |
619 | template <class PropagatorStateType> |
620 | void ExecutorState<PropagatorStateType>::ProcessAsync( |
621 | const NodeItem& item, const OpKernelContext::Params& params, |
622 | const TaggedNode& tagged_node, Entry* first_input, |
623 | NodeExecStatsInterface* stats, activity_watcher::ActivityId activity_id) { |
624 | AsyncOpKernel* async_kernel = item.kernel->AsAsync(); |
625 | DCHECK(async_kernel != nullptr); |
626 | AsyncState* state = |
627 | new AsyncState(params, tagged_node, &item, first_input, stats); |
628 | |
629 | auto done = [this, state, activity_id]() { |
630 | Device* device = immutable_state_.params().device; |
631 | NodeExecStatsInterface* stats = state->stats; // Shorthand |
632 | Entry* first_input = state->first_input; // Shorthand |
633 | |
634 | nodestats::SetOpEnd(stats); |
635 | EntryVector outputs(state->item->num_outputs); |
636 | Status s = ProcessOutputs(*state->item, &state->ctx, outputs.data(), stats); |
637 | nodestats::SetMemory(stats, &state->ctx); |
638 | if (vlog_) { |
639 | VLOG(2) << "Async kernel done: " << state->item->node_id << " step " |
640 | << step_id_ << " " << SummarizeNodeDef(state->item->kernel->def()) |
641 | << (state->tagged_node.get_is_dead() ? " is dead" : "" ) |
642 | << " device: " << device->name(); |
643 | } |
644 | |
645 | // Clears inputs. |
646 | const int num_inputs = state->item->num_inputs; |
647 | for (int i = 0; i < num_inputs; ++i) { |
648 | (first_input + i)->ClearVal(); |
649 | } |
650 | propagator_.MaybeMarkCompleted(state->tagged_node); |
651 | activity_watcher::ActivityEnd(activity_id); |
652 | TaggedNodeSeq ready; |
653 | if (s.ok()) { |
654 | propagator_.PropagateOutputs(state->tagged_node, &outputs, &ready); |
655 | } |
656 | outputs.clear(); |
657 | const bool completed = NodeDone(s, &ready, stats, nullptr); |
658 | delete state; |
659 | if (completed) ScheduleFinish(); |
660 | }; |
661 | nodestats::SetOpStart(stats); |
662 | { |
663 | profiler::AnnotatedTraceMe activity( |
664 | [async_kernel, state] { |
665 | return async_kernel->TraceString( |
666 | state->ctx, /*verbose=*/profiler::TfOpDetailsEnabled()); |
667 | }, |
668 | profiler::GetTFTraceMeLevel(kernel_stats_->IsExpensive(item))); |
669 | immutable_state_.params().device->ComputeAsync(async_kernel, &state->ctx, |
670 | std::move(done)); |
671 | } |
672 | } |
673 | |
674 | template <class PropagatorStateType> |
675 | void ExecutorState<PropagatorStateType>::ProcessNoop( |
676 | NodeExecStatsInterface* stats) { |
677 | nodestats::SetOpStart(stats); |
678 | nodestats::SetOpEnd(stats); |
679 | } |
680 | |
681 | template <class PropagatorStateType> |
682 | void ExecutorState<PropagatorStateType>::ProcessConstTensor( |
683 | const NodeItem& item, EntryVector* outputs, NodeExecStatsInterface* stats) { |
684 | nodestats::SetOpStart(stats); |
685 | nodestats::SetOpEnd(stats); |
686 | Entry& output = (*outputs)[0]; |
687 | output.state = Entry::State::HAS_CONST_TENSOR; |
688 | output.const_tensor = item.const_tensor; |
689 | output.alloc_attr = item.output_attrs()[0]; |
690 | } |
691 | |
692 | template <class PropagatorStateType> |
693 | void ExecutorState<PropagatorStateType>::Process(TaggedNode tagged_node, |
694 | int64_t scheduled_nsec) { |
695 | profiler::TraceMe traceme("ExecutorState::Process Scheduled" , |
696 | profiler::TraceMeLevel::kVerbose); |
697 | TaggedNodeReadyQueue inline_ready; |
698 | inline_ready.push_back(tagged_node); |
699 | return ProcessInline(&inline_ready, scheduled_nsec); |
700 | } |
701 | |
702 | template <class PropagatorStateType> |
703 | void ExecutorState<PropagatorStateType>::ProcessInline( |
704 | TaggedNodeReadyQueue* inline_ready, int64_t scheduled_nsec) { |
705 | WithContext wc(context_); |
706 | TaggedNodeSeq ready; |
707 | |
708 | // Parameters passed to OpKernel::Compute. |
709 | TensorValueVec inputs; |
710 | AllocatorAttributeVec input_alloc_attrs; |
711 | |
712 | OpKernelContext::Params params; |
713 | params.step_id = step_id_; |
714 | // Override device's threadpool if user provides an intra_op_threadpool |
715 | Device* device = immutable_state_.params().device; |
716 | if (user_device_) { |
717 | params.device = user_device_.get(); |
718 | } else { |
719 | params.device = device; |
720 | } |
721 | params.start_time_usecs = start_time_usecs_; |
722 | params.deadline = deadline_; |
723 | params.log_memory = log_memory_; |
724 | params.rendezvous = rendezvous_; |
725 | params.collective_executor = collective_executor_; |
726 | params.session_state = session_state_; |
727 | params.session_handle = session_handle_; |
728 | params.session_metadata = session_metadata_; |
729 | params.tensor_store = tensor_store_; |
730 | params.cancellation_manager = cancellation_manager_; |
731 | params.coordination_service_agent = coordination_service_agent_; |
732 | params.stack_trace = stack_trace_; |
733 | params.call_frame = call_frame_; |
734 | params.function_library = immutable_state_.params().function_library; |
735 | params.resource_manager = device->resource_manager(); |
736 | params.step_container = step_container_; |
737 | params.slice_reader_cache = slice_reader_cache_; |
738 | params.runner = &runner_; |
739 | params.run_all_kernels_inline = run_all_kernels_inline_; |
740 | params.stats_collector = stats_collector_; |
741 | params.inc_num_deferred_ops_function = [this]() { |
742 | mutex_lock lock(num_deferred_ops_mu_); |
743 | num_deferred_ops_++; |
744 | }; |
745 | params.dec_num_deferred_ops_function = [this]() { |
746 | bool finish_when_deferred_ops_done = false; |
747 | { |
748 | mutex_lock lock(num_deferred_ops_mu_); |
749 | num_deferred_ops_--; |
750 | if (num_deferred_ops_ == 0) { |
751 | finish_when_deferred_ops_done = finish_when_deferred_ops_done_; |
752 | } |
753 | } |
754 | // Invoke Finish if the graph processing has completed. Finish is always |
755 | // called exactly once per ExecutorState, either here if there are any |
756 | // deferred ops, or in ScheduleFinish if there aren't any deferred ops. |
757 | if (finish_when_deferred_ops_done) Finish(); |
758 | }; |
759 | |
760 | // Set the device_context for this device, if it exists. |
761 | params.op_device_context = device_context_; |
762 | |
763 | Status s; |
764 | NodeExecStatsInterface* stats = nullptr; |
765 | |
766 | EntryVector outputs(1); |
767 | |
768 | bool completed = false; |
769 | int64_t last_iter_num = -1; |
770 | std::unique_ptr<profiler::TraceMeConsumer> iteration_scope; |
771 | while (!inline_ready->empty()) { |
772 | TaggedNode tagged_node = inline_ready->front(); |
773 | |
774 | int64_t current_iter_num = tagged_node.get_iter_num(); |
775 | if (current_iter_num != last_iter_num) { |
776 | iteration_scope = std::make_unique<profiler::TraceMeConsumer>( |
777 | // From TraceMeProducer in DirectSession::RunInternal, |
778 | // GraphMgr::ExecuteAsync, or FunctionLibraryRuntime::Run. |
779 | [&] { |
780 | // NOTE: This tracing uses the iteration number from the first |
781 | // tagged node that executes during this call to `Process()`. In |
782 | // principle, subsequent nodes could have different values of |
783 | // `iter_num` that will not be traced. |
784 | return profiler::TraceMeEncode( |
785 | "ExecutorState::Process" , |
786 | {{"id" , step_id_}, {"iter_num" , tagged_node.get_iter_num()}}); |
787 | }, |
788 | profiler::ContextType::kTfExecutor, step_id_, |
789 | profiler::TraceMeLevel::kInfo); |
790 | last_iter_num = current_iter_num; |
791 | } |
792 | inline_ready->pop_front(); |
793 | const NodeItem& item = tagged_node.get_node_item(); |
794 | const int id = item.node_id; |
795 | |
796 | propagator_.MaybeMarkStarted(tagged_node); |
797 | const activity_watcher::ActivityId activity_id = |
798 | activity_watcher::ActivityStart( |
799 | [&]() { |
800 | return std::make_unique<activity_watcher::Activity>( |
801 | "ExecutorState::Process" , |
802 | activity_watcher::ActivityCategory::kMisc, |
803 | activity_watcher::Activity::Attributes{ |
804 | {"node_name" , item.kernel->def().name()}, |
805 | {"op" , item.kernel->def().op()}, |
806 | {"iter_num" , absl::StrCat(tagged_node.get_iter_num())}, |
807 | {"step_id" , absl::StrCat(params.step_id)}, |
808 | {"node_id" , absl::StrCat(id)}, |
809 | {"device" , device->name()}, |
810 | }); |
811 | }, |
812 | /*level=*/2); |
813 | |
814 | params.track_allocations = false; |
815 | stats = nullptr; |
816 | if (stats_collector_ && !tagged_node.get_is_dead()) { |
817 | stats = stats_collector_->CreateNodeExecStats(&item.kernel->def()); |
818 | // Track allocations if and only if we are collecting statistics, and |
819 | // `stats` object is expecting allocations to be tracked. |
820 | params.track_allocations = stats ? stats->TrackAllocations() : false; |
821 | nodestats::SetScheduled(stats, scheduled_nsec); |
822 | nodestats::SetAllStart(stats); |
823 | } |
824 | |
825 | if (vlog_) { |
826 | VLOG(1) << "Process node: " << id << " step " << params.step_id << " " |
827 | << SummarizeNodeDef(item.kernel->def()) |
828 | << (tagged_node.get_is_dead() ? " is dead" : "" ) |
829 | << " device: " << device->name(); |
830 | } |
831 | |
832 | Entry* first_input = propagator_.GetInputTensors(tagged_node); |
833 | |
834 | // Only execute this node if it is not dead or it is a send/recv |
835 | // transfer node. For transfer nodes, we need to propagate the "dead" |
836 | // bit even when the node is dead. |
837 | bool launched_asynchronously = false; |
838 | if (tagged_node.get_is_dead() && !item.is_transfer_node) { |
839 | if (outputs.size() < item.num_outputs) outputs.resize(item.num_outputs); |
840 | } else if (TF_PREDICT_FALSE(item.is_noop)) { |
841 | ProcessNoop(stats); |
842 | } else if (item.const_tensor != nullptr && !params.track_allocations) { |
843 | ProcessConstTensor(item, &outputs, stats); |
844 | } else { |
845 | // Prepares inputs. |
846 | bool is_input_dead = false; |
847 | s = PrepareInputs(item, first_input, &inputs, &input_alloc_attrs, |
848 | &is_input_dead); |
849 | if (!s.ok()) { |
850 | // Clear inputs. |
851 | const int num_inputs = item.num_inputs; |
852 | for (int i = 0; i < num_inputs; ++i) { |
853 | (first_input + i)->ClearVal(); |
854 | } |
855 | propagator_.MaybeMarkCompleted(tagged_node); |
856 | activity_watcher::ActivityEnd(activity_id); |
857 | // Continue to process the nodes in 'inline_ready'. |
858 | completed = NodeDone(s, &ready, stats, inline_ready); |
859 | continue; |
860 | } |
861 | |
862 | // Set up compute params. |
863 | params.op_kernel = item.kernel; |
864 | params.frame_iter = propagator_.GetFrameAndIter(tagged_node); |
865 | params.is_input_dead = is_input_dead; |
866 | params.output_attr_array = item.output_attrs(); |
867 | params.forward_from_array = item.forward_from(); |
868 | params.outputs_required_array = item.outputs_required.get(); |
869 | params.inputs = inputs; |
870 | params.input_alloc_attrs = input_alloc_attrs; |
871 | |
872 | if (item.kernel_is_async) { |
873 | ProcessAsync(item, params, tagged_node, first_input, stats, |
874 | activity_id); |
875 | launched_asynchronously = true; |
876 | } else { |
877 | s = ProcessSync(item, ¶ms, &outputs, stats); |
878 | } |
879 | } |
880 | |
881 | if (!launched_asynchronously) { |
882 | if (vlog_) { |
883 | VLOG(2) << "Synchronous kernel done: " << id << " step " |
884 | << params.step_id << " " << SummarizeNodeDef(item.kernel->def()) |
885 | << (tagged_node.get_is_dead() ? " is dead: " : "" ) |
886 | << " device: " << device->name(); |
887 | } |
888 | |
889 | // Clears inputs. |
890 | const int num_inputs = item.num_inputs; |
891 | for (int i = 0; i < num_inputs; ++i) { |
892 | (first_input + i)->ClearVal(); |
893 | } |
894 | propagator_.MaybeMarkCompleted(tagged_node); |
895 | activity_watcher::ActivityEnd(activity_id); |
896 | // Propagates outputs. |
897 | if (s.ok()) { |
898 | propagator_.PropagateOutputs(tagged_node, &outputs, &ready); |
899 | } |
900 | |
901 | // Clear outputs without deallocating the `outputs` vector. |
902 | const int num_outputs = item.num_outputs; |
903 | for (int i = 0; i < num_outputs; ++i) { |
904 | outputs[i].ClearVal(); |
905 | } |
906 | |
907 | if (stats) { |
908 | scheduled_nsec = nodestats::NowInNsec(); |
909 | } |
910 | // Postprocess. |
911 | completed = NodeDone(s, &ready, stats, inline_ready); |
912 | } |
913 | } // while !inline_ready.empty() |
914 | |
915 | // This thread of computation is done if completed = true. |
916 | if (completed) ScheduleFinish(); |
917 | } |
918 | |
919 | template <class PropagatorStateType> |
920 | Status ExecutorState<PropagatorStateType>::PrepareInputs( |
921 | const NodeItem& item, Entry* first_input, TensorValueVec* inputs, |
922 | AllocatorAttributeVec* input_alloc_attrs, bool* is_input_dead) { |
923 | inputs->resize(item.num_inputs); |
924 | input_alloc_attrs->resize(item.num_inputs); |
925 | |
926 | *is_input_dead = false; |
927 | |
928 | for (int i = 0; i < item.num_inputs; ++i) { |
929 | const bool expect_ref = TF_PREDICT_FALSE(item.is_any_input_ref_typed) && |
930 | IsRefType(item.input_type(i)); |
931 | Entry* entry = first_input + i; |
932 | (*input_alloc_attrs)[i] = entry->alloc_attr; |
933 | |
934 | // i-th input. |
935 | TensorValue* inp = &(*inputs)[i]; |
936 | |
937 | switch (entry->state) { |
938 | case Entry::State::NO_VALUE: { |
939 | // Only merge and transfer nodes can have no-value inputs. |
940 | inp->mutex_if_ref = nullptr; |
941 | if (item.is_merge) { |
942 | inp->tensor = nullptr; |
943 | } else { |
944 | DCHECK(item.is_transfer_node) |
945 | << item.kernel->name() << " - input " << i; |
946 | entry->state = Entry::State::HAS_CONST_TENSOR; |
947 | entry->const_tensor = kEmptyTensor; |
948 | // NOTE(mrry): This `const_cast` is necessary because `TensorValue` |
949 | // stores a non-const `Tensor*`, and relies on the `OpKernelContext` |
950 | // accessors making dynamic checks that prevent using an immutable |
951 | // tensor as a mutable tensor. |
952 | inp->tensor = const_cast<Tensor*>(kEmptyTensor); |
953 | *is_input_dead = true; |
954 | } |
955 | break; |
956 | } |
957 | |
958 | case Entry::State::HAS_VALUE: { |
959 | if (TF_PREDICT_FALSE(expect_ref)) { |
960 | return AttachDef( |
961 | errors::InvalidArgument(i, "-th input expects a ref type" ), |
962 | item.kernel->def()); |
963 | } |
964 | inp->mutex_if_ref = nullptr; |
965 | inp->tensor = entry->val.get(); |
966 | break; |
967 | } |
968 | |
969 | case Entry::State::HAS_CONST_TENSOR: { |
970 | if (TF_PREDICT_FALSE(expect_ref)) { |
971 | return AttachDef( |
972 | errors::InvalidArgument(i, "-th input expects a ref type" ), |
973 | item.kernel->def()); |
974 | } |
975 | // NOTE(mrry): This `const_cast` is necessary because `TensorValue` |
976 | // stores a non-const `Tensor*`, and relies on the `OpKernelContext` |
977 | // accessors making dynamic checks that prevent using an immutable |
978 | // tensor as a mutable tensor. |
979 | inp->mutex_if_ref = nullptr; |
980 | inp->tensor = const_cast<Tensor*>(entry->const_tensor); |
981 | break; |
982 | } |
983 | |
984 | case Entry::State::HAS_REF_TENSOR: { |
985 | { |
986 | tf_shared_lock ml(*entry->ref_tensor.mu); |
987 | if (TF_PREDICT_FALSE(!entry->ref_tensor.tensor->IsInitialized() && |
988 | !item.is_initialization_op)) { |
989 | return AttachDef(errors::FailedPrecondition( |
990 | "Attempting to use uninitialized value " , |
991 | item.kernel->requested_input(i)), |
992 | item.kernel->def()); |
993 | } |
994 | } |
995 | |
996 | if (expect_ref) { |
997 | inp->mutex_if_ref = entry->ref_tensor.mu; |
998 | inp->tensor = entry->ref_tensor.tensor; |
999 | } else { |
1000 | // Automatically deref the tensor ref when the op expects a |
1001 | // tensor but is given a ref to a tensor. Need to deref it |
1002 | // under the mutex. |
1003 | { |
1004 | mutex* ref_mu = entry->ref_tensor.mu; |
1005 | Tensor* ref_tensor = entry->ref_tensor.tensor; |
1006 | tf_shared_lock l(*ref_mu); |
1007 | entry->val.Init(*ref_tensor); |
1008 | } |
1009 | entry->state = Entry::State::HAS_VALUE; |
1010 | |
1011 | inp->mutex_if_ref = nullptr; |
1012 | inp->tensor = entry->val.get(); |
1013 | // The dtype of entry->ref_tensor.tensor could have been changed by |
1014 | // another operation that ran after the operation that "produced" it |
1015 | // executed, so re-validate that the type of the dereferenced tensor |
1016 | // matches the expected input type. |
1017 | if (TF_PREDICT_FALSE(item.input_type(i) != inp->tensor->dtype())) { |
1018 | return AttachDef( |
1019 | errors::InvalidArgument( |
1020 | i, "-th input expects type " , |
1021 | DataTypeString(item.input_type(i)), |
1022 | " but automatically dereferenced input tensor has type " , |
1023 | DataTypeString(inp->tensor->dtype())), |
1024 | item.kernel->def()); |
1025 | } |
1026 | } |
1027 | break; |
1028 | } |
1029 | } |
1030 | } |
1031 | return OkStatus(); |
1032 | } |
1033 | |
1034 | template <class PropagatorStateType> |
1035 | Status ExecutorState<PropagatorStateType>::ProcessOutputs( |
1036 | const NodeItem& item, OpKernelContext* ctx, Entry* outputs, |
1037 | NodeExecStatsInterface* stats) { |
1038 | Status s = ctx->status(); |
1039 | if (!s.ok()) { |
1040 | s = AttachDef(s, item.kernel->def()); |
1041 | // TODO(misard) Replace with a finer-grain enabling flag once we |
1042 | // add better optional debugging support. |
1043 | if (vlog_ && VLOG_IS_ON(1)) { |
1044 | LOG(WARNING) << this << " Compute status: " << s; |
1045 | } |
1046 | if (s.code() == error::RESOURCE_EXHAUSTED) { |
1047 | if (stats_collector_) { |
1048 | string err = stats_collector_->ReportAllocsOnResourceExhausted( |
1049 | s.error_message()); |
1050 | s = errors::CreateWithUpdatedMessage( |
1051 | s, strings::StrCat(s.error_message(), err)); |
1052 | } else { |
1053 | s = errors::CreateWithUpdatedMessage( |
1054 | s, |
1055 | strings::StrCat( |
1056 | s.error_message(), |
1057 | "\nHint: If you want to see a list of allocated tensors when " |
1058 | "OOM happens, add report_tensor_allocations_upon_oom " |
1059 | "to RunOptions for current allocation info. This isn't " |
1060 | "available when running in Eager mode.\n" )); |
1061 | } |
1062 | } else if (s.code() == error::UNAVAILABLE && |
1063 | !item.is_distributed_communication) { |
1064 | s = errors::ReplaceErrorFromNonCommunicationOps(s, item.kernel->name()); |
1065 | } |
1066 | return s; |
1067 | } |
1068 | |
1069 | for (int i = 0; i < item.num_outputs; ++i) { |
1070 | const TensorValue val = ctx->release_output(i); |
1071 | Entry* out = &outputs[i]; |
1072 | DCHECK(out->state == Entry::State::NO_VALUE); |
1073 | |
1074 | if (val.tensor == nullptr) { |
1075 | // Unless it's a Switch or a Recv, or the executor has marked the output |
1076 | // as not required, the node must produce a tensor value at i-th output. |
1077 | if (!(item.is_recv_or_switch || |
1078 | (item.outputs_required && !item.outputs_required[i]))) { |
1079 | s.Update(errors::Internal("Missing " , i, "-th output from " , |
1080 | FormatNodeDefForError(item.kernel->def()))); |
1081 | } |
1082 | } else { |
1083 | // Set the allocator attributes of the output entry. |
1084 | out->alloc_attr = ctx->output_alloc_attr(i); |
1085 | |
1086 | // Sanity check of output tensor types. We need to inspect this safely as |
1087 | // we are in the tensor buffer. |
1088 | DataType dtype = val.dtype_safe(); |
1089 | if (dtype == item.output_type(i)) { |
1090 | if (stats && val.tensor->IsInitialized()) { |
1091 | nodestats::SetOutput(stats, i, val.tensor); |
1092 | } |
1093 | if (val.is_ref()) { |
1094 | out->state = Entry::State::HAS_REF_TENSOR; |
1095 | out->ref_tensor.tensor = val.tensor; |
1096 | out->ref_tensor.mu = val.mutex_if_ref; |
1097 | if (log_memory_) { |
1098 | Tensor to_log; |
1099 | { |
1100 | // Dereference the tensor under the lock. |
1101 | tf_shared_lock l(*out->ref_tensor.mu); |
1102 | to_log = *out->ref_tensor.tensor; |
1103 | } |
1104 | LogMemory::RecordTensorOutput(ctx->op_kernel().name(), |
1105 | ctx->step_id(), i, to_log); |
1106 | } |
1107 | } else { |
1108 | // NOTE that std::move is used here, so val.tensor goes to |
1109 | // uninitialized state (val.tensor->IsInitialized return false). |
1110 | out->state = Entry::State::HAS_VALUE; |
1111 | out->val.Init(std::move(*val.tensor)); |
1112 | if (log_memory_) { |
1113 | LogMemory::RecordTensorOutput(ctx->op_kernel().name(), |
1114 | ctx->step_id(), i, *out->val); |
1115 | } |
1116 | } |
1117 | } else { |
1118 | s.Update( |
1119 | errors::Internal("Output " , i, " of type " , DataTypeString(dtype), |
1120 | " does not match declared output type " , |
1121 | DataTypeString(item.output_type(i)), " for node " , |
1122 | FormatNodeDefForError(item.kernel->def()))); |
1123 | } |
1124 | } |
1125 | if (!val.is_ref()) { |
1126 | // If OpKernelContext returns outputs via pass-by-value, we |
1127 | // don't need this trouble. |
1128 | delete val.tensor; |
1129 | } |
1130 | } |
1131 | return s; |
1132 | } |
1133 | |
1134 | template <class PropagatorStateType> |
1135 | bool ExecutorState<PropagatorStateType>::NodeDone( |
1136 | const Status& s, TaggedNodeSeq* ready, NodeExecStatsInterface* stats, |
1137 | TaggedNodeReadyQueue* inline_ready) { |
1138 | if (stats) { |
1139 | nodestats::SetAllEnd(stats); |
1140 | DCHECK_NE(stats_collector_, nullptr); |
1141 | stats->Done(immutable_state_.params().device->name()); |
1142 | } |
1143 | |
1144 | if (TF_PREDICT_TRUE(s.ok())) { |
1145 | const size_t ready_size = ready->size(); |
1146 | if (ready_size == 0) { |
1147 | return num_outstanding_ops_.fetch_sub(1) == 1; |
1148 | } else { |
1149 | // NOTE: Avoid touching the atomic counter if only one node becomes ready. |
1150 | if (ready_size > 1) { |
1151 | num_outstanding_ops_.fetch_add(ready_size - 1, |
1152 | std::memory_order_relaxed); |
1153 | } |
1154 | |
1155 | // Schedule the ready nodes in 'ready'. |
1156 | ScheduleReady(ready, inline_ready); |
1157 | |
1158 | return false; |
1159 | } |
1160 | } else { |
1161 | bool abort_run = false; |
1162 | Status maybe_derived_s(s); |
1163 | |
1164 | // Some error happened. This thread of computation is done. |
1165 | { |
1166 | mutex_lock l(mu_); |
1167 | if (status_.ok()) { |
1168 | // If this is the first node to fail in this run, we are responsible for |
1169 | // aborting all other execution in the step. |
1170 | abort_run = true; |
1171 | |
1172 | // If execution has been cancelled, mark cancelled or aborted errors as |
1173 | // being derived. Note that the original node that fails might also |
1174 | // trigger cancellation, and here we make sure the original error is |
1175 | // exposed to users and not buried as a derived error. |
1176 | if (cancellation_manager_ && cancellation_manager_->IsCancelled() && |
1177 | (errors::IsCancelled(s) || errors::IsAborted(s))) { |
1178 | status_ = StatusGroup::MakeDerived(s); |
1179 | maybe_derived_s = status_; |
1180 | } else { |
1181 | status_ = s; |
1182 | } |
1183 | } |
1184 | } |
1185 | |
1186 | if (abort_run) { |
1187 | TRACEPRINTF("StartAbort: %s" , s.ToString().c_str()); |
1188 | if (cancellation_manager_) { |
1189 | // Only log when the abort happens during the actual run time. |
1190 | // Use VLOG instead of LOG(warning) because error status is expected |
1191 | // when the executor is run under the grappler optimization phase or |
1192 | // when iterating through a tf.data input pipeline. |
1193 | VLOG(1) << "[" << immutable_state_.params().device->name() |
1194 | << "] Executor start aborting: " << s; |
1195 | } |
1196 | |
1197 | if (rendezvous_) { |
1198 | rendezvous_->StartAbort(s); |
1199 | } |
1200 | if (cancellation_manager_) { |
1201 | cancellation_manager_->StartCancelWithStatus(maybe_derived_s); |
1202 | } else if (collective_executor_) { |
1203 | // If there's cancellation_manager_, collective ops aborts |
1204 | // collective_executor_ upon cancellation; otherwise we need to abort |
1205 | // here. |
1206 | collective_executor_->StartAbort(s); |
1207 | } |
1208 | } |
1209 | |
1210 | return num_outstanding_ops_.fetch_sub(1) == 1; |
1211 | } |
1212 | } |
1213 | |
1214 | template <class PropagatorStateType> |
1215 | void ExecutorState<PropagatorStateType>::ScheduleReady( |
1216 | TaggedNodeSeq* ready, TaggedNodeReadyQueue* inline_ready) { |
1217 | profiler::TraceMe activity( |
1218 | [&]() { |
1219 | return strings::StrCat( |
1220 | "ExecutorState::ScheduleReady#" , |
1221 | "ready_size=" , (ready == nullptr ? -1 : ready->size()), |
1222 | ",inline_ready_size=" , |
1223 | (inline_ready == nullptr ? -1 : inline_ready->size()), "#" ); |
1224 | }, |
1225 | profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); |
1226 | DCHECK(!ready->empty()); |
1227 | |
1228 | int64_t scheduled_nsec = 0; |
1229 | if (stats_collector_) { |
1230 | scheduled_nsec = nodestats::NowInNsec(); |
1231 | } |
1232 | |
1233 | if (run_all_kernels_inline_) { |
1234 | if (inline_ready == nullptr) { |
1235 | // Schedule all ready kernels from a single closure. This ensure that, |
1236 | // regardless of the `runner_` implementation, all kernels will run |
1237 | // sequentially on the same thread, and thread wakeup overhead and |
1238 | // executor mutex contention will be minimized. |
1239 | RunTask([this, ready = std::move(*ready), scheduled_nsec]() { |
1240 | for (auto& tagged_node : ready) { |
1241 | Process(tagged_node, scheduled_nsec); |
1242 | } |
1243 | }); |
1244 | } else { |
1245 | for (auto& tagged_node : *ready) { |
1246 | inline_ready->push_back(tagged_node); |
1247 | } |
1248 | } |
1249 | } else { |
1250 | const TaggedNode* curr_expensive_node = nullptr; |
1251 | TaggedNodeSeq expensive_nodes; |
1252 | if (inline_ready == nullptr) { |
1253 | // Schedule to run all the ready ops in thread pool. |
1254 | for (auto& tagged_node : *ready) { |
1255 | RunTask([=]() { Process(tagged_node, scheduled_nsec); }, |
1256 | /*sample_rate=*/ready->size()); |
1257 | } |
1258 | } else { |
1259 | for (auto& tagged_node : *ready) { |
1260 | const NodeItem& item = *tagged_node.node_item; |
1261 | if (tagged_node.get_is_dead() || !kernel_stats_->IsExpensive(item)) { |
1262 | // Inline this inexpensive node. |
1263 | inline_ready->push_back(tagged_node); |
1264 | } else { |
1265 | if (curr_expensive_node) { |
1266 | expensive_nodes.push_back(*curr_expensive_node); |
1267 | } |
1268 | curr_expensive_node = &tagged_node; |
1269 | } |
1270 | } |
1271 | } |
1272 | if (curr_expensive_node) { |
1273 | if (inline_ready->empty()) { |
1274 | inline_ready->push_back(*curr_expensive_node); |
1275 | } else { |
1276 | // There are inline nodes to run already. We dispatch this expensive |
1277 | // node to other thread. |
1278 | expensive_nodes.push_back(*curr_expensive_node); |
1279 | } |
1280 | } |
1281 | if (!expensive_nodes.empty()) { |
1282 | if (expensive_nodes.size() < kInlineScheduleReadyThreshold) { |
1283 | for (auto& tagged_node : expensive_nodes) { |
1284 | RunTask(std::bind(&ExecutorState::Process, this, tagged_node, |
1285 | scheduled_nsec), |
1286 | /*sample_rate=*/expensive_nodes.size()); |
1287 | } |
1288 | } else { |
1289 | // There are too many ready expensive nodes. Schedule them in child |
1290 | // threads. |
1291 | // TODO(fishx): Apply the same optimization to cheap ops as well since |
1292 | // executing lots of cheap ops in one thread can potentially be the |
1293 | // bottleneck as well. |
1294 | auto it = expensive_nodes.begin(); |
1295 | while (it < expensive_nodes.end()) { |
1296 | auto end = it; |
1297 | std::advance(end, kInlineScheduleReadyThreshold); |
1298 | if (end > expensive_nodes.end()) { |
1299 | end = expensive_nodes.end(); |
1300 | } |
1301 | TaggedNodeSeq ready_chunk{it, end}; |
1302 | RunTask( |
1303 | [this, ready_chunk = std::move(ready_chunk), scheduled_nsec]() { |
1304 | profiler::TraceMe activity( |
1305 | [&]() { |
1306 | return strings::StrCat( |
1307 | "ExecutorState::ScheduleReady::" |
1308 | "ChildThreadExpensiveNodes#" , |
1309 | "ready_chunk_size=" , ready_chunk.size(), "#" ); |
1310 | }, |
1311 | profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); |
1312 | for (auto& tagged_node : ready_chunk) { |
1313 | RunTask(std::bind(&ExecutorState::Process, this, tagged_node, |
1314 | scheduled_nsec), |
1315 | /*sample_rate=*/ready_chunk.size()); |
1316 | } |
1317 | }); |
1318 | it = end; |
1319 | } |
1320 | } |
1321 | } |
1322 | } |
1323 | ready->clear(); |
1324 | } |
1325 | |
1326 | template <class PropagatorStateType> |
1327 | void ExecutorState<PropagatorStateType>::ScheduleFinish() { |
1328 | // Checks condition to decide if needs to invoke Finish(). If there are |
1329 | // in-flight deffered ops, wait for `num_deferred_ops_` reaches 0 to invoke |
1330 | // Finish(). Otherwise, invoke Finish() directly. |
1331 | // Note that it is critical that the ScheduleFinish / Finish codepath does not |
1332 | // block, otherwise we might deadlock. See b/124523000 for details. |
1333 | { |
1334 | mutex_lock lock(num_deferred_ops_mu_); |
1335 | if (num_deferred_ops_ > 0) { |
1336 | finish_when_deferred_ops_done_ = true; |
1337 | return; |
1338 | } |
1339 | } |
1340 | // Finish is always called exactly once per ExecutorState, either here if |
1341 | // there aren't any deferred ops, or in the dec_num_deferred_ops_function if |
1342 | // there are deferred ops. |
1343 | Finish(); |
1344 | } |
1345 | |
1346 | template <class PropagatorStateType> |
1347 | void ExecutorState<PropagatorStateType>::Finish() { |
1348 | mu_.lock(); |
1349 | auto status = status_; |
1350 | auto done_cb = std::move(done_cb_); |
1351 | auto runner = std::move(runner_); |
1352 | mu_.unlock(); |
1353 | int64_t step_id = step_id_; |
1354 | CHECK(done_cb != nullptr); |
1355 | Device* device = immutable_state_.params().device; |
1356 | |
1357 | if (vlog_ && !status.ok() && VLOG_IS_ON(1)) { |
1358 | // Logs verbose information about the current state of active and pending |
1359 | // nodes in the propagator. |
1360 | propagator_.DumpState(); |
1361 | } |
1362 | |
1363 | // There are several potential race conditions below. To name a few: |
1364 | // 1. Even if the device's status is OK at the precise moment when |
1365 | // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus() |
1366 | // is called below, caused by work enqueued onto the same device by other |
1367 | // concurrent ExecutorState objects. |
1368 | // 2. Some implementations of Device::RefreshStatus, such as |
1369 | // XlaDevice::RefreshStatus, may be inherently racy because it releases the |
1370 | // device mutex after a stream pointer is acquired and before the stream is |
1371 | // queried for status. |
1372 | // 3. It's the same for some implementations of Device::Sync, such as |
1373 | // XlaDevice::Sync. |
1374 | // |
1375 | // However, these race conditions are acceptable because a stream (and |
1376 | // therefore an XlaDevice) can only go from OK to not-OK, never the opposite, |
1377 | // which means we will at worst report errors when there isn't any, never the |
1378 | // opposite. |
1379 | |
1380 | // An early exit for devices don't allow sync on completion. Ops that run on |
1381 | // these devices should have used num_deferred_ops correctly to ensure the |
1382 | // device has finished all relevant work at this point. |
1383 | if (!device->AllowsSyncOnCompletion()) { |
1384 | status.Update(device->RefreshStatus()); |
1385 | if (!status.ok()) { |
1386 | // In device async execution mode, it's possible for device execution to |
1387 | // lag behind ExecutorState scheduling so much that this is the first |
1388 | // place a device execution error surfaces. |
1389 | // If so, all ExecutorState::NodeDone calls have already happened with OK |
1390 | // status. This is the last defense where StartCancel must be called to |
1391 | // abort all computation still running on any device. |
1392 | // TODO(b/124523000): Always call Finish in a separate thread, so even if |
1393 | // StartCancel blocks the current thread's execution, we won't encounter |
1394 | // deadlocks caused by inter-op thread exhaustion. |
1395 | if (rendezvous_) { |
1396 | rendezvous_->StartAbort(status); |
1397 | } |
1398 | if (cancellation_manager_) { |
1399 | cancellation_manager_->StartCancelWithStatus(status); |
1400 | } else if (collective_executor_) { |
1401 | // If there's cancellation_manager_, collective ops aborts |
1402 | // collective_executor_ upon cancellation; otherwise we need to abort |
1403 | // here. |
1404 | collective_executor_->StartAbort(status); |
1405 | } |
1406 | } |
1407 | delete this; |
1408 | runner([step_id, status, done_cb = std::move(done_cb)]() { |
1409 | profiler::TraceMeConsumer activity( |
1410 | // From TraceMeProducer in KernelAndDeviceFunc::RunAsync, |
1411 | // DirectSession::RunInternal or GraphMgr::ExecuteAsync. |
1412 | [&] { |
1413 | return profiler::TraceMeEncode("ExecutorDoneCallback" , |
1414 | {{"id" , step_id}}); |
1415 | }, |
1416 | profiler::ContextType::kTfExecutor, step_id, |
1417 | profiler::TraceMeLevel::kInfo); |
1418 | done_cb(status); |
1419 | }); |
1420 | return; |
1421 | } |
1422 | |
1423 | if (sync_on_finish_ && status.ok()) { |
1424 | // Block until the device has finished all queued operations. For |
1425 | // devices like GPUs that continue to execute Ops after their Compute |
1426 | // methods have completed, this ensures that control is not returned to |
1427 | // the user until the step (and its side-effects) has actually completed. |
1428 | device->Sync([this, step_id, runner = std::move(runner), |
1429 | done_cb = std::move(done_cb)](const Status& status) mutable { |
1430 | delete this; |
1431 | runner([step_id, status, done_cb = std::move(done_cb)]() { |
1432 | profiler::TraceMeConsumer activity( |
1433 | // From TraceMeProducer in KernelAndDeviceFunc::RunAsync, |
1434 | // DirectSession::RunInternal or GraphMgr::ExecuteAsync. |
1435 | [&] { |
1436 | return profiler::TraceMeEncode("ExecutorDoneCallback" , |
1437 | {{"id" , step_id}}); |
1438 | }, |
1439 | profiler::ContextType::kTfExecutor, step_id, |
1440 | profiler::TraceMeLevel::kInfo); |
1441 | done_cb(status); |
1442 | }); |
1443 | }); |
1444 | } else { |
1445 | delete this; |
1446 | runner([step_id, status, done_cb = std::move(done_cb)]() { |
1447 | profiler::TraceMeConsumer activity( |
1448 | // From TraceMeProducer in KernelAndDeviceFunc::RunAsync, |
1449 | // DirectSession::RunInternal or GraphMgr::ExecuteAsync. |
1450 | [&] { |
1451 | return profiler::TraceMeEncode("ExecutorDoneCallback" , |
1452 | {{"id" , step_id}}); |
1453 | }, |
1454 | profiler::ContextType::kTfExecutor, step_id, |
1455 | profiler::TraceMeLevel::kInfo); |
1456 | done_cb(status); |
1457 | }); |
1458 | } |
1459 | } |
1460 | |
1461 | void ExecutorImpl::RunAsync(const Args& args, DoneCallback done) { |
1462 | if (OpOrderDeterminismRequired()) { |
1463 | (new ExecutorState<OrderedPropagatorState>(args, immutable_state_, |
1464 | &kernel_stats_)) |
1465 | ->RunAsync(std::move(done)); |
1466 | } else if (immutable_state_.requires_control_flow_support()) { |
1467 | (new ExecutorState<PropagatorState>(args, immutable_state_, &kernel_stats_)) |
1468 | ->RunAsync(std::move(done)); |
1469 | } else { |
1470 | (new ExecutorState<SimplePropagatorState>(args, immutable_state_, |
1471 | &kernel_stats_)) |
1472 | ->RunAsync(std::move(done)); |
1473 | } |
1474 | } |
1475 | |
1476 | } // namespace |
1477 | |
1478 | Status NewLocalExecutor(const LocalExecutorParams& params, const Graph& graph, |
1479 | Executor** executor) { |
1480 | ExecutorImpl* impl = new ExecutorImpl(params); |
1481 | const Status s = impl->Initialize(graph); |
1482 | if (s.ok()) { |
1483 | *executor = impl; |
1484 | } else { |
1485 | delete impl; |
1486 | } |
1487 | return s; |
1488 | } |
1489 | |
1490 | Status CreateNonCachedKernel(Device* device, FunctionLibraryRuntime* flib, |
1491 | const std::shared_ptr<const NodeProperties>& props, |
1492 | int graph_def_version, OpKernel** kernel) { |
1493 | const auto device_type = DeviceType(device->attributes().device_type()); |
1494 | auto allocator = device->GetAllocator(AllocatorAttributes()); |
1495 | return CreateOpKernel(device_type, device, allocator, flib, |
1496 | device->resource_manager(), props, graph_def_version, |
1497 | kernel); |
1498 | } |
1499 | |
1500 | void DeleteNonCachedKernel(OpKernel* kernel) { delete kernel; } |
1501 | |
1502 | namespace { |
1503 | |
1504 | class DefaultExecutorRegistrar { |
1505 | public: |
1506 | DefaultExecutorRegistrar() { |
1507 | Factory* factory = new Factory; |
1508 | ExecutorFactory::Register("" , factory); |
1509 | ExecutorFactory::Register("DEFAULT" , factory); |
1510 | } |
1511 | |
1512 | private: |
1513 | class Factory : public ExecutorFactory { |
1514 | Status NewExecutor(const LocalExecutorParams& params, const Graph& graph, |
1515 | std::unique_ptr<Executor>* out_executor) override { |
1516 | Executor* ret = nullptr; |
1517 | TF_RETURN_IF_ERROR(NewLocalExecutor(params, std::move(graph), &ret)); |
1518 | out_executor->reset(ret); |
1519 | return OkStatus(); |
1520 | } |
1521 | }; |
1522 | }; |
1523 | static DefaultExecutorRegistrar registrar; |
1524 | |
1525 | } // namespace |
1526 | |
1527 | } // namespace tensorflow |
1528 | |