1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/common_runtime/simple_propagator_state.h"
16
17#include <atomic>
18
19#include "tensorflow/core/common_runtime/propagator_debug_utils.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/profiler/lib/traceme.h"
22
23namespace tensorflow {
24
25SimplePropagatorState::SimplePropagatorState(
26 const ImmutableExecutorState& immutable_state, int64_t step_id, bool vlog)
27 : SimplePropagatorState(immutable_state, step_id,
28 immutable_state.get_root_frame_info(), vlog) {}
29
30SimplePropagatorState::SimplePropagatorState(
31 const ImmutableExecutorState& immutable_state, int64_t step_id,
32 const ImmutableExecutorState::FrameInfo& finfo, bool vlog)
33 : immutable_state_(immutable_state),
34 step_id_(step_id),
35 vlog_(vlog || VLOG_IS_ON(1)),
36 input_tensors_(finfo.total_inputs),
37 pending_(
38 new std::atomic<int32>[immutable_state.graph_view().num_nodes()]),
39 active_(vlog_ ? new std::vector<bool>(
40 immutable_state.graph_view().num_nodes())
41 : nullptr),
42 nodes_(finfo.nodes.get()) {
43 immutable_state_.copy_pending_counts(pending_.get());
44}
45
46SimplePropagatorState::~SimplePropagatorState() {}
47
48void SimplePropagatorState::ActivateRoots(
49 gtl::ArraySlice<const NodeItem*> roots, TaggedNodeSeq* ready) {
50 for (const NodeItem* item : roots) {
51 DCHECK_EQ(item->num_inputs, 0);
52 ready->push_back(TaggedNode{item});
53 }
54}
55
56void SimplePropagatorState::PropagateOutputs(const TaggedNode& tagged_node,
57 EntryVector* outputs,
58 TaggedNodeSeq* ready) {
59 profiler::TraceMe activity(
60 [&]() {
61 return strings::StrCat(
62 "ExecutorPropagateOutputs#", "id=", step_id_,
63 ",kernel_name=", tagged_node.node_item->kernel->name_view(),
64 ",num_output_edges=", tagged_node.node_item->num_output_edges,
65 ",num_output_control_edges=",
66 tagged_node.node_item->num_output_control_edges, "#");
67 },
68 profiler::GetTFTraceMeLevel(/*is_expensive=*/false));
69
70 // Propagates outputs along out edges, and puts newly ready nodes
71 // into the ready queue.
72 DCHECK(ready->empty());
73
74 const GraphView& gview = immutable_state_.graph_view();
75 const NodeItem* item = tagged_node.node_item;
76
77 for (const EdgeInfo& e : item->output_edges()) {
78 const int dst_id = e.dst_id;
79 const int src_slot = e.output_slot;
80 const int dst_loc = e.input_slot;
81
82 // NOTE(mrry): The write to `input_tensors_[dst_loc]` must happen before
83 // the pending count update, or else one thread might conclude that the
84 // count has dropped to zero before another thread finishes updating the
85 // input.
86 if (e.is_last) {
87 input_tensors_[dst_loc] = std::move((*outputs)[src_slot]);
88 } else {
89 input_tensors_[dst_loc] = (*outputs)[src_slot];
90 }
91
92 int32_t previous_num_pending =
93 pending_[dst_id].fetch_sub(1, std::memory_order_release);
94 if (previous_num_pending == 1) ready->emplace_back(&gview.node_ref(dst_id));
95 }
96
97 for (const ControlEdgeInfo& e : item->output_control_edges()) {
98 const int dst_id = e.dst_id;
99
100 int32_t previous_num_pending =
101 pending_[dst_id].fetch_sub(1, std::memory_order_release);
102 if (previous_num_pending == 1) ready->emplace_back(&gview.node_ref(dst_id));
103 }
104}
105
106void SimplePropagatorState::DumpState() {
107 mutex_lock l(mu_);
108 // Dump any waiting nodes that are holding on to tensors.
109 for (const NodeItem* node : *nodes_) {
110 if (pending_[node->node_id]) {
111 DumpPendingNodeState(*node, input_tensors_.data(), false);
112 }
113 }
114 // Then the active nodes.
115 for (const NodeItem* node : *nodes_) {
116 if ((*active_)[node->node_id]) {
117 DumpActiveNodeState(*node, input_tensors_.data());
118 }
119 }
120 // Show all input tensors in use.
121 size_t total_bytes = 0;
122 for (size_t i = 0; i < input_tensors_.size(); ++i) {
123 const Entry& input = input_tensors_[i];
124 const Tensor* tensor = GetTensorValueForDump(input);
125 if (tensor && tensor->IsInitialized()) {
126 LOG(WARNING) << " Input " << i << ": "
127 << strings::StrCat(
128 "Tensor<type: ", DataTypeString(tensor->dtype()),
129 " shape: ", tensor->shape().DebugString(),
130 ", bytes: ", tensor->TotalBytes(), ">");
131 total_bytes += tensor->TotalBytes();
132 }
133 }
134 LOG(WARNING) << " Total bytes " << total_bytes;
135}
136
137} // namespace tensorflow
138