1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ |
17 | |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/common_runtime/entry.h" |
21 | #include "tensorflow/core/common_runtime/immutable_executor_state.h" |
22 | #include "tensorflow/core/common_runtime/pending_counts.h" |
23 | #include "tensorflow/core/framework/control_flow.h" |
24 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | #include "tensorflow/core/platform/macros.h" |
27 | #include "tensorflow/core/platform/mutex.h" |
28 | #include "tensorflow/core/platform/thread_annotations.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | |
31 | namespace tensorflow { |
32 | |
33 | // Represents the ephemeral "edge state" associated with one invocation of |
34 | // `Executor::Run()`. |
35 | // |
36 | // NOTE: `SimplePropagatorState` does not support "v1-style" control flow, |
37 | // including "dead tensors", "Switch" and "Merge" nodes, and cycles in the |
38 | // graph. Use `PropagatorState` for graphs with those features. |
39 | // `SimplePropagatorState` *does* support "v2-style" or "functional" control |
40 | // flow. |
41 | // |
42 | // `SimplePropagatorState` is responsible for propagating values along dataflow |
43 | // edges in a TensorFlow graph and determining which nodes are runnable. The |
44 | // executor primarily updates `SimplePropagatorState` by calling |
45 | // `PropagateOutputs()` after processing a node, and `SimplePropagatorState` |
46 | // dispatches `TaggedNode`s by adding them to a `TaggedNodeSeq`. |
47 | class SimplePropagatorState { |
48 | public: |
49 | SimplePropagatorState(const ImmutableExecutorState& immutable_state, |
50 | int64_t step_id, bool vlog); |
51 | ~SimplePropagatorState(); |
52 | |
53 | // A `TaggedNode` corresponds to a single invocation of a node's kernel, |
54 | // and it is created when the kernel becomes runnable. |
55 | struct TaggedNode { |
56 | const NodeItem* node_item; |
57 | |
58 | explicit TaggedNode(const NodeItem* node_item) : node_item(node_item) {} |
59 | |
60 | const NodeItem& get_node_item() const { return *node_item; } |
61 | |
62 | bool get_is_dead() const { return false; } |
63 | int64_t get_iter_num() const { return 0; } |
64 | }; |
65 | |
66 | // A drop-in replacement for std::deque<TaggedNode>. We typically don't |
67 | // have that many nodes in the ready queue, so we just use a vector and |
68 | // don't free up memory from the queue as we consume nodes. |
69 | // TODO(mrry): Extract this and share it with the version in |
70 | // `PropagatorState`. The correct constants might be different, since |
71 | // sizeof(TaggedNode) is smaller in this version. |
72 | class TaggedNodeReadyQueue { |
73 | public: |
74 | TaggedNodeReadyQueue() : front_index_(0) {} |
75 | |
76 | void push_back(const TaggedNode& node) { ready_.push_back(node); } |
77 | TaggedNode front() const { |
78 | DCHECK_LT(front_index_, ready_.size()); |
79 | return ready_[front_index_]; |
80 | } |
81 | void pop_front() { |
82 | DCHECK_LT(front_index_, ready_.size()); |
83 | front_index_++; |
84 | if ((front_index_ == ready_.size()) || (front_index_ > kSpillThreshold)) { |
85 | if (front_index_ == ready_.size()) { |
86 | ready_.clear(); |
87 | } else { |
88 | // Lots of unused entries at beginning of vector: move everything |
89 | // down to start of vector. |
90 | ready_.erase(ready_.begin(), ready_.begin() + front_index_); |
91 | } |
92 | front_index_ = 0; |
93 | } |
94 | } |
95 | bool empty() const { return ready_.empty(); } |
96 | int size() const { return ready_.size() - front_index_; } |
97 | |
98 | private: |
99 | // TODO(b/152925936): Re-evaluate these constants with current usage |
100 | // patterns. |
101 | static constexpr int kSpillThreshold = 16384; |
102 | gtl::InlinedVector<TaggedNode, 16> ready_; |
103 | int front_index_; |
104 | }; |
105 | |
106 | // TODO(b/152925936): Re-evaluate this constant with current usage patterns. |
107 | typedef gtl::InlinedVector<TaggedNode, 8> TaggedNodeSeq; |
108 | |
109 | // Creates and adds a `TaggedNode` for each node in `roots` to `*ready`. |
110 | void ActivateRoots(gtl::ArraySlice<const NodeItem*> roots, |
111 | TaggedNodeSeq* ready); |
112 | |
113 | // After processing the outputs, propagates the outputs to their dsts. |
114 | // Contents of *outputs are left in an indeterminate state after |
115 | // returning from this method. |
116 | void PropagateOutputs(const TaggedNode& tagged_node, EntryVector* outputs, |
117 | TaggedNodeSeq* ready); |
118 | |
119 | // Returns an array of `Entry` objects corresponding to the inputs of |
120 | // `tagged_node`. |
121 | Entry* GetInputTensors(const TaggedNode& tagged_node) { |
122 | #if defined(THREAD_SANITIZER) || defined(DEBUG) |
123 | // NOTE: This read of `pending_[...]` works around a limitation in TSAN. |
124 | // To avoid false positive data race reports, we need to perform an atomic |
125 | // object access that will establish the happens-before relation between |
126 | // the write to input_tensors_ in `PropagateOutputs()` and the read in |
127 | // `PrepareInputs()`. |
128 | CHECK_EQ(pending_[tagged_node.node_item->node_id], 0); |
129 | #endif // defined(THREAD_SANITIZER) || defined(DEBUG) |
130 | return input_tensors_.data() + tagged_node.node_item->input_start; |
131 | } |
132 | |
133 | FrameAndIter GetFrameAndIter(const TaggedNode& tagged_node) const { |
134 | return {0, 0}; |
135 | } |
136 | |
137 | // Provide debugging output of the state of the executor. |
138 | void DumpState(); |
139 | |
140 | // For debugging/logging only. |
141 | void MaybeMarkStarted(const TaggedNode& tagged_node) { |
142 | // TODO(misard) Replace with a finer-grain enabling flag once we add better |
143 | // optional debugging support. |
144 | if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { |
145 | mutex_lock l(mu_); |
146 | (*active_)[tagged_node.node_item->node_id] = true; |
147 | } |
148 | } |
149 | void MaybeMarkCompleted(const TaggedNode& tagged_node) { |
150 | // TODO(misard) Replace with a finer-grain enabling flag once we add better |
151 | // optional debugging support. |
152 | if (TF_PREDICT_FALSE(vlog_) && VLOG_IS_ON(1)) { |
153 | mutex_lock l(mu_); |
154 | (*active_)[tagged_node.node_item->node_id] = false; |
155 | } |
156 | } |
157 | |
158 | private: |
159 | SimplePropagatorState(const ImmutableExecutorState& immutable_state_, |
160 | int64_t step_id, |
161 | const ImmutableExecutorState::FrameInfo& finfo, |
162 | bool vlog); |
163 | |
164 | const ImmutableExecutorState& immutable_state_; |
165 | const int64_t step_id_; |
166 | const bool vlog_; |
167 | |
168 | // The i-th node's j-th input is stored at |
169 | // `input_tensors[impl_->nodes[i].input_start + j]`. |
170 | // |
171 | // NOTE: No need to protect input_tensors[i] by any locks because it |
172 | // is resized once. Each element of input_tensors is written once by the |
173 | // source node of an edge and is cleared by the destination of the same |
174 | // edge. The destination node always runs after the source node, so there |
175 | // is never concurrent access to the same entry. |
176 | std::vector<Entry> input_tensors_; |
177 | |
178 | std::unique_ptr<std::atomic<int32>[]> pending_; |
179 | |
180 | // If `vlog_` is true, this stores a bit vector of active nodes, indexed by |
181 | // node ID. |
182 | mutex mu_; |
183 | std::unique_ptr<std::vector<bool>> active_ TF_GUARDED_BY(mu_); |
184 | |
185 | const std::vector<const NodeItem*>* const nodes_; |
186 | }; |
187 | |
188 | } // namespace tensorflow |
189 | |
190 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_SIMPLE_PROPAGATOR_STATE_H_ |
191 | |