1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/propagator_state.h"
17
18#include "tensorflow/core/common_runtime/graph_view.h"
19#include "tensorflow/core/common_runtime/immutable_executor_state.h"
20#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/lib/hash/hash.h"
23#include "tensorflow/core/platform/hash.h"
24#include "tensorflow/core/profiler/lib/traceme.h"
25
26namespace tensorflow {
27
28PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state,
29 int64_t step_id, bool vlog)
30 : immutable_state_(immutable_state),
31 step_id_(step_id),
32 vlog_(vlog || VLOG_IS_ON(1)) {
33 // We start the entire execution in iteration 0 of the root frame
34 // so let us create the root frame and the state for iteration 0.
35 // We assume root_frame_->frame_name.empty().
36 root_frame_ = new FrameState(immutable_state_, 1);
37 root_frame_->frame_id = 0; // must be 0
38 root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info());
39
40 // Initialize iteration 0.
41 root_frame_->SetIteration(
42 0, new PropagatorState::IterationState(0, root_frame_->pending_counts,
43 root_frame_->total_input_tensors));
44
45 outstanding_frames_.emplace(root_frame_->frame_id, root_frame_);
46}
47
48PropagatorState::~PropagatorState() {
49 for (auto name_frame : outstanding_frames_) {
50 delete name_frame.second;
51 }
52}
53
54void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
55 TaggedNodeSeq* ready) {
56 mutex_lock l(root_frame_->mu);
57 IterationState* root_iter = root_frame_->GetIteration(0);
58 for (const NodeItem* item : roots) {
59 DCHECK_EQ(item->num_inputs, 0);
60 ready->emplace_back(item, root_frame_, root_iter, false);
61 }
62 root_iter->outstanding_ops = ready->size();
63}
64
65void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
66 EntryVector* outputs,
67 TaggedNodeSeq* ready) {
68 profiler::TraceMe activity(
69 [&]() {
70 return strings::StrCat(
71 "ExecutorPropagateOutputs#", "id=", step_id_,
72 ",kernel_name=", tagged_node.node_item->kernel->name_view(),
73 ",num_output_edges=", tagged_node.node_item->num_output_edges,
74 ",num_output_control_edges=",
75 tagged_node.node_item->num_output_control_edges,
76 ",input_frame=", tagged_node.input_frame->frame_id, "#");
77 },
78 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
79
80 const NodeItem* const item = tagged_node.node_item;
81 FrameState* const input_frame = tagged_node.input_frame;
82 IterationState* const input_iter = tagged_node.input_iter;
83 const bool is_dead = tagged_node.is_dead;
84
85 // Propagates outputs along out edges, and puts newly ready nodes
86 // into the ready queue.
87 DCHECK(ready->empty());
88 bool is_frame_done = false;
89 FrameState* output_frame = input_frame;
90 IterationState* output_iter = input_iter;
91
92 if (!item->is_enter_exit_or_next_iter) {
93 // Fast path for node types that don't need special handling.
94 // This is the case for most nodes.
95 DCHECK_EQ(input_frame, output_frame);
96 FrameState* frame = input_frame;
97 is_frame_done = frame->ActivateNodesAndAdjustOutstanding(
98 item, is_dead, output_iter, outputs, ready, /*decrement_activation=*/1);
99 } else if (item->is_enter) {
100 FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame);
101 {
102 mutex_lock l(output_frame->mu);
103 output_iter = output_frame->GetIteration(0);
104 if (item->is_constant_enter) {
105 // Propagate to all active iterations if this is a loop invariant.
106 output_frame->AddLoopInv(item, (*outputs)[0], ready);
107 } else {
108 int activated = output_frame->ActivateNodesLocked(
109 item, is_dead, output_iter, outputs, ready);
110 output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready);
111 }
112 output_frame->num_pending_inputs--;
113 }
114 is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
115 } else if (item->is_exit) {
116 if (is_dead) {
117 {
118 tf_shared_lock l(input_frame->mu);
119 // Stop and remember this node if it is a dead exit.
120 if (input_iter->iter_num == input_frame->iteration_count) {
121 mutex_lock l(input_frame->iter_mu);
122 input_frame->dead_exits.push_back(item);
123 }
124 }
125 is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
126 } else {
127 output_frame = input_frame->parent_frame;
128 output_iter = input_frame->parent_iter;
129 output_frame->ActivateNodesAndAdjustOutstanding(
130 item, is_dead, output_iter, outputs, ready,
131 /*decrement_activation=*/0);
132 is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
133 }
134 } else {
135 DCHECK(item->is_next_iteration);
136 if (is_dead) {
137 // Stop the deadness propagation.
138 output_frame = nullptr;
139 } else {
140 bool need_create_iter = false;
141 {
142 tf_shared_lock l(input_frame->mu);
143 if (input_iter->iter_num == input_frame->iteration_count) {
144 if (input_frame->num_outstanding_iterations ==
145 input_frame->max_parallel_iterations) {
146 // Reached the maximum for parallel iterations.
147 output_frame = nullptr;
148 mutex_lock l(input_frame->iter_mu);
149 input_frame->next_iter_roots.push_back({item, (*outputs)[0]});
150 } else {
151 // Need to create iteration state after acquiring mutex lock.
152 need_create_iter = true;
153 }
154 } else {
155 output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
156 }
157 }
158 if (output_frame != nullptr) {
159 if (need_create_iter) {
160 profiler::TraceMe activit1y(
161 [&]() {
162 return strings::StrCat(
163 "PropagateOutputs::NextIteration::CreateIterationState");
164 },
165 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
166 mutex_lock l(input_frame->mu);
167 if (input_iter->iter_num == input_frame->iteration_count) {
168 // Check another time since another thread may create the required
169 // iteration state.
170 // TODO(fishx): This may cause contention since multiple threads may
171 // race for this mutex lock. Further improve this if needed.
172 output_iter = input_frame->IncrementIteration(ready);
173 } else {
174 output_iter = input_frame->GetIteration(input_iter->iter_num + 1);
175 }
176 DCHECK(input_frame == output_frame);
177 int activated = output_frame->ActivateNodesLocked(
178 item, is_dead, output_iter, outputs, ready);
179 output_frame->AdjustOutstandingOpsLocked(output_iter, activated,
180 ready);
181 } else {
182 DCHECK(input_frame == output_frame);
183 output_frame->ActivateNodesAndAdjustOutstanding(
184 item, is_dead, output_iter, outputs, ready,
185 /*decrement_activation=*/0);
186 }
187 }
188 }
189 is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready);
190 }
191
192 // At this point, this node is completely done. We also know if the
193 // completion of this node makes its frame completed.
194 if (is_frame_done) {
195 FrameState* parent_frame = input_frame->parent_frame;
196 IterationState* parent_iter = input_frame->parent_iter;
197 DeleteFrame(input_frame, ready);
198 if (parent_frame != nullptr) {
199 // The completion of frame may cause completions in its parent frame.
200 // So clean things up recursively.
201 CleanupFramesIterations(parent_frame, parent_iter, ready);
202 }
203 }
204}
205
206void PropagatorState::DumpIterationState(const FrameState* frame,
207 IterationState* iteration) {
208 const std::vector<const NodeItem*>* nodes = frame->nodes;
209 // Dump any waiting nodes that are holding on to tensors.
210 for (const NodeItem* node : *nodes) {
211 PendingCounts::Handle pending_id =
212 immutable_state_.pending_ids()[node->node_id];
213 if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY ||
214 iteration->node_state(pending_id) == PendingCounts::PENDING_READY) {
215 DumpPendingNodeState(*node, iteration->input_tensors, false);
216 }
217 }
218 // Then the active nodes.
219 for (const NodeItem* node : *nodes) {
220 PendingCounts::Handle pending_id =
221 immutable_state_.pending_ids()[node->node_id];
222 if (iteration->node_state(pending_id) == PendingCounts::STARTED) {
223 DumpActiveNodeState(*node, iteration->input_tensors);
224 }
225 }
226 // Show all input tensors in use.
227 const int total_input_tensors = frame->total_input_tensors;
228 size_t total_bytes = 0;
229 for (int i = 0; i < total_input_tensors; ++i) {
230 const Entry& input = iteration->input_tensors[i];
231 const Tensor* tensor = GetTensorValueForDump(input);
232 if (tensor->IsInitialized()) {
233 LOG(WARNING) << " Input " << i << ": "
234 << strings::StrCat(
235 "Tensor<type: ", DataTypeString(tensor->dtype()),
236 " shape: ", tensor->shape().DebugString(),
237 ", bytes: ", tensor->TotalBytes(), ">");
238 total_bytes += tensor->TotalBytes();
239 }
240 }
241 LOG(WARNING) << " Total bytes " << total_bytes;
242}
243
244void PropagatorState::DumpState() {
245 mutex_lock l(mu_);
246 LOG(WARNING) << "Dumping state";
247 for (auto& frame : outstanding_frames_) {
248 LOG(WARNING) << frame.first;
249 FrameState* frame_state = frame.second;
250 frame_state->DumpIterationState(this);
251 }
252}
253
254void PropagatorState::FindOrCreateChildFrame(FrameState* frame,
255 IterationState* iter_state,
256 const NodeItem& node_item,
257 FrameState** child) {
258 // Get the child frame name.
259 const ImmutableExecutorState::FrameInfo& frame_info =
260 immutable_state_.get_enter_frame_info(node_item);
261
262 const uint64 child_id = Hash64Combine(
263 frame->frame_id,
264 Hash64Combine(iter_state->iter_num, Hash64(frame_info.name)));
265
266 {
267 tf_shared_lock executor_lock(mu_);
268 auto it = outstanding_frames_.find(child_id);
269 if (it != outstanding_frames_.end()) {
270 *child = it->second;
271 return;
272 }
273 }
274
275 // Need to create a new frame instance.
276 // Note that this new frame instance is created without any locks.
277 if (vlog_) {
278 const string child_name = strings::StrCat(
279 frame->frame_name, ";", iter_state->iter_num, ";", frame_info.name);
280 VLOG(2) << "Create frame: " << child_name << " id: " << child_id;
281 }
282
283 FrameState* temp =
284 new FrameState(immutable_state_, frame_info.parallel_iterations);
285 temp->frame_id = child_id;
286 temp->parent_frame = frame;
287 temp->parent_iter = iter_state;
288 temp->InitializeFrameInfo(frame_info);
289
290 // Initialize iteration 0.
291 {
292 mutex_lock l(temp->mu);
293 temp->SetIteration(0, new IterationState(0, temp->pending_counts,
294 temp->total_input_tensors));
295 }
296
297 {
298 mutex_lock executor_lock(mu_);
299 auto it = outstanding_frames_.find(child_id);
300 if (it != outstanding_frames_.end()) {
301 *child = it->second;
302 } else {
303 mutex_lock frame_lock(frame->mu);
304 iter_state->outstanding_frame_count++;
305 outstanding_frames_[child_id] = temp;
306 *child = temp;
307 temp = nullptr;
308 }
309 }
310 delete temp; // Not used so delete it.
311}
312
313void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
314 // First, propagate dead_exits (if any) to the parent frame.
315 FrameState* parent_frame = frame->parent_frame;
316 IterationState* parent_iter_state = frame->parent_iter;
317 if (parent_frame != nullptr) {
318 mutex_lock parent_frame_lock(parent_frame->mu);
319 // Propagate all the dead exits to the parent frame.
320 mutex_lock this_frame_lock(frame->mu);
321 mutex_lock iter_lock(frame->iter_mu);
322
323 for (const NodeItem* item : frame->dead_exits) {
324 auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready,
325 bool dst_dead) {
326 if (dst_ready) {
327 if (dst_item.is_control_trigger) dst_dead = false;
328 ready->emplace_back(&dst_item, parent_frame, parent_iter_state,
329 dst_dead);
330 parent_iter_state->outstanding_ops++;
331 }
332 };
333
334 auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) {
335 parent_iter_state->increment_dead_count(dst_pending_id);
336 return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0;
337 };
338
339 for (const EdgeInfo& e : item->output_edges()) {
340 const NodeItem& dst_item =
341 immutable_state_.graph_view().node_ref(e.dst_id);
342 const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
343
344 bool dst_dead = true;
345 bool dst_ready;
346 // We know this is a dead input to dst.
347 if (dst_item.is_merge) {
348 parent_iter_state->increment_dead_count(dst_pending_id);
349 const int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
350 dst_dead = (dead_cnt == dst_item.num_inputs);
351 dst_ready =
352 (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead;
353 } else {
354 dst_ready = propagate_to_non_merge(dst_pending_id);
355 }
356 maybe_add_to_ready(dst_item, dst_ready, dst_dead);
357 }
358
359 for (const ControlEdgeInfo& e : item->output_control_edges()) {
360 const NodeItem& dst_item =
361 immutable_state_.graph_view().node_ref(e.dst_id);
362 const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id];
363
364 bool dst_dead;
365 bool dst_ready;
366 // We know this is a dead input to dst.
367 if (dst_item.is_merge) {
368 parent_iter_state->decrement_pending(dst_pending_id, 2);
369 int count = parent_iter_state->pending(dst_pending_id);
370 int dead_cnt = parent_iter_state->dead_count(dst_pending_id);
371 dst_dead = (dead_cnt == dst_item.num_inputs);
372 dst_ready = (count == 0) || ((count == 1) && dst_dead);
373 } else {
374 dst_dead = true;
375 dst_ready = propagate_to_non_merge(dst_pending_id);
376 }
377 maybe_add_to_ready(dst_item, dst_ready, dst_dead);
378 }
379 }
380 }
381
382 // Delete the frame.
383 if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id;
384 {
385 mutex_lock executor_lock(mu_);
386 outstanding_frames_.erase(frame->frame_id);
387 }
388 delete frame;
389}
390
391void PropagatorState::CleanupFramesIterations(FrameState* frame,
392 IterationState* iter_state,
393 TaggedNodeSeq* ready) {
394 bool is_frame_done = false;
395 {
396 mutex_lock frame_lock(frame->mu);
397 iter_state->outstanding_frame_count--;
398 is_frame_done = frame->CleanupIterations(iter_state, ready);
399 }
400 if (is_frame_done) {
401 FrameState* parent_frame = frame->parent_frame;
402 IterationState* parent_iter = frame->parent_iter;
403 DeleteFrame(frame, ready);
404 if (parent_frame != nullptr) {
405 // The completion of frame may cause completions in its parent frame.
406 // So clean things up recursively.
407 CleanupFramesIterations(parent_frame, parent_iter, ready);
408 }
409 }
410}
411
412template <bool atomic>
413int PropagatorState::FrameState::ActivateNodesFastPathInternal(
414 const NodeItem* item, const bool is_dead, IterationState* iter_state,
415 EntryVector* outputs, TaggedNodeSeq* ready) {
416 // If we know that none of the item's edge destinations require special
417 // handling (i.e. none of the nodes is a merge or control trigger node), we
418 // can take a fast path that avoids accessing the destination NodeItem.
419 const GraphView& gview = immutable_state.graph_view();
420 int new_outstanding = 0;
421
422// Add dst to the ready queue if it's ready
423//
424// NOTE(mrry): Use a macro here instead of a lambda, because this method is
425// performance-critical and we need to ensure that the code is inlined.
426#define MAYBE_ADD_TO_READY(dst_id, adjust_result) \
427 do { \
428 if (!(adjust_result.pending_count > 0)) { \
429 const NodeItem* dst_item = &gview.node_ref(dst_id); \
430 TaggedNode& t = ready->emplace_back(); \
431 t.node_item = dst_item; \
432 t.input_frame = this; \
433 t.input_iter = iter_state; \
434 t.is_dead = adjust_result.dead_count > 0; \
435 new_outstanding++; \
436 } \
437 } while (0);
438
439 Entry* input_tensors = iter_state->input_tensors;
440 for (const EdgeInfo& e : item->output_edges()) {
441 const int dst_id = e.dst_id;
442 const PendingCounts::Handle dst_pending_id =
443 immutable_state.pending_ids()[dst_id];
444 const int src_slot = e.output_slot;
445
446 const bool increment_dead =
447 (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
448 const int dst_loc = e.input_slot;
449 if (e.is_last) {
450 input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
451 } else {
452 input_tensors[dst_loc] = (*outputs)[src_slot];
453 }
454 const PendingCounts::AdjustResult adjust_result =
455 atomic
456 ? iter_state->adjust_for_activation_atomic(dst_pending_id,
457 increment_dead)
458 : iter_state->adjust_for_activation(dst_pending_id, increment_dead);
459 MAYBE_ADD_TO_READY(dst_id, adjust_result);
460 }
461
462 for (const ControlEdgeInfo& e : item->output_control_edges()) {
463 const int dst_id = e.dst_id;
464 const PendingCounts::Handle dst_pending_id =
465 immutable_state.pending_ids()[dst_id];
466 const PendingCounts::AdjustResult adjust_result =
467 atomic
468 ? iter_state->adjust_for_activation_atomic(dst_pending_id, is_dead)
469 : iter_state->adjust_for_activation(dst_pending_id, is_dead);
470 MAYBE_ADD_TO_READY(dst_id, adjust_result);
471 }
472
473 return new_outstanding;
474#undef MAYBE_ADD_TO_READY
475}
476
477template <bool atomic>
478int PropagatorState::FrameState::ActivateNodesSlowPathInternal(
479 const NodeItem* item, const bool is_dead, IterationState* iter_state,
480 EntryVector* outputs, TaggedNodeSeq* ready) {
481 // If any of the edge destinations is a merge or a control trigger node,
482 // we need to read each destination NodeItem to determine what action
483 // to take.
484 const GraphView& gview = immutable_state.graph_view();
485 int activated = 0;
486 auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item,
487 bool dst_ready, bool dst_dead) {
488 // Add dst to the ready queue if it's ready
489 if (dst_ready) {
490 if (dst_item->is_control_trigger) dst_dead = false;
491 ready->emplace_back(dst_item, this, iter_state, dst_dead);
492 activated++;
493 }
494 };
495
496 Entry* input_tensors = iter_state->input_tensors;
497
498 for (const EdgeInfo& e : item->output_edges()) {
499 const int dst_id = e.dst_id;
500 const NodeItem* dst_item = &gview.node_ref(dst_id);
501 const PendingCounts::Handle dst_pending_id =
502 immutable_state.pending_ids()[dst_id];
503 const int src_slot = e.output_slot;
504
505 bool dst_dead = false;
506 bool dst_ready = false;
507
508 if (dst_item->is_merge) {
509 // A merge node is ready if all control inputs have arrived and either
510 // a) a live data input becomes available or b) all data inputs are
511 // dead. For Merge, pending's LSB is set iff a live data input has
512 // arrived.
513 if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) {
514 // This is a live data input.
515
516 // We have an assumption that merge op has only one live edge. Based on
517 // this assumption, we set the total pending count of a merge op to be
518 // 2 * control_edge + 1. So if we got a live data input of merge op, it
519 // is fine to directly set the input.
520 // NOTE(fishx): If there are multiple live edge for merge op, it will
521 // be a race condition and the last live edge will override the input
522 // of merge op. This should indicates a graph level issue.
523 const int dst_loc = e.input_slot;
524 if (e.is_last) {
525 input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
526 } else {
527 input_tensors[dst_loc] = (*outputs)[src_slot];
528 }
529
530 const PendingCounts::AdjustResult adjust_result =
531 atomic ? iter_state->adjust_for_mark_live_atomic(dst_pending_id)
532 : iter_state->adjust_for_mark_live(dst_pending_id);
533
534 // The low bit of count is set if and only if no live input has been
535 // used yet (mark_live clears it). The node should be started if and
536 // only if this is the first live input and there are no pending control
537 // edges, i.e. count == 1.
538 dst_ready = (adjust_result.pending_count == 1);
539 } else {
540 // This is a dead data input. Note that dst_node is dead if node is
541 // a dead enter. We need this to handle properly a while loop on
542 // the untaken branch of a conditional.
543 // TODO(yuanbyu): This is a bit hacky, but a good solution for
544 // now.
545 const PendingCounts::AdjustResult adjust_result =
546 atomic
547 ? iter_state->adjust_for_increment_dead_atomic(dst_pending_id)
548 : iter_state->adjust_for_increment_dead(dst_pending_id);
549 dst_dead = (adjust_result.dead_count == dst_item->num_inputs) ||
550 item->is_enter;
551 dst_ready = (adjust_result.pending_count == 1) && dst_dead;
552 }
553 } else {
554 // Handle all other (non-merge) nodes.
555
556 // We need to set the input of the op before adjusting activation.
557 // Otherwise it may have race condition since another thread may execute
558 // the op after adjusted activation.
559 const int dst_loc = e.input_slot;
560 if (e.is_last) {
561 input_tensors[dst_loc] = std::move((*outputs)[src_slot]);
562 } else {
563 input_tensors[dst_loc] = (*outputs)[src_slot];
564 }
565
566 const bool increment_dead =
567 (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE));
568 const PendingCounts::AdjustResult adjust_result =
569 atomic ? iter_state->adjust_for_activation_atomic(dst_pending_id,
570 increment_dead)
571 : iter_state->adjust_for_activation(dst_pending_id,
572 increment_dead);
573 dst_dead = adjust_result.dead_count > 0;
574 dst_ready = !(adjust_result.pending_count > 0);
575 }
576
577 maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
578 }
579
580 for (const ControlEdgeInfo& e : item->output_control_edges()) {
581 const int dst_id = e.dst_id;
582 const NodeItem* dst_item = &gview.node_ref(dst_id);
583 const PendingCounts::Handle dst_pending_id =
584 immutable_state.pending_ids()[dst_id];
585
586 bool dst_dead;
587 bool dst_ready;
588 if (dst_item->is_merge) {
589 // A merge node is ready if all control inputs have arrived and either
590 // a) a live data input becomes available or b) all data inputs are
591 // dead. For Merge, pending's LSB is set iff a live data input has
592 // arrived.
593 const PendingCounts::AdjustResult adjust_result =
594 atomic ? iter_state->adjust_for_decrement_pending_atomic(
595 dst_pending_id, /*decrement_pending=*/2)
596 : iter_state->adjust_for_decrement_pending(
597 dst_pending_id,
598 /*decrement_pending=*/2);
599 dst_dead = (adjust_result.dead_count == dst_item->num_inputs);
600 dst_ready = (adjust_result.pending_count == 0) ||
601 ((adjust_result.pending_count == 1) && dst_dead);
602 } else {
603 // Handle all other (non-merge) nodes.
604 const PendingCounts::AdjustResult adjust_result =
605 atomic ? iter_state->adjust_for_activation_atomic(dst_pending_id,
606 is_dead)
607 : iter_state->adjust_for_activation(dst_pending_id, is_dead);
608 dst_dead = adjust_result.dead_count > 0;
609 dst_ready = adjust_result.pending_count == 0;
610 }
611 maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead);
612 }
613
614 return activated;
615}
616
617bool PropagatorState::FrameState::ActivateNodesAndAdjustOutstanding(
618 const NodeItem* item, const bool is_dead, IterationState* iter_state,
619 EntryVector* outputs, TaggedNodeSeq* ready, int decrement_activation) {
620 if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
621 tf_shared_lock l(mu);
622 int activated =
623 ActivateNodesSlowPathShared(item, is_dead, iter_state, outputs, ready);
624 bool iter_done = AdjustOutstandingOpsFastPath(
625 iter_state, activated - decrement_activation);
626 if (!iter_done) return false;
627 } else {
628 tf_shared_lock l(mu);
629 int activated =
630 ActivateNodesFastPathShared(item, is_dead, iter_state, outputs, ready);
631 bool iter_done = AdjustOutstandingOpsFastPath(
632 iter_state, activated - decrement_activation);
633 if (!iter_done) return false;
634 }
635 if (decrement_activation > 0) {
636 mutex_lock l(mu);
637 return CleanupIterations(iter_state, ready);
638 } else {
639 return true;
640 }
641}
642
643int PropagatorState::FrameState::ActivateNodesLocked(const NodeItem* item,
644 const bool is_dead,
645 IterationState* iter_state,
646 EntryVector* outputs,
647 TaggedNodeSeq* ready) {
648 if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) {
649 return ActivateNodesSlowPathLocked(item, is_dead, iter_state, outputs,
650 ready);
651 } else {
652 return ActivateNodesFastPathLocked(item, is_dead, iter_state, outputs,
653 ready);
654 }
655}
656
657void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state,
658 TaggedNodeSeq* ready) {
659 int activated = 0;
660 // Propagate the deferred NextIteration nodes to the new iteration.
661 for (auto& node_entry : next_iter_roots) {
662 const NodeItem* item = node_entry.first;
663 const Entry& entry = node_entry.second;
664 const bool is_dead = entry.state == Entry::State::NO_VALUE;
665 EntryVector outputs{entry};
666 activated +=
667 ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
668 }
669 next_iter_roots.clear();
670 AdjustOutstandingOpsLocked(iter_state, activated, ready);
671}
672
673void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state,
674 TaggedNodeSeq* ready) {
675 // Propagate loop invariants to the new iteration.
676 mutex_lock l(iter_mu);
677 int activated = 0;
678 for (auto& node_entry : inv_values) {
679 const NodeItem* item = node_entry.first;
680 const Entry& entry = node_entry.second;
681 const bool is_dead = entry.state == Entry::State::NO_VALUE;
682 EntryVector outputs{entry};
683 activated +=
684 ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
685 }
686 AdjustOutstandingOpsLocked(iter_state, activated, ready);
687}
688
689void PropagatorState::FrameState::AddLoopInv(const NodeItem* item,
690 const Entry& entry,
691 TaggedNodeSeq* ready) {
692 {
693 mutex_lock l(iter_mu);
694 // Store this value.
695 inv_values.push_back({item, entry});
696 }
697
698 // Make this value available to all iterations.
699 const bool is_dead = entry.state == Entry::State::NO_VALUE;
700 for (int i = 0; i <= iteration_count; ++i) {
701 EntryVector outputs{entry};
702 IterationState* iter_state = GetIteration(i);
703 int activated =
704 ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready);
705 AdjustOutstandingOpsLocked(iter_state, activated, ready);
706 }
707}
708
709bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) {
710 if (iter_state->outstanding_ops == 0 &&
711 iter_state->outstanding_frame_count == 0) {
712 if (iter_state->iter_num == 0) {
713 // The enclosing frame has no pending input.
714 return num_pending_inputs == 0;
715 } else {
716 // The preceding iteration is deleted (and therefore done).
717 return (GetIteration(iter_state->iter_num - 1) == nullptr);
718 }
719 }
720 return false;
721}
722
723PropagatorState::IterationState*
724PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) {
725 iteration_count++;
726
727 // Initialize the next iteration.
728 IterationState* next_iter =
729 new IterationState(iteration_count, pending_counts, total_input_tensors);
730 SetIteration(iteration_count, next_iter);
731 num_outstanding_iterations++;
732 {
733 mutex_lock l(iter_mu);
734 dead_exits.clear();
735 }
736
737 // Activate the successors of the deferred roots in the new iteration.
738 ActivateNexts(next_iter, ready);
739
740 // Activate the loop invariants in the new iteration.
741 ActivateLoopInvs(next_iter, ready);
742
743 return next_iter;
744}
745
746bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state,
747 TaggedNodeSeq* ready) {
748 int64_t curr_iter = iter_state->iter_num;
749 while (curr_iter <= iteration_count && IsIterationDone(iter_state)) {
750 delete iter_state;
751 SetIteration(curr_iter, nullptr);
752 --num_outstanding_iterations;
753 ++curr_iter;
754
755 // When one iteration is completed, we check for deferred iteration,
756 // and start it if there is one.
757 bool increment_iteration = false;
758 {
759 tf_shared_lock l(iter_mu);
760 increment_iteration = !next_iter_roots.empty();
761 }
762 if (increment_iteration) {
763 IncrementIteration(ready);
764 }
765
766 if (curr_iter <= iteration_count) {
767 iter_state = GetIteration(curr_iter);
768 }
769 }
770 return IsFrameDone();
771}
772
773void PropagatorState::FrameState::InitializeFrameInfo(
774 const ImmutableExecutorState::FrameInfo& finfo) {
775 pending_counts = finfo.pending_counts.get();
776 total_input_tensors = finfo.total_inputs;
777 num_pending_inputs = finfo.input_count;
778 nodes = finfo.nodes.get();
779}
780
781void PropagatorState::FrameState::SetIteration(int64_t iter,
782 IterationState* state)
783 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
784 size_t index = iter % (max_parallel_iterations + 1);
785 DCHECK(state == nullptr || iterations[index] == nullptr);
786 iterations_raw[index] = state;
787 if (index == 0) {
788 iterations_first = state;
789 }
790}
791
792// Decrement the outstanding op count and clean up the iterations in the
793// frame. Return true iff the execution of the frame is done.
794bool PropagatorState::FrameState::DecrementOutstandingOps(
795 IterationState* iter_state, TaggedNodeSeq* ready) {
796 return AdjustOutstandingOps(iter_state, -1, ready);
797}
798
799bool PropagatorState::FrameState::AdjustOutstandingOps(
800 IterationState* iter_state, int delta, TaggedNodeSeq* ready) {
801 // Given the following profile of values of 'delta' for wide_deep model from
802 // the TF model garden:
803 //
804 // Count Value
805 // ---------------
806 // 757938 delta=0x0
807 // 541713 delta=0xffffffff
808 // 138115 delta=0x1
809 // 58770 delta=0x2
810 // 5394 delta=0x3
811 // 4669 delta=0x4
812 // 2037 delta=0xa
813 // 1646 delta=0x7
814 // 1632 delta=0x6
815 // 1613 delta=0x6c
816 // 1224 delta=0x5
817 // 409 delta=0x53
818 // 17 delta=0x86
819 //
820 // ... it's worth no-opping out when delta == 0 to avoid the atomic
821 // instruction.
822 if (delta == 0) {
823 return false;
824 }
825 {
826 tf_shared_lock sl(mu);
827 if (TF_PREDICT_TRUE(!AdjustOutstandingOpsFastPath(iter_state, delta))) {
828 return false;
829 }
830 }
831 mutex_lock l(mu);
832 DCHECK(IsIterationDone(iter_state));
833 return CleanupIterations(iter_state, ready);
834}
835
836bool PropagatorState::FrameState::AdjustOutstandingOpsFastPath(
837 IterationState* iter_state, int delta) {
838 auto old_val = iter_state->outstanding_ops.fetch_add(delta);
839 return (old_val + delta == 0) && IsIterationDone(iter_state);
840}
841
842// Decrement the outstanding op count and clean up the iterations in the
843// frame. Return true iff the execution of the frame is done.
844bool PropagatorState::FrameState::DecrementOutstandingOpsLocked(
845 IterationState* iter_state, TaggedNodeSeq* ready)
846 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
847 return AdjustOutstandingOpsLocked(iter_state, -1, ready);
848}
849
850bool PropagatorState::FrameState::AdjustOutstandingOpsLocked(
851 IterationState* iter_state, int delta, TaggedNodeSeq* ready) {
852 // We hold the lock, so we don't need to use an atomic modification.
853 auto cur_val = iter_state->outstanding_ops.load(std::memory_order_relaxed);
854 DCHECK(delta >= 0 || cur_val >= -delta)
855 << "cannot adjust outstanding_ops by " << delta
856 << " when current value is " << cur_val;
857 auto new_val = cur_val + delta;
858 iter_state->outstanding_ops.store(new_val, std::memory_order_relaxed);
859 if (new_val != 0) {
860 return false;
861 }
862 return CleanupIterations(iter_state, ready);
863}
864
865// Returns true if the computation in the frame is completed.
866bool PropagatorState::FrameState::IsFrameDone()
867 TF_EXCLUSIVE_LOCKS_REQUIRED(mu) {
868 return (num_pending_inputs == 0 && num_outstanding_iterations == 0);
869}
870
871} // namespace tensorflow
872