1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/immutable_executor_state.h"
17
18#include "absl/memory/memory.h"
19#include "tensorflow/core/framework/function.h"
20#include "tensorflow/core/framework/metrics.h"
21#include "tensorflow/core/framework/node_def_util.h"
22#include "tensorflow/core/graph/edgeset.h"
23#include "tensorflow/core/graph/graph.h"
24#include "tensorflow/core/graph/graph_node_util.h"
25#include "tensorflow/core/platform/errors.h"
26#include "tensorflow/core/platform/logging.h"
27
28namespace tensorflow {
29
30namespace {
31bool IsInitializationOp(const Node* node) {
32 return node->op_def().allows_uninitialized_input();
33}
34} // namespace
35
36ImmutableExecutorState::~ImmutableExecutorState() {
37 for (int32_t i = 0; i < gview_.num_nodes(); i++) {
38 NodeItem* item = gview_.node(i);
39 if (item != nullptr) {
40 params_.delete_kernel(item->kernel);
41 }
42 }
43}
44
45namespace {
46void GetMaxPendingCounts(const Node* n, size_t* max_pending,
47 size_t* max_dead_count) {
48 const size_t num_in_edges = n->in_edges().size();
49 size_t initial_count;
50 if (IsMerge(n)) {
51 // merge waits all control inputs so we initialize the pending
52 // count to be the number of control edges.
53 int32_t num_control_edges = 0;
54 for (const Edge* edge : n->in_edges()) {
55 if (edge->IsControlEdge()) {
56 num_control_edges++;
57 }
58 }
59 // Use bit 0 to indicate if we are waiting for a ready live data input.
60 initial_count = 1 + (num_control_edges << 1);
61 } else {
62 initial_count = num_in_edges;
63 }
64
65 *max_pending = initial_count;
66 *max_dead_count = num_in_edges;
67}
68} // namespace
69
70ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo(
71 const string& fname) {
72 auto iter = frame_info_.find(fname);
73 if (iter != frame_info_.end()) {
74 return iter->second.get();
75 } else {
76 auto frame_info = std::make_unique<FrameInfo>(fname);
77 absl::string_view fname_view = frame_info->name;
78 auto emplace_result =
79 frame_info_.emplace(fname_view, std::move(frame_info));
80 return emplace_result.first->second.get();
81 }
82}
83
84Status ImmutableExecutorState::Initialize(const Graph& graph) {
85 TF_RETURN_IF_ERROR(gview_.Initialize(&graph));
86
87 // Build the information about frames in this subgraph.
88 ControlFlowInfo cf_info;
89 TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info));
90
91 for (auto& it : cf_info.unique_frame_names) {
92 EnsureFrameInfo(it)->nodes =
93 std::make_unique<std::vector<const NodeItem*>>();
94 }
95 root_frame_info_ = frame_info_[""].get();
96
97 pending_ids_.resize(gview_.num_nodes());
98
99 // Preprocess every node in the graph to create an instance of op
100 // kernel for each node.
101 requires_control_flow_ = false;
102 for (const Node* n : graph.nodes()) {
103 if (IsSink(n)) continue;
104 if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) {
105 requires_control_flow_ = true;
106 } else if (IsRecv(n)) {
107 // A Recv node from a different device may produce dead tensors from
108 // non-local control-flow nodes.
109 //
110 // TODO(mrry): Track whether control flow was present in the
111 // pre-partitioned graph, and enable the caller (e.g.
112 // `DirectSession`) to relax this constraint.
113 string send_device;
114 string recv_device;
115 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device", &send_device));
116 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device", &recv_device));
117 if (send_device != recv_device) {
118 requires_control_flow_ = true;
119 }
120 }
121
122 const int id = n->id();
123 const string& frame_name = cf_info.frame_names[id];
124 FrameInfo* frame_info = EnsureFrameInfo(frame_name);
125
126 NodeItem* item = gview_.node(id);
127 item->node_id = id;
128
129 item->input_start = frame_info->total_inputs;
130 frame_info->total_inputs += n->num_inputs();
131
132 Status s = params_.create_kernel(n->properties(), &item->kernel);
133 if (!s.ok()) {
134 params_.delete_kernel(item->kernel);
135 item->kernel = nullptr;
136 s = AttachDef(s, *n);
137 return s;
138 }
139 CHECK(item->kernel);
140 item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
141 item->is_merge = IsMerge(n);
142 item->is_any_consumer_merge_or_control_trigger = false;
143 for (const Node* consumer : n->out_nodes()) {
144 if (IsMerge(consumer) || IsControlTrigger(consumer)) {
145 item->is_any_consumer_merge_or_control_trigger = true;
146 break;
147 }
148 }
149 const Tensor* const_tensor = item->kernel->const_tensor();
150 if (const_tensor) {
151 // Hold onto a shallow copy of the constant tensor in `*this` so that the
152 // reference count does not drop to 1. This prevents the constant tensor
153 // from being forwarded, and its buffer reused.
154 const_tensors_.emplace_back(*const_tensor);
155 }
156 item->const_tensor = const_tensor;
157 item->is_noop = (item->kernel->type_string_view() == "NoOp");
158 item->is_enter = IsEnter(n);
159 if (item->is_enter) {
160 bool is_constant_enter;
161 TF_RETURN_IF_ERROR(
162 GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
163 item->is_constant_enter = is_constant_enter;
164
165 string frame_name;
166 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &frame_name));
167 FrameInfo* frame_info = frame_info_[frame_name].get();
168
169 int parallel_iterations;
170 TF_RETURN_IF_ERROR(
171 GetNodeAttr(n->attrs(), "parallel_iterations", &parallel_iterations));
172
173 if (frame_info->parallel_iterations == -1) {
174 frame_info->parallel_iterations = parallel_iterations;
175 } else if (frame_info->parallel_iterations != parallel_iterations) {
176 LOG(WARNING) << "Loop frame \"" << frame_name
177 << "\" had two different values for parallel_iterations: "
178 << frame_info->parallel_iterations << " vs. "
179 << parallel_iterations << ".";
180 }
181
182 if (enter_frame_info_.size() <= id) {
183 enter_frame_info_.resize(id + 1);
184 }
185 enter_frame_info_[id] = frame_info;
186 } else {
187 item->is_constant_enter = false;
188 }
189 item->is_exit = IsExit(n);
190 item->is_control_trigger = IsControlTrigger(n);
191 item->is_source = IsSource(n);
192 item->is_enter_exit_or_next_iter =
193 (IsEnter(n) || IsExit(n) || IsNextIteration(n));
194 item->is_transfer_node = IsTransferNode(n);
195 item->is_initialization_op = IsInitializationOp(n);
196 item->is_recv_or_switch = IsRecv(n) || IsSwitch(n);
197 item->is_next_iteration = IsNextIteration(n);
198 item->is_distributed_communication = IsDistributedCommunication(n);
199
200 // Compute the maximum values we'll store for this node in the
201 // pending counts data structure, and allocate a handle in
202 // that frame's pending counts data structure that has enough
203 // space to store these maximal count values.
204 size_t max_pending, max_dead;
205 GetMaxPendingCounts(n, &max_pending, &max_dead);
206 pending_ids_[id] =
207 frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead);
208
209 // See if this node is a root node, and if so, add item to root_nodes_.
210 if (n->in_edges().empty()) {
211 root_nodes_.push_back(item);
212 }
213
214 // Initialize static information about the frames in the graph.
215 frame_info->nodes->push_back(item);
216 if (item->is_enter) {
217 string enter_name;
218 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name", &enter_name));
219 EnsureFrameInfo(enter_name)->input_count++;
220 }
221
222 // Record information about whether each output of the op is used.
223 std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]);
224 std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false);
225 int32_t unused_outputs = n->num_outputs();
226 for (const Edge* e : n->out_edges()) {
227 if (IsSink(e->dst())) continue;
228 if (e->src_output() >= 0) {
229 if (!outputs_required[e->src_output()]) {
230 --unused_outputs;
231 outputs_required[e->src_output()] = true;
232 }
233 }
234 }
235 if (unused_outputs > 0) {
236 for (int i = 0; i < n->num_outputs(); ++i) {
237 if (!outputs_required[i]) {
238 metrics::RecordUnusedOutput(n->type_string());
239 }
240 }
241 item->outputs_required = std::move(outputs_required);
242 }
243 }
244
245 // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input
246 // location.
247 for (const Node* n : graph.nodes()) {
248 if (IsSink(n)) continue;
249 const int id = n->id();
250 NodeItem* item = gview_.node(id);
251
252 for (EdgeInfo& e : item->mutable_output_edges()) {
253 const int dst_id = e.dst_id;
254 NodeItem* dst_item = gview_.node(dst_id);
255 e.input_slot += dst_item->input_start;
256 }
257 }
258
259 // Initialize PendingCounts only after pending_ids_[node.id] is initialized
260 // for all nodes.
261 InitializePending(&graph, cf_info);
262 return gview_.SetAllocAttrs(&graph, params_.device);
263}
264
265namespace {
266// If a Node has been marked to use a ScopedAllocator x for output i, then
267// sc_attr will contain the subsequence (i, x) at an even offset. This function
268// extracts and transfers that ScopedAllocator id to alloc_attr. For now, we
269// only allow one ScopedAllocator use per Node.
270bool ExtractScopedAllocatorAttr(const std::vector<int>& sc_attr,
271 int output_index,
272 AllocatorAttributes* alloc_attr) {
273 DCHECK_LE(2, sc_attr.size());
274 for (int i = 0; i < sc_attr.size(); i += 2) {
275 if (sc_attr[i] == output_index) {
276 CHECK_EQ(alloc_attr->scope_id, 0);
277 alloc_attr->scope_id = sc_attr[i + 1];
278 return true;
279 }
280 }
281 return false;
282}
283} // namespace
284
285Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g,
286 ControlFlowInfo* cf_info) {
287 const int num_nodes = g->num_node_ids();
288 cf_info->frame_names.resize(num_nodes);
289 std::vector<Node*> parent_nodes;
290 parent_nodes.resize(num_nodes);
291 std::vector<bool> visited;
292 visited.resize(num_nodes);
293
294 string frame_name;
295 std::deque<Node*> ready;
296
297 // Initialize with the root nodes.
298 for (Node* n : g->nodes()) {
299 if (n->in_edges().empty()) {
300 visited[n->id()] = true;
301 cf_info->unique_frame_names.insert(frame_name);
302 ready.push_back(n);
303 }
304 }
305
306 while (!ready.empty()) {
307 Node* curr_node = ready.front();
308 int curr_id = curr_node->id();
309 ready.pop_front();
310
311 Node* parent = nullptr;
312 if (IsEnter(curr_node)) {
313 // Enter a child frame.
314 TF_RETURN_IF_ERROR(
315 GetNodeAttr(curr_node->attrs(), "frame_name", &frame_name));
316 parent = curr_node;
317 } else if (IsExit(curr_node)) {
318 // Exit to the parent frame.
319 parent = parent_nodes[curr_id];
320 if (!parent) {
321 return errors::InvalidArgument(
322 "Invalid Exit op: Cannot find a corresponding Enter op.");
323 }
324 frame_name = cf_info->frame_names[parent->id()];
325 parent = parent_nodes[parent->id()];
326 } else {
327 parent = parent_nodes[curr_id];
328 frame_name = cf_info->frame_names[curr_id];
329 }
330
331 for (const Edge* out_edge : curr_node->out_edges()) {
332 Node* out = out_edge->dst();
333 if (IsSink(out)) continue;
334 const int out_id = out->id();
335
336 // Add to ready queue if not visited.
337 bool is_visited = visited[out_id];
338 if (!is_visited) {
339 ready.push_back(out);
340 visited[out_id] = true;
341
342 // Process the node 'out'.
343 cf_info->frame_names[out_id] = frame_name;
344 parent_nodes[out_id] = parent;
345 cf_info->unique_frame_names.insert(frame_name);
346 }
347 }
348 }
349
350 return OkStatus();
351}
352
353void ImmutableExecutorState::InitializePending(const Graph* graph,
354 const ControlFlowInfo& cf_info) {
355 for (auto& it : cf_info.unique_frame_names) {
356 FrameInfo* finfo = EnsureFrameInfo(it);
357 DCHECK_EQ(finfo->pending_counts.get(), nullptr);
358 finfo->pending_counts =
359 std::make_unique<PendingCounts>(finfo->pending_counts_layout);
360 }
361
362 if (!requires_control_flow_) {
363 atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]);
364 std::fill(atomic_pending_counts_.get(),
365 atomic_pending_counts_.get() + gview_.num_nodes(), 0);
366 }
367
368 for (const Node* n : graph->nodes()) {
369 if (IsSink(n)) continue;
370 const int id = n->id();
371 const string& name = cf_info.frame_names[id];
372 size_t max_pending, max_dead;
373 GetMaxPendingCounts(n, &max_pending, &max_dead);
374 auto& counts = EnsureFrameInfo(name)->pending_counts;
375 counts->set_initial_count(pending_ids_[id], max_pending);
376 if (!requires_control_flow_) {
377 atomic_pending_counts_[id] = max_pending;
378 }
379 }
380}
381} // namespace tensorflow
382