1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
16#define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
17
18#include <vector>
19
20#include "tensorflow/core/common_runtime/entry.h"
21#include "tensorflow/core/common_runtime/immutable_executor_state.h"
22#include "tensorflow/core/common_runtime/pending_counts.h"
23#include "tensorflow/core/framework/control_flow.h"
24#include "tensorflow/core/lib/gtl/inlined_vector.h"
25#include "tensorflow/core/platform/logging.h"
26#include "tensorflow/core/platform/macros.h"
27#include "tensorflow/core/platform/mutex.h"
28#include "tensorflow/core/platform/thread_annotations.h"
29#include "tensorflow/core/platform/types.h"
30
31namespace tensorflow {
32
33// Represents the ephemeral "edge state" associated with one invocation of
34// `Executor::Run()`.
35//
36// NOTE: `SimplePropagatorState` does not support "v1-style" control flow,
37// including "dead tensors", "Switch" and "Merge" nodes, and cycles in the
38// graph. Use `PropagatorState` for graphs with those features.
39// `SimplePropagatorState` *does* support "v2-style" or "functional" control
40// flow.
41//
42// `SimplePropagatorState` is responsible for propagating values along dataflow
43// edges in a TensorFlow graph and determining which nodes are runnable. The
44// executor primarily updates `SimplePropagatorState` by calling
45// `PropagateOutputs()` after processing a node, and `SimplePropagatorState`
46// dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`.
47class SimplePropagatorState {
48 public:
49 SimplePropagatorState(const ImmutableExecutorState& immutable_state,
50 int64_t step_id, bool vlog);
51 ~SimplePropagatorState();
52
53 // A `TaggedNode` corresponds to a single invocation of a node's kernel,
54 // and it is created when the kernel becomes runnable.
55 struct TaggedNode {
56 const NodeItem* node_item;
57
58 explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {}
59
60 const NodeItem& get_node_item() const { return *node_item; }
61
62 bool get_is_dead() const { return false; }
63 int64_t get_iter_num() const { return 0; }
64 };
65
66 // A drop-in replacement for std::deque<TaggedNode>. We typically don't
67 // have that many nodes in the ready queue, so we just use a vector and
68 // don't free up memory from the queue as we consume nodes.
69 // TODO(mrry): Extract this and share it with the version in
70 // `PropagatorState`. The correct constants might be different, since
71 // sizeof(TaggedNode) is smaller in this version.
72 class TaggedNodeReadyQueue {
73 public:
74 TaggedNodeReadyQueue() : front_index_(0) {}
75
76 void push_back(const TaggedNode& node) { ready_.push_back(node); }
77 TaggedNode front() const {
78 DCHECK_LT(front_index_, ready_.size());
79 return ready_[front_index_];
80 }
81 void pop_front() {
82 DCHECK_LT(front_index_, ready_.size());
83 front_index_++;
84 if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) {
85 if (front_index_ == ready_.size()) {
86 ready_.clear();
87 } else {
88 // Lots of unused entries at beginning of vector: move everything
89 // down to start of vector.
90 ready_.erase(ready_.begin(), ready_.begin() + front_index_);
91 }
92 front_index_ = 0;
93 }
94 }
95 bool empty() const { return ready_.empty(); }
96 int size() const { return ready_.size() - front_index_; }
97
98 private:
99 // TODO(b/152925936): Re-evaluate these constants with current usage
100 // patterns.
101 static constexpr int kSpillThreshold = 16384;
102 gtl::InlinedVector<TaggedNode, 16> ready_;
103 int front_index_;
104 };
105
106 // TODO(b/152925936): Re-evaluate this constant with current usage patterns.
107 typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq;
108
109 // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`.
110 void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots,
111 TaggedNodeSeq* ready);
112
113 // After processing the outputs, propagates the outputs to their dsts.
114 // Contents of *outputs are left in an indeterminate state after
115 // returning from this method.
116 void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs,
117 TaggedNodeSeq* ready);
118
119 // Returns an array of `Entry` objects corresponding to the inputs of
120 // `tagged_node`.
121 Entry* GetInputTensors(const TaggedNode& tagged_node) {
122#if defined(THREAD_SANITIZER) || defined(DEBUG)
123 // NOTE: This read of `pending_[...]` works around a limitation in TSAN.
124 // To avoid false positive data race reports, we need to perform an atomic
125 // object access that will establish the happens-before relation between
126 // the write to input_tensors_ in `PropagateOutputs()` and the read in
127 // `PrepareInputs()`.
128 CHECK_EQ(pending_[tagged_node.node_item->node_id], 0);
129#endif // defined(THREAD_SANITIZER) || defined(DEBUG)
130 return input_tensors_.data() + tagged_node.node_item->input_start;
131 }
132
133 FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const {
134 return {0, 0};
135 }
136
137 // Provide debugging output of the state of the executor.
138 void DumpState();
139
140 // For debugging/logging only.
141 void MaybeMarkStarted(const TaggedNode& tagged_node) {
142 // TODO(misard) Replace with a finer-grain enabling flag once we add better
143 // optional debugging support.
144 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
145 mutex_lock l(mu_);
146 (*active_)[tagged_node.node_item->node_id] = true;
147 }
148 }
149 void MaybeMarkCompleted(const TaggedNode& tagged_node) {
150 // TODO(misard) Replace with a finer-grain enabling flag once we add better
151 // optional debugging support.
152 if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) {
153 mutex_lock l(mu_);
154 (*active_)[tagged_node.node_item->node_id] = false;
155 }
156 }
157
158 private:
159 SimplePropagatorState(const ImmutableExecutorState& immutable_state_,
160 int64_t step_id,
161 const ImmutableExecutorState::FrameInfo& finfo,
162 bool vlog);
163
164 const ImmutableExecutorState& immutable_state_;
165 const int64_t step_id_;
166 const bool vlog_;
167
168 // The i-th node's j-th input is stored at
169 // `input_tensors[impl_->nodes[i].input_start + j]`.
170 //
171 // NOTE: No need to protect input_tensors[i] by any locks because it
172 // is resized once. Each element of input_tensors is written once by the
173 // source node of an edge and is cleared by the destination of the same
174 // edge. The destination node always runs after the source node, so there
175 // is never concurrent access to the same entry.
176 std::vector<Entry> input_tensors_;
177
178 std::unique_ptr<std::atomic<int32>[]> pending_;
179
180 // If `vlog_` is true, this stores a bit vector of active nodes, indexed by
181 // node ID.
182 mutex mu_;
183 std::unique_ptr<std::vector<bool>> active_ TF_GUARDED_BY(mu_);
184
185 const std::vector<const NodeItem*>* const nodes_;
186};
187
188} // namespace tensorflow
189
190#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_
191