1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ |
17 | |
18 | #include <queue> |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/core/common_runtime/entry.h" |
22 | #include "tensorflow/core/common_runtime/immutable_executor_state.h" |
23 | #include "tensorflow/core/common_runtime/pending_counts.h" |
24 | #include "tensorflow/core/framework/allocator.h" |
25 | #include "tensorflow/core/framework/control_flow.h" |
26 | #include "tensorflow/core/lib/gtl/flatmap.h" |
27 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
28 | #include "tensorflow/core/platform/env.h" |
29 | #include "tensorflow/core/platform/logging.h" |
30 | #include "tensorflow/core/platform/macros.h" |
31 | #include "tensorflow/core/platform/mutex.h" |
32 | #include "tensorflow/core/platform/thread_annotations.h" |
33 | #include "tensorflow/core/platform/types.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; |
38 | |
39 | // Represents the ephemeral "edge state" associated with one invocation of |
40 | // `Executor::Run()`. |
41 | // |
42 | // `PropagatorState` is responsible for propagating values along dataflow |
43 | // edges in a TensorFlow graph and determining which nodes are runnable. The |
44 | // executor primarily updates `PropagatorState` by calling `PropagateOutputs()` |
45 | // after processing a node, and `PropagatorState` dispatches `TaggedNode`s by |
46 | // adding them to a `TaggedNodeSeq`. |
47 | class PropagatorState { |
48 | public: |
49 | PropagatorState(const ImmutableExecutorState& immutable_state, |
50 | int64_t step_id, bool vlog); |
51 | ~PropagatorState(); |
52 | |
53 | private: |
54 | // Forward declaration so that `TaggedNode` can include a `FrameState*` and an |
55 | // `IterationState*`. |
56 | struct FrameState; |
57 | struct IterationState; |
58 | |
59 | public: |
60 | // A `TaggedNode` corresponds to a single invocation of a node's kernel, |
61 | // and it is created when the kernel becomes runnable (in a particular |
62 | // iteration of a particular frame). |
63 | struct TaggedNode { |
64 | const NodeItem* node_item; |
65 | FrameState* input_frame; |
66 | IterationState* input_iter; |
67 | bool is_dead; |
68 | |
69 | TaggedNode() = default; |
70 | TaggedNode(const NodeItem* node_item, FrameState* in_frame, |
71 | IterationState* in_iter, bool dead) |
72 | : node_item(node_item), |
73 | input_frame(in_frame), |
74 | input_iter(in_iter), |
75 | is_dead(dead) {} |
76 | |
77 | const NodeItem& get_node_item() const { return *node_item; } |
78 | |
79 | bool get_is_dead() const { return is_dead; } |
80 | int64_t get_iter_num() const; |
81 | }; |
82 | |
83 | // A drop-in replacement for std::deque<TaggedNode>. We typically don't |
84 | // have that many nodes in the ready queue, so we just use a vector and |
85 | // don't free up memory from the queue as we consume nodes. |
86 | class TaggedNodeReadyQueue { |
87 | public: |
88 | TaggedNodeReadyQueue() : front_index_(0) {} |
89 | |
90 | void push_back(const TaggedNode& node) { ready_.push_back(node); } |
91 | TaggedNode front() const { |
92 | DCHECK_LT(front_index_, ready_.size()); |
93 | return ready_[front_index_]; |
94 | } |
95 | void pop_front() { |
96 | DCHECK_LT(front_index_, ready_.size()); |
97 | front_index_++; |
98 | if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { |
99 | if (front_index_ == ready_.size()) { |
100 | ready_.clear(); |
101 | } else { |
102 | // Lots of unused entries at beginning of vector: move everything |
103 | // down to start of vector. |
104 | ready_.erase(ready_.begin(), ready_.begin() + front_index_); |
105 | } |
106 | front_index_ = 0; |
107 | } |
108 | } |
109 | bool empty() const { return ready_.empty(); } |
110 | int size() const { return ready_.size() - front_index_; } |
111 | |
112 | private: |
113 | // TODO(b/152925936): Re-evaluate these constants with current usage |
114 | // patterns. |
115 | static constexpr int kSpillThreshold = 16384; |
116 | gtl::InlinedVector<TaggedNode, 16> ready_; |
117 | int front_index_; |
118 | }; |
119 | |
120 | // TODO(b/152925936): Re-evaluate this constant with current usage patterns. |
121 | typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; |
122 | |
123 | private: |
124 | // The state of an iteration in a particular frame. |
125 | struct IterationState { |
126 | explicit IterationState(int64_t iter_num, |
127 | const PendingCounts* pending_counts, |
128 | int total_input_tensors) |
129 | : iter_num(iter_num), |
130 | input_tensors(new Entry[total_input_tensors]), |
131 | outstanding_ops(0), |
132 | outstanding_frame_count(0), |
133 | counts(*pending_counts) { // Initialize with copy of *pending_counts |
134 | } |
135 | |
136 | const int64_t |
137 | iter_num; // The index of this iteration in the enclosing loop. |
138 | |
139 | // One copy per iteration. For iteration k, i-th node's j-th input is in |
140 | // input_tensors[k][immutable_state_.nodes[i].input_start + j]. An entry is |
141 | // either a tensor pointer (pass-by-reference) or a tensor (pass-by-value). |
142 | // |
143 | // NOTE: No need to protect input_tensors[i] by any locks because it |
144 | // is resized once. Each element of tensors_ is written once by the |
145 | // source node of an edge and is cleared by the destination of the same |
146 | // edge. The latter node is never run concurrently with the former node. |
147 | Entry* input_tensors; |
148 | |
149 | // The number of outstanding ops for each iteration. |
150 | std::atomic<size_t> outstanding_ops; |
151 | |
152 | // The number of outstanding frames for each iteration. |
153 | int outstanding_frame_count; |
154 | int pending(PendingCounts::Handle h) { return counts.pending(h); } |
155 | int decrement_pending(PendingCounts::Handle h, int v) { |
156 | return counts.decrement_pending(h, v); |
157 | } |
158 | // Mark a merge node as live |
159 | // REQUIRES: Node corresponding to "h" is a merge node |
160 | void mark_live(PendingCounts::Handle h) { counts.mark_live(h); } |
161 | // Mark a node to show that processing has started. |
162 | void mark_started(PendingCounts::Handle h) { counts.mark_started(h); } |
163 | // Mark a node to show that processing has completed. |
164 | void mark_completed(PendingCounts::Handle h) { counts.mark_completed(h); } |
165 | PendingCounts::NodeState node_state(PendingCounts::Handle h) { |
166 | return counts.node_state(h); |
167 | } |
168 | |
169 | int dead_count(PendingCounts::Handle h) { return counts.dead_count(h); } |
170 | void increment_dead_count(PendingCounts::Handle h) { |
171 | counts.increment_dead_count(h); |
172 | } |
173 | // REQUIRES: Node corresponding to "h" is a merge node |
174 | PendingCounts::AdjustResult adjust_for_mark_live(PendingCounts::Handle h) { |
175 | return counts.adjust_for_mark_live(h); |
176 | } |
177 | // REQUIRES: Node corresponding to "h" is a merge node |
178 | PendingCounts::AdjustResult adjust_for_mark_live_atomic( |
179 | PendingCounts::Handle h) { |
180 | return counts.adjust_for_mark_live_atomic(h); |
181 | } |
182 | PendingCounts::AdjustResult adjust_for_decrement_pending( |
183 | PendingCounts::Handle h, int decrement_pending) { |
184 | return counts.adjust_for_decrement_pending(h, decrement_pending); |
185 | } |
186 | PendingCounts::AdjustResult adjust_for_decrement_pending_atomic( |
187 | PendingCounts::Handle h, int decrement_pending) { |
188 | return counts.adjust_for_decrement_pending_atomic(h, decrement_pending); |
189 | } |
190 | PendingCounts::AdjustResult adjust_for_increment_dead( |
191 | PendingCounts::Handle h) { |
192 | return counts.adjust_for_increment_dead(h); |
193 | } |
194 | PendingCounts::AdjustResult adjust_for_increment_dead_atomic( |
195 | PendingCounts::Handle h) { |
196 | return counts.adjust_for_increment_dead_atomic(h); |
197 | } |
198 | PendingCounts::AdjustResult adjust_for_activation(PendingCounts::Handle h, |
199 | bool increment_dead) { |
200 | return counts.adjust_for_activation(h, increment_dead); |
201 | } |
202 | PendingCounts::AdjustResult adjust_for_activation_atomic( |
203 | PendingCounts::Handle h, bool increment_dead) { |
204 | return counts.adjust_for_activation_atomic(h, increment_dead); |
205 | } |
206 | |
207 | ~IterationState() { delete[] input_tensors; } |
208 | |
209 | private: |
210 | PendingCounts counts; |
211 | }; |
212 | |
213 | struct FrameState { |
214 | explicit FrameState(const ImmutableExecutorState& immutable_state, |
215 | int parallel_iters) |
216 | : immutable_state(immutable_state), |
217 | max_parallel_iterations(parallel_iters), |
218 | num_outstanding_iterations(1), |
219 | iterations(parallel_iters + 1), |
220 | iterations_raw(iterations.data()) {} |
221 | |
222 | // A new frame is created for each loop. Execution starts at iteration 0. |
223 | // When a value at iteration 0 passes through a NextIteration node, |
224 | // iteration 1 is created and starts running. Note that iteration 0 may |
225 | // still be running so multiple iterations may run in parallel. The |
226 | // frame maintains the state of iterations in several data structures |
227 | // such as pending_count and input_tensors. When iteration 0 completes, |
228 | // we garbage collect the state of iteration 0. |
229 | // |
230 | // A frame instance is considered "done" and can be garbage collected |
231 | // if all its inputs have entered and all its iterations are "done". |
232 | // |
233 | // A frame manages the live iterations of an iterative computation. |
234 | // Iteration i is considered "done" when there are no outstanding ops, |
235 | // frames at iteration i are done, all recvs for this iteration are |
236 | // completed, and iteration i-1 is done. For iteration 0, we instead |
237 | // wait for there to be no more pending inputs of the frame. |
238 | // |
239 | // Frames and iterations are garbage collected once they are done. |
240 | // The state we need to keep around is highly dependent on the |
241 | // parallelism enabled by the scheduler. We may want to have the |
242 | // scheduler dynamically control the outstanding number of live |
243 | // parallel frames and iterations. To reduce the state space, the |
244 | // scheduler might want to schedule ops in inner frames first and |
245 | // lower iterations first. |
246 | // |
247 | // This frame state is mostly initialized lazily on demand so we |
248 | // don't introduce unnecessary overhead. |
249 | |
250 | // The immutable state of the executor the frame is in. |
251 | const ImmutableExecutorState& immutable_state; |
252 | |
253 | // The name of this frame, which is the concatenation of its parent |
254 | // frame name, the iteration of the parent frame when this frame was |
255 | // created, and the value of the attr 'frame_name'. |
256 | string frame_name; |
257 | |
258 | // The unique id for this frame. Generated by fingerprinting |
259 | // frame_name. |
260 | uint64 frame_id; |
261 | |
262 | // The iteration state of its parent frame when this frame is created. |
263 | // nullptr if there is no parent frame. The frame_name/parent_iter pair |
264 | // uniquely identifies this FrameState. |
265 | IterationState* parent_iter = nullptr; |
266 | |
267 | // The FrameState of its parent frame. |
268 | FrameState* parent_frame = nullptr; |
269 | |
270 | // The maximum allowed number of parallel iterations. |
271 | const int max_parallel_iterations; |
272 | |
273 | // The number of inputs this frame is still waiting. |
274 | int num_pending_inputs = 0; |
275 | |
276 | // The highest iteration number we have reached so far in this frame. |
277 | int64_t iteration_count TF_GUARDED_BY(mu) = 0; |
278 | |
279 | // The number of outstanding iterations. |
280 | int num_outstanding_iterations TF_GUARDED_BY(mu) = 1; |
281 | |
282 | private: |
283 | // The active iteration states of this frame. |
284 | gtl::InlinedVector<IterationState*, 12> iterations; |
285 | IterationState** const iterations_raw TF_GUARDED_BY(mu); |
286 | IterationState* iterations_first TF_GUARDED_BY(mu); |
287 | |
288 | public: |
289 | // The NextIteration nodes to enter a new iteration. If the number of |
290 | // outstanding iterations reaches the limit, we will defer the start of |
291 | // the next iteration until the number of outstanding iterations falls |
292 | // below the limit. |
293 | std::vector<std::pair<const NodeItem*, Entry>> next_iter_roots |
294 | TF_GUARDED_BY(mu); |
295 | |
296 | // The values of the loop invariants for this loop. They are added into |
297 | // this list as they "enter" the frame. When a loop invariant enters, |
298 | // we make it available to all active iterations. When the frame starts |
299 | // a new iteration, we make all the current loop invariants available |
300 | // to the new iteration. |
301 | std::vector<std::pair<const NodeItem*, Entry>> inv_values |
302 | TF_GUARDED_BY(iter_mu); |
303 | |
304 | // The list of dead exit node items for the current highest iteration. We |
305 | // will only "execute" the dead exits of the final iteration. |
306 | std::vector<const NodeItem*> dead_exits TF_GUARDED_BY(iter_mu); |
307 | |
308 | // Static information specific to this frame. |
309 | PendingCounts* pending_counts = nullptr; |
310 | int total_input_tensors = 0; |
311 | std::vector<const NodeItem*>* nodes = nullptr; |
312 | |
313 | // Lock ordering: ExecutorState.mu_ < mu < iter_mu; |
314 | // during structured traversal: parent_frame->mu < mu. |
315 | mutex mu; |
316 | |
317 | // This mutex lock should only be held when entering next iteration. |
318 | mutex iter_mu; |
319 | |
320 | void InitializeFrameInfo(const ImmutableExecutorState::FrameInfo& finfo); |
321 | |
322 | inline IterationState* GetIteration(int64_t iter) |
323 | TF_SHARED_LOCKS_REQUIRED(mu) { |
324 | if (TF_PREDICT_TRUE(iter == 0)) { |
325 | return iterations_first; |
326 | } else { |
327 | size_t index = iter % (max_parallel_iterations + 1); |
328 | return iterations_raw[index]; |
329 | } |
330 | } |
331 | |
332 | void SetIteration(int64_t iter, IterationState* state); |
333 | |
334 | // Adjust the outstanding op count by 'delta' and clean up the iterations in |
335 | // the frame if no more ops are oustanding. Return true iff the execution of |
336 | // the frame is done. |
337 | // |
338 | // Avoids acquiring the lock in the common case that the frame is not done. |
339 | bool AdjustOutstandingOps(IterationState* iter_state, int delta, |
340 | TaggedNodeSeq* ready); |
341 | |
342 | bool AdjustOutstandingOpsLocked(IterationState* iter_state, int delta, |
343 | TaggedNodeSeq* ready) |
344 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
345 | |
346 | bool AdjustOutstandingOpsFastPath(IterationState* iter_state, int delta) |
347 | TF_SHARED_LOCKS_REQUIRED(mu); |
348 | |
349 | // Convenience methods for the above 'Adjust' calls where delta takes the |
350 | // common value of -1. |
351 | bool DecrementOutstandingOps(IterationState* iter_state, |
352 | TaggedNodeSeq* ready); |
353 | |
354 | bool DecrementOutstandingOpsLocked(IterationState* iter_state, |
355 | TaggedNodeSeq* ready); |
356 | |
357 | // Returns true if the computation in the frame is completed. |
358 | bool IsFrameDone(); |
359 | |
360 | // Returns true if the iteration of the frame is completed. |
361 | bool IsIterationDone(IterationState* iter_state) |
362 | TF_SHARED_LOCKS_REQUIRED(mu); |
363 | |
364 | // Increments the iteration id. If this is a new iteration, initialize it. |
365 | // |
366 | // Returns a pointer to the new iteration. |
367 | IterationState* IncrementIteration(TaggedNodeSeq* ready) |
368 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
369 | |
370 | // Activate all the deferred NextIteration nodes in a new iteration. |
371 | void ActivateNexts(IterationState* iter_state, TaggedNodeSeq* ready) |
372 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
373 | |
374 | // Activate all the current loop invariants in a new iteration. |
375 | void ActivateLoopInvs(IterationState* iter_state, TaggedNodeSeq* ready) |
376 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
377 | |
378 | // Add a new loop invariant and make it available to all active |
379 | // iterations. |
380 | void AddLoopInv(const NodeItem* item, const Entry& entry, |
381 | TaggedNodeSeq* ready) TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
382 | |
383 | // Activate the successors of a node. Contents of *outputs are left in an |
384 | // indeterminate state after returning from this method. |
385 | // |
386 | // In the case that 'item' is a simple node (no merge/control outputs) this |
387 | // will acquire a shared lock and can run concurrently with other |
388 | // invocations. |
389 | // |
390 | // Return true if the frame is done after activation. |
391 | bool ActivateNodesAndAdjustOutstanding( |
392 | const NodeItem* item, const bool is_dead, IterationState* iter_state, |
393 | EntryVector* outputs, TaggedNodeSeq* ready, int decrement_activation); |
394 | |
395 | // Same as the above, but requires 'mu' already held in exclusive mode. |
396 | int ActivateNodesLocked(const NodeItem* item, const bool is_dead, |
397 | IterationState* iter_state, EntryVector* outputs, |
398 | TaggedNodeSeq* ready) |
399 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
400 | |
401 | // Cleanup iterations of this frame starting from the given iteration. |
402 | bool CleanupIterations(IterationState* iter_state, TaggedNodeSeq* ready) |
403 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu); |
404 | |
405 | void DumpIterationState(PropagatorState* parent) { |
406 | mutex_lock l(mu); |
407 | for (IterationState* iteration : iterations) { |
408 | if (iteration) { |
409 | LOG(WARNING) << " Iteration:" ; |
410 | parent->DumpIterationState(this, iteration); |
411 | } |
412 | } |
413 | } |
414 | |
415 | ~FrameState() { |
416 | for (size_t i = 0; i < iterations.size(); ++i) { |
417 | delete iterations[i]; |
418 | iterations[i] = nullptr; |
419 | } |
420 | } |
421 | |
422 | private: |
423 | // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. |
424 | // This variant does not use atomic operations to modify the pending counts |
425 | // and thus must hold the exclusive lock. |
426 | int ActivateNodesFastPathLocked(const NodeItem* item, const bool is_dead, |
427 | IterationState* iter_state, |
428 | EntryVector* outputs, TaggedNodeSeq* ready) |
429 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { |
430 | return ActivateNodesFastPathInternal<false>(item, is_dead, iter_state, |
431 | outputs, ready); |
432 | } |
433 | |
434 | // REQUIRES: `!item->is_any_consumer_merge_or_control_trigger`. |
435 | // This variant uses atomic operations to modify the pending counts. |
436 | int ActivateNodesFastPathShared(const NodeItem* item, const bool is_dead, |
437 | IterationState* iter_state, |
438 | EntryVector* outputs, TaggedNodeSeq* ready) |
439 | TF_SHARED_LOCKS_REQUIRED(mu) { |
440 | return ActivateNodesFastPathInternal<true>(item, is_dead, iter_state, |
441 | outputs, ready); |
442 | } |
443 | |
444 | template <bool atomic> |
445 | int ActivateNodesFastPathInternal(const NodeItem* item, const bool is_dead, |
446 | IterationState* iter_state, |
447 | EntryVector* outputs, |
448 | TaggedNodeSeq* ready); |
449 | |
450 | int ActivateNodesSlowPathLocked(const NodeItem* item, const bool is_dead, |
451 | IterationState* iter_state, |
452 | EntryVector* outputs, TaggedNodeSeq* ready) |
453 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { |
454 | return ActivateNodesSlowPathInternal<false>(item, is_dead, iter_state, |
455 | outputs, ready); |
456 | } |
457 | |
458 | int ActivateNodesSlowPathShared(const NodeItem* item, const bool is_dead, |
459 | IterationState* iter_state, |
460 | EntryVector* outputs, TaggedNodeSeq* ready) |
461 | TF_SHARED_LOCKS_REQUIRED(mu) { |
462 | return ActivateNodesSlowPathInternal<true>(item, is_dead, iter_state, |
463 | outputs, ready); |
464 | } |
465 | |
466 | template <bool atomic> |
467 | int ActivateNodesSlowPathInternal(const NodeItem* item, const bool is_dead, |
468 | IterationState* iter_state, |
469 | EntryVector* outputs, |
470 | TaggedNodeSeq* ready); |
471 | }; |
472 | |
473 | public: |
474 | // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. |
475 | void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, |
476 | TaggedNodeSeq* ready); |
477 | |
478 | // After processing the outputs, propagates the outputs to their dsts. |
479 | // Contents of *outputs are left in an indeterminate state after |
480 | // returning from this method. |
481 | void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, |
482 | TaggedNodeSeq* ready); |
483 | |
484 | // Returns an array of `Entry` objects corresponding to the inputs of |
485 | // `tagged_node`. |
486 | // |
487 | // NOTE: Thread safety analysis is disabled on this method, because the |
488 | // underlying `IterationState` and its array of `input_tensors` retain the |
489 | // same address while the iteration is live. |
490 | Entry* GetInputTensors(const TaggedNode& tagged_node) const |
491 | TF_NO_THREAD_SAFETY_ANALYSIS { |
492 | return tagged_node.input_iter->input_tensors + |
493 | tagged_node.node_item->input_start; |
494 | } |
495 | |
496 | FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { |
497 | return {tagged_node.input_frame->frame_id, |
498 | tagged_node.input_iter->iter_num}; |
499 | } |
500 | |
501 | // Provide debugging output of the state of the executor. |
502 | void DumpState(); |
503 | |
504 | // For debugging/logging only. |
505 | void MaybeMarkStarted(const TaggedNode& tagged_node) { |
506 | // TODO(misard) Replace with a finer-grain enabling flag once we add better |
507 | // optional debugging support. |
508 | if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { |
509 | mutex_lock l(tagged_node.input_frame->mu); |
510 | tagged_node.input_iter->mark_started( |
511 | immutable_state_.pending_ids()[tagged_node.node_item->node_id]); |
512 | } |
513 | } |
514 | |
515 | void MaybeMarkCompleted(const TaggedNode& tagged_node) { |
516 | // TODO(misard) Replace with a finer-grain enabling flag once we add better |
517 | // optional debugging support. |
518 | if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { |
519 | mutex_lock l(tagged_node.input_frame->mu); |
520 | tagged_node.input_iter->mark_completed( |
521 | immutable_state_.pending_ids()[tagged_node.node_item->node_id]); |
522 | } |
523 | } |
524 | |
525 | private: |
526 | // Find an existing or create a new child frame in the frame 'frame' at |
527 | // iteration 'iter'. |
528 | void FindOrCreateChildFrame(FrameState* frame, IterationState* iter_state, |
529 | const NodeItem& node_item, FrameState** child); |
530 | |
531 | // Delete a frame. Called when the frame is done. |
532 | void DeleteFrame(FrameState* frame, TaggedNodeSeq* ready); |
533 | |
534 | // Cleanup frames and iterations starting from frame/iter. Called when |
535 | // a child frame is done. |
536 | void CleanupFramesIterations(FrameState* frame, IterationState* iter_state, |
537 | TaggedNodeSeq* ready); |
538 | |
539 | // Provide debugging output about an outstanding iteration in the executor. |
540 | void DumpIterationState(const FrameState* frame, IterationState* iteration); |
541 | |
542 | const ImmutableExecutorState& immutable_state_; |
543 | const int64_t step_id_; |
544 | const bool vlog_; |
545 | |
546 | mutex mu_; |
547 | |
548 | // The root frame in which the execution of this step is started. |
549 | FrameState* root_frame_; |
550 | |
551 | // Mapping from frame ID to outstanding frames. A new frame is created |
552 | // at some iteration of an active frame. So the unique key for the new |
553 | // child frame is a hash composed of the ID of the parent frame, the iteration |
554 | // number at which the parent frame is creating the new frame, and the |
555 | // name of the new frame from nodedef. |
556 | absl::flat_hash_map<uint64, FrameState*> outstanding_frames_ |
557 | TF_GUARDED_BY(mu_); |
558 | |
559 | TF_DISALLOW_COPY_AND_ASSIGN(PropagatorState); |
560 | }; |
561 | |
562 | inline int64_t PropagatorState::TaggedNode::get_iter_num() const { |
563 | return input_iter->iter_num; |
564 | } |
565 | |
566 | // `OrderedPropagatorState` replaces `PropagatorState`s `TaggedNodeReadyQueue` |
567 | // with a priority queue. This ensures that the order in which we dequeue |
568 | // `TaggedNode&`s is stable with respect to ASLR. |
569 | // |
570 | // This is not always needed, as in a multithreaded environment, executions are |
571 | // expected to happen nondeterministically, but this nondeteminism can be a |
572 | // problem: For example, In usecases that are running close to the RAM limit of |
573 | // a device, reordering ops can cause an increase in memory fragmenenation, |
574 | // causing an OOM. |
575 | // This codepath is enabled using TF_DETERMINISTIC_ORDER=1 in executor.cc |
576 | class OrderedPropagatorState : public PropagatorState { |
577 | using PropagatorState::PropagatorState; |
578 | |
579 | public: |
580 | class TaggedNodeReadyQueue : PropagatorState::TaggedNodeReadyQueue { |
581 | public: |
582 | TaggedNodeReadyQueue() : readyp_(compare) {} |
583 | void push_back(const TaggedNode& node) { readyp_.push(node); } |
584 | TaggedNode front() const { return readyp_.top(); } |
585 | void pop_front() { readyp_.pop(); } |
586 | bool empty() const { return readyp_.empty(); } |
587 | int size() const { return readyp_.size(); } |
588 | |
589 | private: |
590 | static bool compare(TaggedNode const& lhs, TaggedNode const& rhs) { |
591 | std::tuple<int, uint64, int64_t> lhs_prio{lhs.node_item->node_id, |
592 | lhs.input_frame->frame_id, |
593 | lhs.input_iter->iter_num}; |
594 | std::tuple<int, uint64, int64_t> rhs_prio{rhs.node_item->node_id, |
595 | rhs.input_frame->frame_id, |
596 | rhs.input_iter->iter_num}; |
597 | return lhs_prio < rhs_prio; |
598 | } |
599 | |
600 | std::priority_queue<TaggedNode, std::vector<TaggedNode>, decltype(&compare)> |
601 | readyp_; |
602 | }; |
603 | }; |
604 | |
605 | } // namespace tensorflow |
606 | |
607 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PROPAGATOR_STATE_H_ |
608 | |