1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ |
17 | #define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ |
18 | |
19 | #include <memory> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/allocator.h" |
23 | #include "tensorflow/core/framework/types.h" |
24 | #include "tensorflow/core/lib/core/status.h" |
25 | #include "tensorflow/core/lib/gtl/array_slice.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | #include "tensorflow/core/platform/macros.h" |
28 | #include "tensorflow/core/platform/types.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | class Device; |
33 | class Graph; |
34 | class Node; |
35 | class OpKernel; |
36 | class Tensor; |
37 | |
38 | // Represents a single data edge in a `NodeItem`. |
39 | struct EdgeInfo { |
40 | // The node ID of the destination in the containing `GraphView`. |
41 | int dst_id; |
42 | // The index of the output that produces values on this edge. |
43 | int output_slot : 31; |
44 | // true if this is the last info for output_slot in the EdgeInfo list. |
45 | bool is_last : 1; |
46 | // The index of the input that consumes values on this edge. |
47 | int input_slot; |
48 | }; |
49 | |
50 | // Represents a single control edge in a `NodeItem`. |
51 | struct ControlEdgeInfo { |
52 | // The node ID of the destination in the containing `GraphView`. |
53 | int dst_id; |
54 | }; |
55 | |
56 | // Compact structure representing a graph node and its associated kernel. |
57 | // |
58 | // Each NodeItem is an element of exactly one GraphView. |
59 | struct NodeItem { |
60 | // The index of this node's item in its GraphView. |
61 | int node_id = -1; |
62 | |
63 | // Cached attributes of this node for fast lookup. |
64 | bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr |
65 | bool is_merge : 1; // True iff IsMerge(node) |
66 | bool is_enter : 1; // True iff IsEnter(node) |
67 | bool is_constant_enter : 1; // True iff IsEnter(node) and |
68 | // node->GetAttr("is_constant") == true. |
69 | bool is_exit : 1; // True iff IsExit(node) |
70 | bool is_control_trigger : 1; // True iff IsControlTrigger(node) |
71 | bool is_source : 1; // True iff IsSource(node) |
72 | // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node) |
73 | bool is_enter_exit_or_next_iter : 1; |
74 | bool is_transfer_node : 1; // True iff IsTransferNode(node) |
75 | bool is_initialization_op : 1; // True iff IsInitializationOp(node) |
76 | bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node) |
77 | bool is_next_iteration : 1; // True iff IsNextIteration(node) |
78 | bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp") |
79 | bool |
80 | is_any_consumer_merge_or_control_trigger : 1; // True iff the destination |
81 | // of any output edge is a |
82 | // merge or control trigger |
83 | // node. |
84 | bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this |
85 | // node's input types. |
86 | bool is_distributed_communication : 1; // True iff the op is registered to |
87 | // use distributed communication. |
88 | |
89 | // The kernel for this node. |
90 | OpKernel* kernel = nullptr; |
91 | |
92 | // If the kernel is a Const op, this containts points to the constant tensor. |
93 | const Tensor* const_tensor = nullptr; |
94 | |
95 | // Cached values of node->num_inputs() and node->num_outputs(), to |
96 | // avoid levels of indirection. |
97 | int num_inputs; |
98 | int num_outputs; |
99 | |
100 | // ExecutorImpl::tensors_[input_start] is the 1st positional input |
101 | // for this node. |
102 | int input_start = 0; |
103 | |
104 | // Number of output edges, excluding control edges. |
105 | int32 num_output_edges; |
106 | |
107 | // Number of output control edges. |
108 | int32 num_output_control_edges; |
109 | |
110 | // If non-null, contains an array of num_outputs bools, where the ith bool |
111 | // is true if and only if the ith output is consumed by another node. |
112 | std::unique_ptr<bool[]> outputs_required; |
113 | |
114 | gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() { |
115 | return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(), |
116 | num_output_edges); |
117 | } |
118 | |
119 | gtl::ArraySlice<EdgeInfo> output_edges() const { |
120 | return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges); |
121 | } |
122 | |
123 | gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const { |
124 | return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(), |
125 | num_output_control_edges); |
126 | } |
127 | |
128 | DataType input_type(int i) const { |
129 | DCHECK_LT(i, num_inputs); |
130 | return static_cast<DataType>(input_type_base()[i]); |
131 | } |
132 | DataType output_type(int i) const { |
133 | DCHECK_LT(i, num_outputs); |
134 | return static_cast<DataType>(output_type_base()[i]); |
135 | } |
136 | |
137 | // Return array of per-output allocator attributes. |
138 | const AllocatorAttributes* output_attrs() const { return output_attr_base(); } |
139 | |
140 | // Return array of expected input index from which each output should |
141 | // be forwarded: |
142 | // kNeverForward (-2) for DO NOT FORWARD (must allocate). |
143 | // kNoReservation (-1) for no expected forwarding. |
144 | // 0... for forward from that input. |
145 | const int* forward_from() const { return forward_from_base(); } |
146 | |
147 | string DebugString() const; |
148 | |
149 | private: |
150 | friend class GraphView; |
151 | |
152 | NodeItem() {} |
153 | |
154 | // Variable length section starts immediately after *this |
155 | // (uint8 is enough for DataType). |
156 | // EdgeInfo out_edges[num_output_edges]; |
157 | // ControlEdgeInfo out_control_edges[num_output_control_edges]; |
158 | // AllocatorAttributes output_attr[num_outputs]; |
159 | // int forward_from[num_outputs]; |
160 | // uint8 input_type[num_inputs]; |
161 | // uint8 output_type[num_outputs]; |
162 | |
163 | // Return pointer to variable length section. |
164 | char* var() const { |
165 | return const_cast<char*>(reinterpret_cast<const char*>(this) + |
166 | sizeof(NodeItem)); |
167 | } |
168 | |
169 | EdgeInfo* output_edge_base() const { |
170 | return reinterpret_cast<EdgeInfo*>(var()); |
171 | } |
172 | |
173 | ControlEdgeInfo* output_control_edge_base() const { |
174 | return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) * |
175 | num_output_edges); |
176 | } |
177 | |
178 | AllocatorAttributes* output_attr_base() const { |
179 | return reinterpret_cast<AllocatorAttributes*>( |
180 | var() + sizeof(EdgeInfo) * num_output_edges + |
181 | sizeof(ControlEdgeInfo) * num_output_control_edges); |
182 | } |
183 | int* forward_from_base() const { |
184 | return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges + |
185 | sizeof(ControlEdgeInfo) * |
186 | num_output_control_edges + |
187 | sizeof(AllocatorAttributes) * num_outputs); |
188 | } |
189 | uint8* input_type_base() const { |
190 | return reinterpret_cast<uint8*>( |
191 | var() + sizeof(EdgeInfo) * num_output_edges + |
192 | sizeof(ControlEdgeInfo) * num_output_control_edges + |
193 | sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs); |
194 | } |
195 | uint8* output_type_base() const { |
196 | return reinterpret_cast<uint8*>( |
197 | var() + sizeof(EdgeInfo) * num_output_edges + |
198 | sizeof(ControlEdgeInfo) * num_output_control_edges + |
199 | sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs + |
200 | sizeof(uint8) * num_inputs); |
201 | } |
202 | |
203 | TF_DISALLOW_COPY_AND_ASSIGN(NodeItem); |
204 | }; |
205 | |
206 | // Immutable view of a Graph organized for efficient execution. |
207 | // |
208 | // TODO(b/152651962): Add independent unit tests for this class. |
209 | class GraphView { |
210 | public: |
211 | GraphView() : space_(nullptr) {} |
212 | ~GraphView(); |
213 | |
214 | Status Initialize(const Graph* g); |
215 | Status SetAllocAttrs(const Graph* g, const Device* device); |
216 | void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes); |
217 | |
218 | // Returns a mutable pointer to the `NodeItem` with the given `id` if it |
219 | // exists in the graph, or `nullptr` if it does not. |
220 | NodeItem* node(int32_t id) const { |
221 | DCHECK_GE(id, 0); |
222 | DCHECK_LT(id, num_nodes_); |
223 | uint32 offset = node_offsets_[id]; |
224 | return ((offset == kuint32max) |
225 | ? nullptr |
226 | : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id])); |
227 | } |
228 | |
229 | // Returns the `NodeItem` with the given `id`. |
230 | // |
231 | // REQUIRES: `id` must be the ID of a valid node in the graph. |
232 | const NodeItem& node_ref(int32_t id) const { |
233 | DCHECK_GE(id, 0); |
234 | DCHECK_LT(id, num_nodes_); |
235 | uint32 offset = node_offsets_[id]; |
236 | DCHECK_NE(offset, kuint32max); |
237 | return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]); |
238 | } |
239 | |
240 | int32 num_nodes() const { return num_nodes_; } |
241 | |
242 | private: |
243 | char* InitializeNode(char* ptr, const Node* n); |
244 | size_t NodeItemBytes(const Node* n); |
245 | |
246 | int32 num_nodes_ = 0; |
247 | uint32* node_offsets_ = nullptr; // array of size "num_nodes_" |
248 | // node_offsets_[id] holds the byte offset for node w/ "id" in space_ |
249 | |
250 | char* space_; // NodeItem objects are allocated here |
251 | |
252 | TF_DISALLOW_COPY_AND_ASSIGN(GraphView); |
253 | }; |
254 | |
255 | } // namespace tensorflow |
256 | |
257 | #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_ |
258 | |