1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
17#define TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
18
19#include <memory>
20#include <vector>
21
22#include "tensorflow/core/framework/allocator.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/lib/core/status.h"
25#include "tensorflow/core/lib/gtl/array_slice.h"
26#include "tensorflow/core/platform/logging.h"
27#include "tensorflow/core/platform/macros.h"
28#include "tensorflow/core/platform/types.h"
29
30namespace tensorflow {
31
32class Device;
33class Graph;
34class Node;
35class OpKernel;
36class Tensor;
37
38// Represents a single data edge in a `NodeItem`.
39struct EdgeInfo {
40 // The node ID of the destination in the containing `GraphView`.
41 int dst_id;
42 // The index of the output that produces values on this edge.
43 int output_slot : 31;
44 // true if this is the last info for output_slot in the EdgeInfo list.
45 bool is_last : 1;
46 // The index of the input that consumes values on this edge.
47 int input_slot;
48};
49
50// Represents a single control edge in a `NodeItem`.
51struct ControlEdgeInfo {
52 // The node ID of the destination in the containing `GraphView`.
53 int dst_id;
54};
55
56// Compact structure representing a graph node and its associated kernel.
57//
58// Each NodeItem is an element of exactly one GraphView.
59struct NodeItem {
60 // The index of this node's item in its GraphView.
61 int node_id = -1;
62
63 // Cached attributes of this node for fast lookup.
64 bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
65 bool is_merge : 1; // True iff IsMerge(node)
66 bool is_enter : 1; // True iff IsEnter(node)
67 bool is_constant_enter : 1; // True iff IsEnter(node) and
68 // node->GetAttr("is_constant") == true.
69 bool is_exit : 1; // True iff IsExit(node)
70 bool is_control_trigger : 1; // True iff IsControlTrigger(node)
71 bool is_source : 1; // True iff IsSource(node)
72 // True iff IsEnter(node) || IsExit(node) || IsNextIteration(node)
73 bool is_enter_exit_or_next_iter : 1;
74 bool is_transfer_node : 1; // True iff IsTransferNode(node)
75 bool is_initialization_op : 1; // True iff IsInitializationOp(node)
76 bool is_recv_or_switch : 1; // True iff IsRecv(node) || IsSwitch(node)
77 bool is_next_iteration : 1; // True iff IsNextIteration(node)
78 bool is_noop : 1; // True iff item->kernel->type_string_view() == "NoOp")
79 bool
80 is_any_consumer_merge_or_control_trigger : 1; // True iff the destination
81 // of any output edge is a
82 // merge or control trigger
83 // node.
84 bool is_any_input_ref_typed : 1; // True iff any IsRefType(dt) for dt in this
85 // node's input types.
86 bool is_distributed_communication : 1; // True iff the op is registered to
87 // use distributed communication.
88
89 // The kernel for this node.
90 OpKernel* kernel = nullptr;
91
92 // If the kernel is a Const op, this containts points to the constant tensor.
93 const Tensor* const_tensor = nullptr;
94
95 // Cached values of node->num_inputs() and node->num_outputs(), to
96 // avoid levels of indirection.
97 int num_inputs;
98 int num_outputs;
99
100 // ExecutorImpl::tensors_[input_start] is the 1st positional input
101 // for this node.
102 int input_start = 0;
103
104 // Number of output edges, excluding control edges.
105 int32 num_output_edges;
106
107 // Number of output control edges.
108 int32 num_output_control_edges;
109
110 // If non-null, contains an array of num_outputs bools, where the ith bool
111 // is true if and only if the ith output is consumed by another node.
112 std::unique_ptr<bool[]> outputs_required;
113
114 gtl::MutableArraySlice<EdgeInfo> mutable_output_edges() {
115 return gtl::MutableArraySlice<EdgeInfo>(output_edge_base(),
116 num_output_edges);
117 }
118
119 gtl::ArraySlice<EdgeInfo> output_edges() const {
120 return gtl::ArraySlice<EdgeInfo>(output_edge_base(), num_output_edges);
121 }
122
123 gtl::ArraySlice<ControlEdgeInfo> output_control_edges() const {
124 return gtl::ArraySlice<const ControlEdgeInfo>(output_control_edge_base(),
125 num_output_control_edges);
126 }
127
128 DataType input_type(int i) const {
129 DCHECK_LT(i, num_inputs);
130 return static_cast<DataType>(input_type_base()[i]);
131 }
132 DataType output_type(int i) const {
133 DCHECK_LT(i, num_outputs);
134 return static_cast<DataType>(output_type_base()[i]);
135 }
136
137 // Return array of per-output allocator attributes.
138 const AllocatorAttributes* output_attrs() const { return output_attr_base(); }
139
140 // Return array of expected input index from which each output should
141 // be forwarded:
142 // kNeverForward (-2) for DO NOT FORWARD (must allocate).
143 // kNoReservation (-1) for no expected forwarding.
144 // 0... for forward from that input.
145 const int* forward_from() const { return forward_from_base(); }
146
147 string DebugString() const;
148
149 private:
150 friend class GraphView;
151
152 NodeItem() {}
153
154 // Variable length section starts immediately after *this
155 // (uint8 is enough for DataType).
156 // EdgeInfo out_edges[num_output_edges];
157 // ControlEdgeInfo out_control_edges[num_output_control_edges];
158 // AllocatorAttributes output_attr[num_outputs];
159 // int forward_from[num_outputs];
160 // uint8 input_type[num_inputs];
161 // uint8 output_type[num_outputs];
162
163 // Return pointer to variable length section.
164 char* var() const {
165 return const_cast<char*>(reinterpret_cast<const char*>(this) +
166 sizeof(NodeItem));
167 }
168
169 EdgeInfo* output_edge_base() const {
170 return reinterpret_cast<EdgeInfo*>(var());
171 }
172
173 ControlEdgeInfo* output_control_edge_base() const {
174 return reinterpret_cast<ControlEdgeInfo*>(var() + sizeof(EdgeInfo) *
175 num_output_edges);
176 }
177
178 AllocatorAttributes* output_attr_base() const {
179 return reinterpret_cast<AllocatorAttributes*>(
180 var() + sizeof(EdgeInfo) * num_output_edges +
181 sizeof(ControlEdgeInfo) * num_output_control_edges);
182 }
183 int* forward_from_base() const {
184 return reinterpret_cast<int*>(var() + sizeof(EdgeInfo) * num_output_edges +
185 sizeof(ControlEdgeInfo) *
186 num_output_control_edges +
187 sizeof(AllocatorAttributes) * num_outputs);
188 }
189 uint8* input_type_base() const {
190 return reinterpret_cast<uint8*>(
191 var() + sizeof(EdgeInfo) * num_output_edges +
192 sizeof(ControlEdgeInfo) * num_output_control_edges +
193 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs);
194 }
195 uint8* output_type_base() const {
196 return reinterpret_cast<uint8*>(
197 var() + sizeof(EdgeInfo) * num_output_edges +
198 sizeof(ControlEdgeInfo) * num_output_control_edges +
199 sizeof(AllocatorAttributes) * num_outputs + sizeof(int) * num_outputs +
200 sizeof(uint8) * num_inputs);
201 }
202
203 TF_DISALLOW_COPY_AND_ASSIGN(NodeItem);
204};
205
206// Immutable view of a Graph organized for efficient execution.
207//
208// TODO(b/152651962): Add independent unit tests for this class.
209class GraphView {
210 public:
211 GraphView() : space_(nullptr) {}
212 ~GraphView();
213
214 Status Initialize(const Graph* g);
215 Status SetAllocAttrs(const Graph* g, const Device* device);
216 void SetScopedAllocatorAttrs(const std::vector<const Node*>& sa_nodes);
217
218 // Returns a mutable pointer to the `NodeItem` with the given `id` if it
219 // exists in the graph, or `nullptr` if it does not.
220 NodeItem* node(int32_t id) const {
221 DCHECK_GE(id, 0);
222 DCHECK_LT(id, num_nodes_);
223 uint32 offset = node_offsets_[id];
224 return ((offset == kuint32max)
225 ? nullptr
226 : reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]));
227 }
228
229 // Returns the `NodeItem` with the given `id`.
230 //
231 // REQUIRES: `id` must be the ID of a valid node in the graph.
232 const NodeItem& node_ref(int32_t id) const {
233 DCHECK_GE(id, 0);
234 DCHECK_LT(id, num_nodes_);
235 uint32 offset = node_offsets_[id];
236 DCHECK_NE(offset, kuint32max);
237 return *reinterpret_cast<NodeItem*>(space_ + node_offsets_[id]);
238 }
239
240 int32 num_nodes() const { return num_nodes_; }
241
242 private:
243 char* InitializeNode(char* ptr, const Node* n);
244 size_t NodeItemBytes(const Node* n);
245
246 int32 num_nodes_ = 0;
247 uint32* node_offsets_ = nullptr; // array of size "num_nodes_"
248 // node_offsets_[id] holds the byte offset for node w/ "id" in space_
249
250 char* space_; // NodeItem objects are allocated here
251
252 TF_DISALLOW_COPY_AND_ASSIGN(GraphView);
253};
254
255} // namespace tensorflow
256
257#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GRAPH_VIEW_H_
258