1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/propagator_state.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/graph_view.h" |
19 | #include "tensorflow/core/common_runtime/immutable_executor_state.h" |
20 | #include "tensorflow/core/common_runtime/propagator_debug_utils.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/lib/hash/hash.h" |
23 | #include "tensorflow/core/platform/hash.h" |
24 | #include "tensorflow/core/profiler/lib/traceme.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | PropagatorState::PropagatorState(const ImmutableExecutorState& immutable_state, |
29 | int64_t step_id, bool vlog) |
30 | : immutable_state_(immutable_state), |
31 | step_id_(step_id), |
32 | vlog_(vlog || VLOG_IS_ON(1)) { |
33 | // We start the entire execution in iteration 0 of the root frame |
34 | // so let us create the root frame and the state for iteration 0. |
35 | // We assume root_frame_->frame_name.empty(). |
36 | root_frame_ = new FrameState(immutable_state_, 1); |
37 | root_frame_->frame_id = 0; // must be 0 |
38 | root_frame_->InitializeFrameInfo(immutable_state_.get_root_frame_info()); |
39 | |
40 | // Initialize iteration 0. |
41 | root_frame_->SetIteration( |
42 | 0, new PropagatorState::IterationState(0, root_frame_->pending_counts, |
43 | root_frame_->total_input_tensors)); |
44 | |
45 | outstanding_frames_.emplace(root_frame_->frame_id, root_frame_); |
46 | } |
47 | |
48 | PropagatorState::~PropagatorState() { |
49 | for (auto name_frame : outstanding_frames_) { |
50 | delete name_frame.second; |
51 | } |
52 | } |
53 | |
54 | void PropagatorState::ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, |
55 | TaggedNodeSeq* ready) { |
56 | mutex_lock l(root_frame_->mu); |
57 | IterationState* root_iter = root_frame_->GetIteration(0); |
58 | for (const NodeItem* item : roots) { |
59 | DCHECK_EQ(item->num_inputs, 0); |
60 | ready->emplace_back(item, root_frame_, root_iter, false); |
61 | } |
62 | root_iter->outstanding_ops = ready->size(); |
63 | } |
64 | |
65 | void PropagatorState::PropagateOutputs(const TaggedNode& tagged_node, |
66 | EntryVector* outputs, |
67 | TaggedNodeSeq* ready) { |
68 | profiler::TraceMe activity( |
69 | [&]() { |
70 | return strings::StrCat( |
71 | "ExecutorPropagateOutputs#" , "id=" , step_id_, |
72 | ",kernel_name=" , tagged_node.node_item->kernel->name_view(), |
73 | ",num_output_edges=" , tagged_node.node_item->num_output_edges, |
74 | ",num_output_control_edges=" , |
75 | tagged_node.node_item->num_output_control_edges, |
76 | ",input_frame=" , tagged_node.input_frame->frame_id, "#" ); |
77 | }, |
78 | profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); |
79 | |
80 | const NodeItem* const item = tagged_node.node_item; |
81 | FrameState* const input_frame = tagged_node.input_frame; |
82 | IterationState* const input_iter = tagged_node.input_iter; |
83 | const bool is_dead = tagged_node.is_dead; |
84 | |
85 | // Propagates outputs along out edges, and puts newly ready nodes |
86 | // into the ready queue. |
87 | DCHECK(ready->empty()); |
88 | bool is_frame_done = false; |
89 | FrameState* output_frame = input_frame; |
90 | IterationState* output_iter = input_iter; |
91 | |
92 | if (!item->is_enter_exit_or_next_iter) { |
93 | // Fast path for node types that don't need special handling. |
94 | // This is the case for most nodes. |
95 | DCHECK_EQ(input_frame, output_frame); |
96 | FrameState* frame = input_frame; |
97 | is_frame_done = frame->ActivateNodesAndAdjustOutstanding( |
98 | item, is_dead, output_iter, outputs, ready, /*decrement_activation=*/1); |
99 | } else if (item->is_enter) { |
100 | FindOrCreateChildFrame(input_frame, input_iter, *item, &output_frame); |
101 | { |
102 | mutex_lock l(output_frame->mu); |
103 | output_iter = output_frame->GetIteration(0); |
104 | if (item->is_constant_enter) { |
105 | // Propagate to all active iterations if this is a loop invariant. |
106 | output_frame->AddLoopInv(item, (*outputs)[0], ready); |
107 | } else { |
108 | int activated = output_frame->ActivateNodesLocked( |
109 | item, is_dead, output_iter, outputs, ready); |
110 | output_frame->AdjustOutstandingOpsLocked(output_iter, activated, ready); |
111 | } |
112 | output_frame->num_pending_inputs--; |
113 | } |
114 | is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); |
115 | } else if (item->is_exit) { |
116 | if (is_dead) { |
117 | { |
118 | tf_shared_lock l(input_frame->mu); |
119 | // Stop and remember this node if it is a dead exit. |
120 | if (input_iter->iter_num == input_frame->iteration_count) { |
121 | mutex_lock l(input_frame->iter_mu); |
122 | input_frame->dead_exits.push_back(item); |
123 | } |
124 | } |
125 | is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); |
126 | } else { |
127 | output_frame = input_frame->parent_frame; |
128 | output_iter = input_frame->parent_iter; |
129 | output_frame->ActivateNodesAndAdjustOutstanding( |
130 | item, is_dead, output_iter, outputs, ready, |
131 | /*decrement_activation=*/0); |
132 | is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); |
133 | } |
134 | } else { |
135 | DCHECK(item->is_next_iteration); |
136 | if (is_dead) { |
137 | // Stop the deadness propagation. |
138 | output_frame = nullptr; |
139 | } else { |
140 | bool need_create_iter = false; |
141 | { |
142 | tf_shared_lock l(input_frame->mu); |
143 | if (input_iter->iter_num == input_frame->iteration_count) { |
144 | if (input_frame->num_outstanding_iterations == |
145 | input_frame->max_parallel_iterations) { |
146 | // Reached the maximum for parallel iterations. |
147 | output_frame = nullptr; |
148 | mutex_lock l(input_frame->iter_mu); |
149 | input_frame->next_iter_roots.push_back({item, (*outputs)[0]}); |
150 | } else { |
151 | // Need to create iteration state after acquiring mutex lock. |
152 | need_create_iter = true; |
153 | } |
154 | } else { |
155 | output_iter = input_frame->GetIteration(input_iter->iter_num + 1); |
156 | } |
157 | } |
158 | if (output_frame != nullptr) { |
159 | if (need_create_iter) { |
160 | profiler::TraceMe activit1y( |
161 | [&]() { |
162 | return strings::StrCat( |
163 | "PropagateOutputs::NextIteration::CreateIterationState" ); |
164 | }, |
165 | profiler::GetTFTraceMeLevel(/*is_expensive=*/false)); |
166 | mutex_lock l(input_frame->mu); |
167 | if (input_iter->iter_num == input_frame->iteration_count) { |
168 | // Check another time since another thread may create the required |
169 | // iteration state. |
170 | // TODO(fishx): This may cause contention since multiple threads may |
171 | // race for this mutex lock. Further improve this if needed. |
172 | output_iter = input_frame->IncrementIteration(ready); |
173 | } else { |
174 | output_iter = input_frame->GetIteration(input_iter->iter_num + 1); |
175 | } |
176 | DCHECK(input_frame == output_frame); |
177 | int activated = output_frame->ActivateNodesLocked( |
178 | item, is_dead, output_iter, outputs, ready); |
179 | output_frame->AdjustOutstandingOpsLocked(output_iter, activated, |
180 | ready); |
181 | } else { |
182 | DCHECK(input_frame == output_frame); |
183 | output_frame->ActivateNodesAndAdjustOutstanding( |
184 | item, is_dead, output_iter, outputs, ready, |
185 | /*decrement_activation=*/0); |
186 | } |
187 | } |
188 | } |
189 | is_frame_done = input_frame->DecrementOutstandingOps(input_iter, ready); |
190 | } |
191 | |
192 | // At this point, this node is completely done. We also know if the |
193 | // completion of this node makes its frame completed. |
194 | if (is_frame_done) { |
195 | FrameState* parent_frame = input_frame->parent_frame; |
196 | IterationState* parent_iter = input_frame->parent_iter; |
197 | DeleteFrame(input_frame, ready); |
198 | if (parent_frame != nullptr) { |
199 | // The completion of frame may cause completions in its parent frame. |
200 | // So clean things up recursively. |
201 | CleanupFramesIterations(parent_frame, parent_iter, ready); |
202 | } |
203 | } |
204 | } |
205 | |
206 | void PropagatorState::DumpIterationState(const FrameState* frame, |
207 | IterationState* iteration) { |
208 | const std::vector<const NodeItem*>* nodes = frame->nodes; |
209 | // Dump any waiting nodes that are holding on to tensors. |
210 | for (const NodeItem* node : *nodes) { |
211 | PendingCounts::Handle pending_id = |
212 | immutable_state_.pending_ids()[node->node_id]; |
213 | if (iteration->node_state(pending_id) == PendingCounts::PENDING_NOTREADY || |
214 | iteration->node_state(pending_id) == PendingCounts::PENDING_READY) { |
215 | DumpPendingNodeState(*node, iteration->input_tensors, false); |
216 | } |
217 | } |
218 | // Then the active nodes. |
219 | for (const NodeItem* node : *nodes) { |
220 | PendingCounts::Handle pending_id = |
221 | immutable_state_.pending_ids()[node->node_id]; |
222 | if (iteration->node_state(pending_id) == PendingCounts::STARTED) { |
223 | DumpActiveNodeState(*node, iteration->input_tensors); |
224 | } |
225 | } |
226 | // Show all input tensors in use. |
227 | const int total_input_tensors = frame->total_input_tensors; |
228 | size_t total_bytes = 0; |
229 | for (int i = 0; i < total_input_tensors; ++i) { |
230 | const Entry& input = iteration->input_tensors[i]; |
231 | const Tensor* tensor = GetTensorValueForDump(input); |
232 | if (tensor->IsInitialized()) { |
233 | LOG(WARNING) << " Input " << i << ": " |
234 | << strings::StrCat( |
235 | "Tensor<type: " , DataTypeString(tensor->dtype()), |
236 | " shape: " , tensor->shape().DebugString(), |
237 | ", bytes: " , tensor->TotalBytes(), ">" ); |
238 | total_bytes += tensor->TotalBytes(); |
239 | } |
240 | } |
241 | LOG(WARNING) << " Total bytes " << total_bytes; |
242 | } |
243 | |
244 | void PropagatorState::DumpState() { |
245 | mutex_lock l(mu_); |
246 | LOG(WARNING) << "Dumping state" ; |
247 | for (auto& frame : outstanding_frames_) { |
248 | LOG(WARNING) << frame.first; |
249 | FrameState* frame_state = frame.second; |
250 | frame_state->DumpIterationState(this); |
251 | } |
252 | } |
253 | |
254 | void PropagatorState::FindOrCreateChildFrame(FrameState* frame, |
255 | IterationState* iter_state, |
256 | const NodeItem& node_item, |
257 | FrameState** child) { |
258 | // Get the child frame name. |
259 | const ImmutableExecutorState::FrameInfo& frame_info = |
260 | immutable_state_.get_enter_frame_info(node_item); |
261 | |
262 | const uint64 child_id = Hash64Combine( |
263 | frame->frame_id, |
264 | Hash64Combine(iter_state->iter_num, Hash64(frame_info.name))); |
265 | |
266 | { |
267 | tf_shared_lock executor_lock(mu_); |
268 | auto it = outstanding_frames_.find(child_id); |
269 | if (it != outstanding_frames_.end()) { |
270 | *child = it->second; |
271 | return; |
272 | } |
273 | } |
274 | |
275 | // Need to create a new frame instance. |
276 | // Note that this new frame instance is created without any locks. |
277 | if (vlog_) { |
278 | const string child_name = strings::StrCat( |
279 | frame->frame_name, ";" , iter_state->iter_num, ";" , frame_info.name); |
280 | VLOG(2) << "Create frame: " << child_name << " id: " << child_id; |
281 | } |
282 | |
283 | FrameState* temp = |
284 | new FrameState(immutable_state_, frame_info.parallel_iterations); |
285 | temp->frame_id = child_id; |
286 | temp->parent_frame = frame; |
287 | temp->parent_iter = iter_state; |
288 | temp->InitializeFrameInfo(frame_info); |
289 | |
290 | // Initialize iteration 0. |
291 | { |
292 | mutex_lock l(temp->mu); |
293 | temp->SetIteration(0, new IterationState(0, temp->pending_counts, |
294 | temp->total_input_tensors)); |
295 | } |
296 | |
297 | { |
298 | mutex_lock executor_lock(mu_); |
299 | auto it = outstanding_frames_.find(child_id); |
300 | if (it != outstanding_frames_.end()) { |
301 | *child = it->second; |
302 | } else { |
303 | mutex_lock frame_lock(frame->mu); |
304 | iter_state->outstanding_frame_count++; |
305 | outstanding_frames_[child_id] = temp; |
306 | *child = temp; |
307 | temp = nullptr; |
308 | } |
309 | } |
310 | delete temp; // Not used so delete it. |
311 | } |
312 | |
313 | void PropagatorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) { |
314 | // First, propagate dead_exits (if any) to the parent frame. |
315 | FrameState* parent_frame = frame->parent_frame; |
316 | IterationState* parent_iter_state = frame->parent_iter; |
317 | if (parent_frame != nullptr) { |
318 | mutex_lock parent_frame_lock(parent_frame->mu); |
319 | // Propagate all the dead exits to the parent frame. |
320 | mutex_lock this_frame_lock(frame->mu); |
321 | mutex_lock iter_lock(frame->iter_mu); |
322 | |
323 | for (const NodeItem* item : frame->dead_exits) { |
324 | auto maybe_add_to_ready = [&](const NodeItem& dst_item, bool dst_ready, |
325 | bool dst_dead) { |
326 | if (dst_ready) { |
327 | if (dst_item.is_control_trigger) dst_dead = false; |
328 | ready->emplace_back(&dst_item, parent_frame, parent_iter_state, |
329 | dst_dead); |
330 | parent_iter_state->outstanding_ops++; |
331 | } |
332 | }; |
333 | |
334 | auto propagate_to_non_merge = [&](PendingCounts::Handle dst_pending_id) { |
335 | parent_iter_state->increment_dead_count(dst_pending_id); |
336 | return parent_iter_state->decrement_pending(dst_pending_id, 1) == 0; |
337 | }; |
338 | |
339 | for (const EdgeInfo& e : item->output_edges()) { |
340 | const NodeItem& dst_item = |
341 | immutable_state_.graph_view().node_ref(e.dst_id); |
342 | const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; |
343 | |
344 | bool dst_dead = true; |
345 | bool dst_ready; |
346 | // We know this is a dead input to dst. |
347 | if (dst_item.is_merge) { |
348 | parent_iter_state->increment_dead_count(dst_pending_id); |
349 | const int dead_cnt = parent_iter_state->dead_count(dst_pending_id); |
350 | dst_dead = (dead_cnt == dst_item.num_inputs); |
351 | dst_ready = |
352 | (parent_iter_state->pending(dst_pending_id) == 1) && dst_dead; |
353 | } else { |
354 | dst_ready = propagate_to_non_merge(dst_pending_id); |
355 | } |
356 | maybe_add_to_ready(dst_item, dst_ready, dst_dead); |
357 | } |
358 | |
359 | for (const ControlEdgeInfo& e : item->output_control_edges()) { |
360 | const NodeItem& dst_item = |
361 | immutable_state_.graph_view().node_ref(e.dst_id); |
362 | const auto dst_pending_id = immutable_state_.pending_ids()[e.dst_id]; |
363 | |
364 | bool dst_dead; |
365 | bool dst_ready; |
366 | // We know this is a dead input to dst. |
367 | if (dst_item.is_merge) { |
368 | parent_iter_state->decrement_pending(dst_pending_id, 2); |
369 | int count = parent_iter_state->pending(dst_pending_id); |
370 | int dead_cnt = parent_iter_state->dead_count(dst_pending_id); |
371 | dst_dead = (dead_cnt == dst_item.num_inputs); |
372 | dst_ready = (count == 0) || ((count == 1) && dst_dead); |
373 | } else { |
374 | dst_dead = true; |
375 | dst_ready = propagate_to_non_merge(dst_pending_id); |
376 | } |
377 | maybe_add_to_ready(dst_item, dst_ready, dst_dead); |
378 | } |
379 | } |
380 | } |
381 | |
382 | // Delete the frame. |
383 | if (vlog_) VLOG(2) << "Delete frame " << frame->frame_id; |
384 | { |
385 | mutex_lock executor_lock(mu_); |
386 | outstanding_frames_.erase(frame->frame_id); |
387 | } |
388 | delete frame; |
389 | } |
390 | |
391 | void PropagatorState::CleanupFramesIterations(FrameState* frame, |
392 | IterationState* iter_state, |
393 | TaggedNodeSeq* ready) { |
394 | bool is_frame_done = false; |
395 | { |
396 | mutex_lock frame_lock(frame->mu); |
397 | iter_state->outstanding_frame_count--; |
398 | is_frame_done = frame->CleanupIterations(iter_state, ready); |
399 | } |
400 | if (is_frame_done) { |
401 | FrameState* parent_frame = frame->parent_frame; |
402 | IterationState* parent_iter = frame->parent_iter; |
403 | DeleteFrame(frame, ready); |
404 | if (parent_frame != nullptr) { |
405 | // The completion of frame may cause completions in its parent frame. |
406 | // So clean things up recursively. |
407 | CleanupFramesIterations(parent_frame, parent_iter, ready); |
408 | } |
409 | } |
410 | } |
411 | |
412 | template <bool atomic> |
413 | int PropagatorState::FrameState::ActivateNodesFastPathInternal( |
414 | const NodeItem* item, const bool is_dead, IterationState* iter_state, |
415 | EntryVector* outputs, TaggedNodeSeq* ready) { |
416 | // If we know that none of the item's edge destinations require special |
417 | // handling (i.e. none of the nodes is a merge or control trigger node), we |
418 | // can take a fast path that avoids accessing the destination NodeItem. |
419 | const GraphView& gview = immutable_state.graph_view(); |
420 | int new_outstanding = 0; |
421 | |
422 | // Add dst to the ready queue if it's ready |
423 | // |
424 | // NOTE(mrry): Use a macro here instead of a lambda, because this method is |
425 | // performance-critical and we need to ensure that the code is inlined. |
426 | #define MAYBE_ADD_TO_READY(dst_id, adjust_result) \ |
427 | do { \ |
428 | if (!(adjust_result.pending_count > 0)) { \ |
429 | const NodeItem* dst_item = &gview.node_ref(dst_id); \ |
430 | TaggedNode& t = ready->emplace_back(); \ |
431 | t.node_item = dst_item; \ |
432 | t.input_frame = this; \ |
433 | t.input_iter = iter_state; \ |
434 | t.is_dead = adjust_result.dead_count > 0; \ |
435 | new_outstanding++; \ |
436 | } \ |
437 | } while (0); |
438 | |
439 | Entry* input_tensors = iter_state->input_tensors; |
440 | for (const EdgeInfo& e : item->output_edges()) { |
441 | const int dst_id = e.dst_id; |
442 | const PendingCounts::Handle dst_pending_id = |
443 | immutable_state.pending_ids()[dst_id]; |
444 | const int src_slot = e.output_slot; |
445 | |
446 | const bool increment_dead = |
447 | (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); |
448 | const int dst_loc = e.input_slot; |
449 | if (e.is_last) { |
450 | input_tensors[dst_loc] = std::move((*outputs)[src_slot]); |
451 | } else { |
452 | input_tensors[dst_loc] = (*outputs)[src_slot]; |
453 | } |
454 | const PendingCounts::AdjustResult adjust_result = |
455 | atomic |
456 | ? iter_state->adjust_for_activation_atomic(dst_pending_id, |
457 | increment_dead) |
458 | : iter_state->adjust_for_activation(dst_pending_id, increment_dead); |
459 | MAYBE_ADD_TO_READY(dst_id, adjust_result); |
460 | } |
461 | |
462 | for (const ControlEdgeInfo& e : item->output_control_edges()) { |
463 | const int dst_id = e.dst_id; |
464 | const PendingCounts::Handle dst_pending_id = |
465 | immutable_state.pending_ids()[dst_id]; |
466 | const PendingCounts::AdjustResult adjust_result = |
467 | atomic |
468 | ? iter_state->adjust_for_activation_atomic(dst_pending_id, is_dead) |
469 | : iter_state->adjust_for_activation(dst_pending_id, is_dead); |
470 | MAYBE_ADD_TO_READY(dst_id, adjust_result); |
471 | } |
472 | |
473 | return new_outstanding; |
474 | #undef MAYBE_ADD_TO_READY |
475 | } |
476 | |
477 | template <bool atomic> |
478 | int PropagatorState::FrameState::ActivateNodesSlowPathInternal( |
479 | const NodeItem* item, const bool is_dead, IterationState* iter_state, |
480 | EntryVector* outputs, TaggedNodeSeq* ready) { |
481 | // If any of the edge destinations is a merge or a control trigger node, |
482 | // we need to read each destination NodeItem to determine what action |
483 | // to take. |
484 | const GraphView& gview = immutable_state.graph_view(); |
485 | int activated = 0; |
486 | auto maybe_add_to_ready = [&](int dst_id, const NodeItem* dst_item, |
487 | bool dst_ready, bool dst_dead) { |
488 | // Add dst to the ready queue if it's ready |
489 | if (dst_ready) { |
490 | if (dst_item->is_control_trigger) dst_dead = false; |
491 | ready->emplace_back(dst_item, this, iter_state, dst_dead); |
492 | activated++; |
493 | } |
494 | }; |
495 | |
496 | Entry* input_tensors = iter_state->input_tensors; |
497 | |
498 | for (const EdgeInfo& e : item->output_edges()) { |
499 | const int dst_id = e.dst_id; |
500 | const NodeItem* dst_item = &gview.node_ref(dst_id); |
501 | const PendingCounts::Handle dst_pending_id = |
502 | immutable_state.pending_ids()[dst_id]; |
503 | const int src_slot = e.output_slot; |
504 | |
505 | bool dst_dead = false; |
506 | bool dst_ready = false; |
507 | |
508 | if (dst_item->is_merge) { |
509 | // A merge node is ready if all control inputs have arrived and either |
510 | // a) a live data input becomes available or b) all data inputs are |
511 | // dead. For Merge, pending's LSB is set iff a live data input has |
512 | // arrived. |
513 | if ((*outputs)[src_slot].state != Entry::State::NO_VALUE) { |
514 | // This is a live data input. |
515 | |
516 | // We have an assumption that merge op has only one live edge. Based on |
517 | // this assumption, we set the total pending count of a merge op to be |
518 | // 2 * control_edge + 1. So if we got a live data input of merge op, it |
519 | // is fine to directly set the input. |
520 | // NOTE(fishx): If there are multiple live edge for merge op, it will |
521 | // be a race condition and the last live edge will override the input |
522 | // of merge op. This should indicates a graph level issue. |
523 | const int dst_loc = e.input_slot; |
524 | if (e.is_last) { |
525 | input_tensors[dst_loc] = std::move((*outputs)[src_slot]); |
526 | } else { |
527 | input_tensors[dst_loc] = (*outputs)[src_slot]; |
528 | } |
529 | |
530 | const PendingCounts::AdjustResult adjust_result = |
531 | atomic ? iter_state->adjust_for_mark_live_atomic(dst_pending_id) |
532 | : iter_state->adjust_for_mark_live(dst_pending_id); |
533 | |
534 | // The low bit of count is set if and only if no live input has been |
535 | // used yet (mark_live clears it). The node should be started if and |
536 | // only if this is the first live input and there are no pending control |
537 | // edges, i.e. count == 1. |
538 | dst_ready = (adjust_result.pending_count == 1); |
539 | } else { |
540 | // This is a dead data input. Note that dst_node is dead if node is |
541 | // a dead enter. We need this to handle properly a while loop on |
542 | // the untaken branch of a conditional. |
543 | // TODO(yuanbyu): This is a bit hacky, but a good solution for |
544 | // now. |
545 | const PendingCounts::AdjustResult adjust_result = |
546 | atomic |
547 | ? iter_state->adjust_for_increment_dead_atomic(dst_pending_id) |
548 | : iter_state->adjust_for_increment_dead(dst_pending_id); |
549 | dst_dead = (adjust_result.dead_count == dst_item->num_inputs) || |
550 | item->is_enter; |
551 | dst_ready = (adjust_result.pending_count == 1) && dst_dead; |
552 | } |
553 | } else { |
554 | // Handle all other (non-merge) nodes. |
555 | |
556 | // We need to set the input of the op before adjusting activation. |
557 | // Otherwise it may have race condition since another thread may execute |
558 | // the op after adjusted activation. |
559 | const int dst_loc = e.input_slot; |
560 | if (e.is_last) { |
561 | input_tensors[dst_loc] = std::move((*outputs)[src_slot]); |
562 | } else { |
563 | input_tensors[dst_loc] = (*outputs)[src_slot]; |
564 | } |
565 | |
566 | const bool increment_dead = |
567 | (is_dead || ((*outputs)[src_slot].state == Entry::State::NO_VALUE)); |
568 | const PendingCounts::AdjustResult adjust_result = |
569 | atomic ? iter_state->adjust_for_activation_atomic(dst_pending_id, |
570 | increment_dead) |
571 | : iter_state->adjust_for_activation(dst_pending_id, |
572 | increment_dead); |
573 | dst_dead = adjust_result.dead_count > 0; |
574 | dst_ready = !(adjust_result.pending_count > 0); |
575 | } |
576 | |
577 | maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); |
578 | } |
579 | |
580 | for (const ControlEdgeInfo& e : item->output_control_edges()) { |
581 | const int dst_id = e.dst_id; |
582 | const NodeItem* dst_item = &gview.node_ref(dst_id); |
583 | const PendingCounts::Handle dst_pending_id = |
584 | immutable_state.pending_ids()[dst_id]; |
585 | |
586 | bool dst_dead; |
587 | bool dst_ready; |
588 | if (dst_item->is_merge) { |
589 | // A merge node is ready if all control inputs have arrived and either |
590 | // a) a live data input becomes available or b) all data inputs are |
591 | // dead. For Merge, pending's LSB is set iff a live data input has |
592 | // arrived. |
593 | const PendingCounts::AdjustResult adjust_result = |
594 | atomic ? iter_state->adjust_for_decrement_pending_atomic( |
595 | dst_pending_id, /*decrement_pending=*/2) |
596 | : iter_state->adjust_for_decrement_pending( |
597 | dst_pending_id, |
598 | /*decrement_pending=*/2); |
599 | dst_dead = (adjust_result.dead_count == dst_item->num_inputs); |
600 | dst_ready = (adjust_result.pending_count == 0) || |
601 | ((adjust_result.pending_count == 1) && dst_dead); |
602 | } else { |
603 | // Handle all other (non-merge) nodes. |
604 | const PendingCounts::AdjustResult adjust_result = |
605 | atomic ? iter_state->adjust_for_activation_atomic(dst_pending_id, |
606 | is_dead) |
607 | : iter_state->adjust_for_activation(dst_pending_id, is_dead); |
608 | dst_dead = adjust_result.dead_count > 0; |
609 | dst_ready = adjust_result.pending_count == 0; |
610 | } |
611 | maybe_add_to_ready(dst_id, dst_item, dst_ready, dst_dead); |
612 | } |
613 | |
614 | return activated; |
615 | } |
616 | |
617 | bool PropagatorState::FrameState::ActivateNodesAndAdjustOutstanding( |
618 | const NodeItem* item, const bool is_dead, IterationState* iter_state, |
619 | EntryVector* outputs, TaggedNodeSeq* ready, int decrement_activation) { |
620 | if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { |
621 | tf_shared_lock l(mu); |
622 | int activated = |
623 | ActivateNodesSlowPathShared(item, is_dead, iter_state, outputs, ready); |
624 | bool iter_done = AdjustOutstandingOpsFastPath( |
625 | iter_state, activated - decrement_activation); |
626 | if (!iter_done) return false; |
627 | } else { |
628 | tf_shared_lock l(mu); |
629 | int activated = |
630 | ActivateNodesFastPathShared(item, is_dead, iter_state, outputs, ready); |
631 | bool iter_done = AdjustOutstandingOpsFastPath( |
632 | iter_state, activated - decrement_activation); |
633 | if (!iter_done) return false; |
634 | } |
635 | if (decrement_activation > 0) { |
636 | mutex_lock l(mu); |
637 | return CleanupIterations(iter_state, ready); |
638 | } else { |
639 | return true; |
640 | } |
641 | } |
642 | |
643 | int PropagatorState::FrameState::ActivateNodesLocked(const NodeItem* item, |
644 | const bool is_dead, |
645 | IterationState* iter_state, |
646 | EntryVector* outputs, |
647 | TaggedNodeSeq* ready) { |
648 | if (TF_PREDICT_FALSE(item->is_any_consumer_merge_or_control_trigger)) { |
649 | return ActivateNodesSlowPathLocked(item, is_dead, iter_state, outputs, |
650 | ready); |
651 | } else { |
652 | return ActivateNodesFastPathLocked(item, is_dead, iter_state, outputs, |
653 | ready); |
654 | } |
655 | } |
656 | |
657 | void PropagatorState::FrameState::ActivateNexts(IterationState* iter_state, |
658 | TaggedNodeSeq* ready) { |
659 | int activated = 0; |
660 | // Propagate the deferred NextIteration nodes to the new iteration. |
661 | for (auto& node_entry : next_iter_roots) { |
662 | const NodeItem* item = node_entry.first; |
663 | const Entry& entry = node_entry.second; |
664 | const bool is_dead = entry.state == Entry::State::NO_VALUE; |
665 | EntryVector outputs{entry}; |
666 | activated += |
667 | ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); |
668 | } |
669 | next_iter_roots.clear(); |
670 | AdjustOutstandingOpsLocked(iter_state, activated, ready); |
671 | } |
672 | |
673 | void PropagatorState::FrameState::ActivateLoopInvs(IterationState* iter_state, |
674 | TaggedNodeSeq* ready) { |
675 | // Propagate loop invariants to the new iteration. |
676 | mutex_lock l(iter_mu); |
677 | int activated = 0; |
678 | for (auto& node_entry : inv_values) { |
679 | const NodeItem* item = node_entry.first; |
680 | const Entry& entry = node_entry.second; |
681 | const bool is_dead = entry.state == Entry::State::NO_VALUE; |
682 | EntryVector outputs{entry}; |
683 | activated += |
684 | ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); |
685 | } |
686 | AdjustOutstandingOpsLocked(iter_state, activated, ready); |
687 | } |
688 | |
689 | void PropagatorState::FrameState::AddLoopInv(const NodeItem* item, |
690 | const Entry& entry, |
691 | TaggedNodeSeq* ready) { |
692 | { |
693 | mutex_lock l(iter_mu); |
694 | // Store this value. |
695 | inv_values.push_back({item, entry}); |
696 | } |
697 | |
698 | // Make this value available to all iterations. |
699 | const bool is_dead = entry.state == Entry::State::NO_VALUE; |
700 | for (int i = 0; i <= iteration_count; ++i) { |
701 | EntryVector outputs{entry}; |
702 | IterationState* iter_state = GetIteration(i); |
703 | int activated = |
704 | ActivateNodesLocked(item, is_dead, iter_state, &outputs, ready); |
705 | AdjustOutstandingOpsLocked(iter_state, activated, ready); |
706 | } |
707 | } |
708 | |
709 | bool PropagatorState::FrameState::IsIterationDone(IterationState* iter_state) { |
710 | if (iter_state->outstanding_ops == 0 && |
711 | iter_state->outstanding_frame_count == 0) { |
712 | if (iter_state->iter_num == 0) { |
713 | // The enclosing frame has no pending input. |
714 | return num_pending_inputs == 0; |
715 | } else { |
716 | // The preceding iteration is deleted (and therefore done). |
717 | return (GetIteration(iter_state->iter_num - 1) == nullptr); |
718 | } |
719 | } |
720 | return false; |
721 | } |
722 | |
723 | PropagatorState::IterationState* |
724 | PropagatorState::FrameState::IncrementIteration(TaggedNodeSeq* ready) { |
725 | iteration_count++; |
726 | |
727 | // Initialize the next iteration. |
728 | IterationState* next_iter = |
729 | new IterationState(iteration_count, pending_counts, total_input_tensors); |
730 | SetIteration(iteration_count, next_iter); |
731 | num_outstanding_iterations++; |
732 | { |
733 | mutex_lock l(iter_mu); |
734 | dead_exits.clear(); |
735 | } |
736 | |
737 | // Activate the successors of the deferred roots in the new iteration. |
738 | ActivateNexts(next_iter, ready); |
739 | |
740 | // Activate the loop invariants in the new iteration. |
741 | ActivateLoopInvs(next_iter, ready); |
742 | |
743 | return next_iter; |
744 | } |
745 | |
746 | bool PropagatorState::FrameState::CleanupIterations(IterationState* iter_state, |
747 | TaggedNodeSeq* ready) { |
748 | int64_t curr_iter = iter_state->iter_num; |
749 | while (curr_iter <= iteration_count && IsIterationDone(iter_state)) { |
750 | delete iter_state; |
751 | SetIteration(curr_iter, nullptr); |
752 | --num_outstanding_iterations; |
753 | ++curr_iter; |
754 | |
755 | // When one iteration is completed, we check for deferred iteration, |
756 | // and start it if there is one. |
757 | bool increment_iteration = false; |
758 | { |
759 | tf_shared_lock l(iter_mu); |
760 | increment_iteration = !next_iter_roots.empty(); |
761 | } |
762 | if (increment_iteration) { |
763 | IncrementIteration(ready); |
764 | } |
765 | |
766 | if (curr_iter <= iteration_count) { |
767 | iter_state = GetIteration(curr_iter); |
768 | } |
769 | } |
770 | return IsFrameDone(); |
771 | } |
772 | |
773 | void PropagatorState::FrameState::InitializeFrameInfo( |
774 | const ImmutableExecutorState::FrameInfo& finfo) { |
775 | pending_counts = finfo.pending_counts.get(); |
776 | total_input_tensors = finfo.total_inputs; |
777 | num_pending_inputs = finfo.input_count; |
778 | nodes = finfo.nodes.get(); |
779 | } |
780 | |
781 | void PropagatorState::FrameState::SetIteration(int64_t iter, |
782 | IterationState* state) |
783 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { |
784 | size_t index = iter % (max_parallel_iterations + 1); |
785 | DCHECK(state == nullptr || iterations[index] == nullptr); |
786 | iterations_raw[index] = state; |
787 | if (index == 0) { |
788 | iterations_first = state; |
789 | } |
790 | } |
791 | |
792 | // Decrement the outstanding op count and clean up the iterations in the |
793 | // frame. Return true iff the execution of the frame is done. |
794 | bool PropagatorState::FrameState::DecrementOutstandingOps( |
795 | IterationState* iter_state, TaggedNodeSeq* ready) { |
796 | return AdjustOutstandingOps(iter_state, -1, ready); |
797 | } |
798 | |
799 | bool PropagatorState::FrameState::AdjustOutstandingOps( |
800 | IterationState* iter_state, int delta, TaggedNodeSeq* ready) { |
801 | // Given the following profile of values of 'delta' for wide_deep model from |
802 | // the TF model garden: |
803 | // |
804 | // Count Value |
805 | // --------------- |
806 | // 757938 delta=0x0 |
807 | // 541713 delta=0xffffffff |
808 | // 138115 delta=0x1 |
809 | // 58770 delta=0x2 |
810 | // 5394 delta=0x3 |
811 | // 4669 delta=0x4 |
812 | // 2037 delta=0xa |
813 | // 1646 delta=0x7 |
814 | // 1632 delta=0x6 |
815 | // 1613 delta=0x6c |
816 | // 1224 delta=0x5 |
817 | // 409 delta=0x53 |
818 | // 17 delta=0x86 |
819 | // |
820 | // ... it's worth no-opping out when delta == 0 to avoid the atomic |
821 | // instruction. |
822 | if (delta == 0) { |
823 | return false; |
824 | } |
825 | { |
826 | tf_shared_lock sl(mu); |
827 | if (TF_PREDICT_TRUE(!AdjustOutstandingOpsFastPath(iter_state, delta))) { |
828 | return false; |
829 | } |
830 | } |
831 | mutex_lock l(mu); |
832 | DCHECK(IsIterationDone(iter_state)); |
833 | return CleanupIterations(iter_state, ready); |
834 | } |
835 | |
836 | bool PropagatorState::FrameState::AdjustOutstandingOpsFastPath( |
837 | IterationState* iter_state, int delta) { |
838 | auto old_val = iter_state->outstanding_ops.fetch_add(delta); |
839 | return (old_val + delta == 0) && IsIterationDone(iter_state); |
840 | } |
841 | |
842 | // Decrement the outstanding op count and clean up the iterations in the |
843 | // frame. Return true iff the execution of the frame is done. |
844 | bool PropagatorState::FrameState::DecrementOutstandingOpsLocked( |
845 | IterationState* iter_state, TaggedNodeSeq* ready) |
846 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { |
847 | return AdjustOutstandingOpsLocked(iter_state, -1, ready); |
848 | } |
849 | |
850 | bool PropagatorState::FrameState::AdjustOutstandingOpsLocked( |
851 | IterationState* iter_state, int delta, TaggedNodeSeq* ready) { |
852 | // We hold the lock, so we don't need to use an atomic modification. |
853 | auto cur_val = iter_state->outstanding_ops.load(std::memory_order_relaxed); |
854 | DCHECK(delta >= 0 || cur_val >= -delta) |
855 | << "cannot adjust outstanding_ops by " << delta |
856 | << " when current value is " << cur_val; |
857 | auto new_val = cur_val + delta; |
858 | iter_state->outstanding_ops.store(new_val, std::memory_order_relaxed); |
859 | if (new_val != 0) { |
860 | return false; |
861 | } |
862 | return CleanupIterations(iter_state, ready); |
863 | } |
864 | |
865 | // Returns true if the computation in the frame is completed. |
866 | bool PropagatorState::FrameState::IsFrameDone() |
867 | TF_EXCLUSIVE_LOCKS_REQUIRED(mu) { |
868 | return (num_pending_inputs == 0 && num_outstanding_iterations == 0); |
869 | } |
870 | |
871 | } // namespace tensorflow |
872 | |