1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/immutable_executor_state.h" |
17 | |
18 | #include "absl/memory/memory.h" |
19 | #include "tensorflow/core/framework/function.h" |
20 | #include "tensorflow/core/framework/metrics.h" |
21 | #include "tensorflow/core/framework/node_def_util.h" |
22 | #include "tensorflow/core/graph/edgeset.h" |
23 | #include "tensorflow/core/graph/graph.h" |
24 | #include "tensorflow/core/graph/graph_node_util.h" |
25 | #include "tensorflow/core/platform/errors.h" |
26 | #include "tensorflow/core/platform/logging.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | namespace { |
31 | bool IsInitializationOp(const Node* node) { |
32 | return node->op_def().allows_uninitialized_input(); |
33 | } |
34 | } // namespace |
35 | |
36 | ImmutableExecutorState::~ImmutableExecutorState() { |
37 | for (int32_t i = 0; i < gview_.num_nodes(); i++) { |
38 | NodeItem* item = gview_.node(i); |
39 | if (item != nullptr) { |
40 | params_.delete_kernel(item->kernel); |
41 | } |
42 | } |
43 | } |
44 | |
45 | namespace { |
46 | void GetMaxPendingCounts(const Node* n, size_t* max_pending, |
47 | size_t* max_dead_count) { |
48 | const size_t num_in_edges = n->in_edges().size(); |
49 | size_t initial_count; |
50 | if (IsMerge(n)) { |
51 | // merge waits all control inputs so we initialize the pending |
52 | // count to be the number of control edges. |
53 | int32_t num_control_edges = 0; |
54 | for (const Edge* edge : n->in_edges()) { |
55 | if (edge->IsControlEdge()) { |
56 | num_control_edges++; |
57 | } |
58 | } |
59 | // Use bit 0 to indicate if we are waiting for a ready live data input. |
60 | initial_count = 1 + (num_control_edges << 1); |
61 | } else { |
62 | initial_count = num_in_edges; |
63 | } |
64 | |
65 | *max_pending = initial_count; |
66 | *max_dead_count = num_in_edges; |
67 | } |
68 | } // namespace |
69 | |
70 | ImmutableExecutorState::FrameInfo* ImmutableExecutorState::EnsureFrameInfo( |
71 | const string& fname) { |
72 | auto iter = frame_info_.find(fname); |
73 | if (iter != frame_info_.end()) { |
74 | return iter->second.get(); |
75 | } else { |
76 | auto frame_info = std::make_unique<FrameInfo>(fname); |
77 | absl::string_view fname_view = frame_info->name; |
78 | auto emplace_result = |
79 | frame_info_.emplace(fname_view, std::move(frame_info)); |
80 | return emplace_result.first->second.get(); |
81 | } |
82 | } |
83 | |
84 | Status ImmutableExecutorState::Initialize(const Graph& graph) { |
85 | TF_RETURN_IF_ERROR(gview_.Initialize(&graph)); |
86 | |
87 | // Build the information about frames in this subgraph. |
88 | ControlFlowInfo cf_info; |
89 | TF_RETURN_IF_ERROR(BuildControlFlowInfo(&graph, &cf_info)); |
90 | |
91 | for (auto& it : cf_info.unique_frame_names) { |
92 | EnsureFrameInfo(it)->nodes = |
93 | std::make_unique<std::vector<const NodeItem*>>(); |
94 | } |
95 | root_frame_info_ = frame_info_["" ].get(); |
96 | |
97 | pending_ids_.resize(gview_.num_nodes()); |
98 | |
99 | // Preprocess every node in the graph to create an instance of op |
100 | // kernel for each node. |
101 | requires_control_flow_ = false; |
102 | for (const Node* n : graph.nodes()) { |
103 | if (IsSink(n)) continue; |
104 | if (IsSwitch(n) || IsMerge(n) || IsEnter(n) || IsExit(n)) { |
105 | requires_control_flow_ = true; |
106 | } else if (IsRecv(n)) { |
107 | // A Recv node from a different device may produce dead tensors from |
108 | // non-local control-flow nodes. |
109 | // |
110 | // TODO(mrry): Track whether control flow was present in the |
111 | // pre-partitioned graph, and enable the caller (e.g. |
112 | // `DirectSession`) to relax this constraint. |
113 | string send_device; |
114 | string recv_device; |
115 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "send_device" , &send_device)); |
116 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "recv_device" , &recv_device)); |
117 | if (send_device != recv_device) { |
118 | requires_control_flow_ = true; |
119 | } |
120 | } |
121 | |
122 | const int id = n->id(); |
123 | const string& frame_name = cf_info.frame_names[id]; |
124 | FrameInfo* frame_info = EnsureFrameInfo(frame_name); |
125 | |
126 | NodeItem* item = gview_.node(id); |
127 | item->node_id = id; |
128 | |
129 | item->input_start = frame_info->total_inputs; |
130 | frame_info->total_inputs += n->num_inputs(); |
131 | |
132 | Status s = params_.create_kernel(n->properties(), &item->kernel); |
133 | if (!s.ok()) { |
134 | params_.delete_kernel(item->kernel); |
135 | item->kernel = nullptr; |
136 | s = AttachDef(s, *n); |
137 | return s; |
138 | } |
139 | CHECK(item->kernel); |
140 | item->kernel_is_async = (item->kernel->AsAsync() != nullptr); |
141 | item->is_merge = IsMerge(n); |
142 | item->is_any_consumer_merge_or_control_trigger = false; |
143 | for (const Node* consumer : n->out_nodes()) { |
144 | if (IsMerge(consumer) || IsControlTrigger(consumer)) { |
145 | item->is_any_consumer_merge_or_control_trigger = true; |
146 | break; |
147 | } |
148 | } |
149 | const Tensor* const_tensor = item->kernel->const_tensor(); |
150 | if (const_tensor) { |
151 | // Hold onto a shallow copy of the constant tensor in `*this` so that the |
152 | // reference count does not drop to 1. This prevents the constant tensor |
153 | // from being forwarded, and its buffer reused. |
154 | const_tensors_.emplace_back(*const_tensor); |
155 | } |
156 | item->const_tensor = const_tensor; |
157 | item->is_noop = (item->kernel->type_string_view() == "NoOp" ); |
158 | item->is_enter = IsEnter(n); |
159 | if (item->is_enter) { |
160 | bool is_constant_enter; |
161 | TF_RETURN_IF_ERROR( |
162 | GetNodeAttr(n->attrs(), "is_constant" , &is_constant_enter)); |
163 | item->is_constant_enter = is_constant_enter; |
164 | |
165 | string frame_name; |
166 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name" , &frame_name)); |
167 | FrameInfo* frame_info = frame_info_[frame_name].get(); |
168 | |
169 | int parallel_iterations; |
170 | TF_RETURN_IF_ERROR( |
171 | GetNodeAttr(n->attrs(), "parallel_iterations" , ¶llel_iterations)); |
172 | |
173 | if (frame_info->parallel_iterations == -1) { |
174 | frame_info->parallel_iterations = parallel_iterations; |
175 | } else if (frame_info->parallel_iterations != parallel_iterations) { |
176 | LOG(WARNING) << "Loop frame \"" << frame_name |
177 | << "\" had two different values for parallel_iterations: " |
178 | << frame_info->parallel_iterations << " vs. " |
179 | << parallel_iterations << "." ; |
180 | } |
181 | |
182 | if (enter_frame_info_.size() <= id) { |
183 | enter_frame_info_.resize(id + 1); |
184 | } |
185 | enter_frame_info_[id] = frame_info; |
186 | } else { |
187 | item->is_constant_enter = false; |
188 | } |
189 | item->is_exit = IsExit(n); |
190 | item->is_control_trigger = IsControlTrigger(n); |
191 | item->is_source = IsSource(n); |
192 | item->is_enter_exit_or_next_iter = |
193 | (IsEnter(n) || IsExit(n) || IsNextIteration(n)); |
194 | item->is_transfer_node = IsTransferNode(n); |
195 | item->is_initialization_op = IsInitializationOp(n); |
196 | item->is_recv_or_switch = IsRecv(n) || IsSwitch(n); |
197 | item->is_next_iteration = IsNextIteration(n); |
198 | item->is_distributed_communication = IsDistributedCommunication(n); |
199 | |
200 | // Compute the maximum values we'll store for this node in the |
201 | // pending counts data structure, and allocate a handle in |
202 | // that frame's pending counts data structure that has enough |
203 | // space to store these maximal count values. |
204 | size_t max_pending, max_dead; |
205 | GetMaxPendingCounts(n, &max_pending, &max_dead); |
206 | pending_ids_[id] = |
207 | frame_info->pending_counts_layout.CreateHandle(max_pending, max_dead); |
208 | |
209 | // See if this node is a root node, and if so, add item to root_nodes_. |
210 | if (n->in_edges().empty()) { |
211 | root_nodes_.push_back(item); |
212 | } |
213 | |
214 | // Initialize static information about the frames in the graph. |
215 | frame_info->nodes->push_back(item); |
216 | if (item->is_enter) { |
217 | string enter_name; |
218 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "frame_name" , &enter_name)); |
219 | EnsureFrameInfo(enter_name)->input_count++; |
220 | } |
221 | |
222 | // Record information about whether each output of the op is used. |
223 | std::unique_ptr<bool[]> outputs_required(new bool[n->num_outputs()]); |
224 | std::fill(&outputs_required[0], &outputs_required[n->num_outputs()], false); |
225 | int32_t unused_outputs = n->num_outputs(); |
226 | for (const Edge* e : n->out_edges()) { |
227 | if (IsSink(e->dst())) continue; |
228 | if (e->src_output() >= 0) { |
229 | if (!outputs_required[e->src_output()]) { |
230 | --unused_outputs; |
231 | outputs_required[e->src_output()] = true; |
232 | } |
233 | } |
234 | } |
235 | if (unused_outputs > 0) { |
236 | for (int i = 0; i < n->num_outputs(); ++i) { |
237 | if (!outputs_required[i]) { |
238 | metrics::RecordUnusedOutput(n->type_string()); |
239 | } |
240 | } |
241 | item->outputs_required = std::move(outputs_required); |
242 | } |
243 | } |
244 | |
245 | // Rewrite each `EdgeInfo::input_slot` member to refer directly to the input |
246 | // location. |
247 | for (const Node* n : graph.nodes()) { |
248 | if (IsSink(n)) continue; |
249 | const int id = n->id(); |
250 | NodeItem* item = gview_.node(id); |
251 | |
252 | for (EdgeInfo& e : item->mutable_output_edges()) { |
253 | const int dst_id = e.dst_id; |
254 | NodeItem* dst_item = gview_.node(dst_id); |
255 | e.input_slot += dst_item->input_start; |
256 | } |
257 | } |
258 | |
259 | // Initialize PendingCounts only after pending_ids_[node.id] is initialized |
260 | // for all nodes. |
261 | InitializePending(&graph, cf_info); |
262 | return gview_.SetAllocAttrs(&graph, params_.device); |
263 | } |
264 | |
265 | namespace { |
266 | // If a Node has been marked to use a ScopedAllocator x for output i, then |
267 | // sc_attr will contain the subsequence (i, x) at an even offset. This function |
268 | // extracts and transfers that ScopedAllocator id to alloc_attr. For now, we |
269 | // only allow one ScopedAllocator use per Node. |
270 | bool (const std::vector<int>& sc_attr, |
271 | int output_index, |
272 | AllocatorAttributes* alloc_attr) { |
273 | DCHECK_LE(2, sc_attr.size()); |
274 | for (int i = 0; i < sc_attr.size(); i += 2) { |
275 | if (sc_attr[i] == output_index) { |
276 | CHECK_EQ(alloc_attr->scope_id, 0); |
277 | alloc_attr->scope_id = sc_attr[i + 1]; |
278 | return true; |
279 | } |
280 | } |
281 | return false; |
282 | } |
283 | } // namespace |
284 | |
285 | Status ImmutableExecutorState::BuildControlFlowInfo(const Graph* g, |
286 | ControlFlowInfo* cf_info) { |
287 | const int num_nodes = g->num_node_ids(); |
288 | cf_info->frame_names.resize(num_nodes); |
289 | std::vector<Node*> parent_nodes; |
290 | parent_nodes.resize(num_nodes); |
291 | std::vector<bool> visited; |
292 | visited.resize(num_nodes); |
293 | |
294 | string frame_name; |
295 | std::deque<Node*> ready; |
296 | |
297 | // Initialize with the root nodes. |
298 | for (Node* n : g->nodes()) { |
299 | if (n->in_edges().empty()) { |
300 | visited[n->id()] = true; |
301 | cf_info->unique_frame_names.insert(frame_name); |
302 | ready.push_back(n); |
303 | } |
304 | } |
305 | |
306 | while (!ready.empty()) { |
307 | Node* curr_node = ready.front(); |
308 | int curr_id = curr_node->id(); |
309 | ready.pop_front(); |
310 | |
311 | Node* parent = nullptr; |
312 | if (IsEnter(curr_node)) { |
313 | // Enter a child frame. |
314 | TF_RETURN_IF_ERROR( |
315 | GetNodeAttr(curr_node->attrs(), "frame_name" , &frame_name)); |
316 | parent = curr_node; |
317 | } else if (IsExit(curr_node)) { |
318 | // Exit to the parent frame. |
319 | parent = parent_nodes[curr_id]; |
320 | if (!parent) { |
321 | return errors::InvalidArgument( |
322 | "Invalid Exit op: Cannot find a corresponding Enter op." ); |
323 | } |
324 | frame_name = cf_info->frame_names[parent->id()]; |
325 | parent = parent_nodes[parent->id()]; |
326 | } else { |
327 | parent = parent_nodes[curr_id]; |
328 | frame_name = cf_info->frame_names[curr_id]; |
329 | } |
330 | |
331 | for (const Edge* out_edge : curr_node->out_edges()) { |
332 | Node* out = out_edge->dst(); |
333 | if (IsSink(out)) continue; |
334 | const int out_id = out->id(); |
335 | |
336 | // Add to ready queue if not visited. |
337 | bool is_visited = visited[out_id]; |
338 | if (!is_visited) { |
339 | ready.push_back(out); |
340 | visited[out_id] = true; |
341 | |
342 | // Process the node 'out'. |
343 | cf_info->frame_names[out_id] = frame_name; |
344 | parent_nodes[out_id] = parent; |
345 | cf_info->unique_frame_names.insert(frame_name); |
346 | } |
347 | } |
348 | } |
349 | |
350 | return OkStatus(); |
351 | } |
352 | |
353 | void ImmutableExecutorState::InitializePending(const Graph* graph, |
354 | const ControlFlowInfo& cf_info) { |
355 | for (auto& it : cf_info.unique_frame_names) { |
356 | FrameInfo* finfo = EnsureFrameInfo(it); |
357 | DCHECK_EQ(finfo->pending_counts.get(), nullptr); |
358 | finfo->pending_counts = |
359 | std::make_unique<PendingCounts>(finfo->pending_counts_layout); |
360 | } |
361 | |
362 | if (!requires_control_flow_) { |
363 | atomic_pending_counts_.reset(new std::atomic<int32>[gview_.num_nodes()]); |
364 | std::fill(atomic_pending_counts_.get(), |
365 | atomic_pending_counts_.get() + gview_.num_nodes(), 0); |
366 | } |
367 | |
368 | for (const Node* n : graph->nodes()) { |
369 | if (IsSink(n)) continue; |
370 | const int id = n->id(); |
371 | const string& name = cf_info.frame_names[id]; |
372 | size_t max_pending, max_dead; |
373 | GetMaxPendingCounts(n, &max_pending, &max_dead); |
374 | auto& counts = EnsureFrameInfo(name)->pending_counts; |
375 | counts->set_initial_count(pending_ids_[id], max_pending); |
376 | if (!requires_control_flow_) { |
377 | atomic_pending_counts_[id] = max_pending; |
378 | } |
379 | } |
380 | } |
381 | } // namespace tensorflow |
382 | |