1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ |
16 | #define TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ |
17 | |
18 | #include <atomic> |
19 | #include <deque> |
20 | #include <memory> |
21 | #include <vector> |
22 | |
23 | #include "absl/container/flat_hash_map.h" |
24 | #include "tensorflow/core/common_runtime/graph_view.h" |
25 | #include "tensorflow/core/common_runtime/local_executor_params.h" |
26 | #include "tensorflow/core/common_runtime/pending_counts.h" |
27 | #include "tensorflow/core/framework/tensor.h" |
28 | #include "tensorflow/core/lib/core/status.h" |
29 | #include "tensorflow/core/lib/gtl/flatmap.h" |
30 | #include "tensorflow/core/lib/gtl/flatset.h" |
31 | #include "tensorflow/core/platform/macros.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | class Graph; |
37 | |
38 | // Represents the state of an executor (graph and control flow information) |
39 | // that is immutable throughout execution. |
40 | // |
41 | // TODO(b/152651962): Add independent unit tests for this class. |
42 | class ImmutableExecutorState { |
43 | public: |
44 | struct FrameInfo { |
45 | explicit FrameInfo(string name) |
46 | : name(std::move(name)), |
47 | input_count(0), |
48 | total_inputs(0), |
49 | pending_counts(nullptr), |
50 | nodes(nullptr), |
51 | parallel_iterations(-1) {} |
52 | |
53 | // The name of the frame. |
54 | string name; |
55 | |
56 | // The total number of inputs to a frame. |
57 | int input_count; |
58 | |
59 | // The total number of input tensors of a frame. |
60 | // == sum(nodes[*].num_inputs()) where nodes are the nodes in the frame. |
61 | int total_inputs; |
62 | |
63 | // Used to determine the next place to allocate space in the |
64 | // pending_counts data structure we'll eventually construct |
65 | PendingCounts::Layout pending_counts_layout; |
66 | |
67 | // Each frame has its own PendingCounts only for the nodes in the frame. |
68 | std::unique_ptr<PendingCounts> pending_counts; |
69 | |
70 | // The nodes in a frame. Used only for debugging. |
71 | std::unique_ptr<std::vector<const NodeItem*>> nodes; |
72 | |
73 | // The number of iterations of this frame that can execute concurrently. |
74 | int32 parallel_iterations; |
75 | }; |
76 | |
77 | explicit ImmutableExecutorState(const LocalExecutorParams& p) |
78 | : params_(p), gview_() {} |
79 | ~ImmutableExecutorState(); |
80 | |
81 | Status Initialize(const Graph& graph); |
82 | |
83 | // Process all Nodes in the current graph, attempting to infer the |
84 | // memory allocation attributes to be used wherever they may allocate |
85 | // a tensor buffer. |
86 | Status SetAllocAttrs(); |
87 | |
88 | const LocalExecutorParams& params() const { return params_; } |
89 | const GraphView& graph_view() const { return gview_; } |
90 | const std::vector<PendingCounts::Handle>& pending_ids() const { |
91 | return pending_ids_; |
92 | } |
93 | const std::vector<const NodeItem*>& root_nodes() const { return root_nodes_; } |
94 | |
95 | const FrameInfo& get_root_frame_info() const { return *root_frame_info_; } |
96 | |
97 | const FrameInfo& get_enter_frame_info(const NodeItem& node_item) const { |
98 | DCHECK(node_item.is_enter); |
99 | return *enter_frame_info_[node_item.node_id]; |
100 | } |
101 | |
102 | bool requires_control_flow_support() const { return requires_control_flow_; } |
103 | |
104 | // Copies the pending counts for nodes in this graph to the given array. |
105 | // |
106 | // This method provides a more efficient way of initializing |
107 | // `SimplePropagatorState` than individually accessing the pending counts from |
108 | // `get_root_frame_info().counts`. |
109 | // |
110 | // REQUIRES: `!requires_control_flow_support && len(dest) == |
111 | // graph_view().num_nodes()`. |
112 | void copy_pending_counts(std::atomic<int32>* dest) const { |
113 | DCHECK(!requires_control_flow_); |
114 | memcpy(dest, atomic_pending_counts_.get(), |
115 | graph_view().num_nodes() * sizeof(std::atomic<int32>)); |
116 | std::atomic_thread_fence(std::memory_order_release); |
117 | } |
118 | |
119 | private: |
120 | struct ControlFlowInfo { |
121 | gtl::FlatSet<string> unique_frame_names; |
122 | std::vector<string> frame_names; |
123 | }; |
124 | |
125 | static Status BuildControlFlowInfo(const Graph* graph, |
126 | ControlFlowInfo* cf_info); |
127 | void InitializePending(const Graph* graph, const ControlFlowInfo& cf_info); |
128 | |
129 | FrameInfo* EnsureFrameInfo(const string& fname); |
130 | |
131 | // Owned. |
132 | LocalExecutorParams params_; |
133 | GraphView gview_; |
134 | bool requires_control_flow_; |
135 | std::vector<PendingCounts::Handle> pending_ids_; |
136 | |
137 | // Root nodes (with no in edges) that should form the initial ready queue |
138 | std::vector<const NodeItem*> root_nodes_; |
139 | |
140 | // Mapping from frame name to static information about the frame. |
141 | // TODO(yuanbyu): We could cache it along with the graph so to avoid |
142 | // the overhead of constructing it for each executor instance. |
143 | absl::flat_hash_map<absl::string_view, std::unique_ptr<FrameInfo>> |
144 | frame_info_; |
145 | const FrameInfo* root_frame_info_; // Not owned. |
146 | |
147 | // If the graph contains any "Enter" or "RefEnter" nodes, this vector maps |
148 | // dense node IDs to the corresponding FrameInfo. |
149 | std::vector<FrameInfo*> enter_frame_info_; |
150 | |
151 | // If `requires_control_flow_` is false, this points to an array of initial |
152 | // pending counts for the nodes in the graph, indexed by node ID. |
153 | std::unique_ptr<std::atomic<int32>[]> atomic_pending_counts_; |
154 | |
155 | // Shallow copies of the constant tensors used in the graph. |
156 | std::vector<Tensor> const_tensors_; |
157 | |
158 | TF_DISALLOW_COPY_AND_ASSIGN(ImmutableExecutorState); |
159 | }; |
160 | |
161 | } // namespace tensorflow |
162 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_IMMUTABLE_EXECUTOR_STATE_H_ |
163 | |