1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/lower_while_op.h"
17
18#include "tensorflow/core/common_runtime/inline_function_utils.h"
19#include "tensorflow/core/framework/node_def_builder.h"
20#include "tensorflow/core/framework/types.pb.h"
21#include "tensorflow/core/graph/graph.h"
22#include "tensorflow/core/graph/node_builder.h"
23
24namespace tensorflow {
25
26namespace {
27
28using NodeOut = NodeBuilder::NodeOut;
29
30constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
31 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
32
33// Helper to convert a functional While op to its lowered form.
34//
35// Example:
36//
37// Input graph:
38//
39// loop_var -> WhileOp<cond_func, body_func> -> consumer
40//
41// Output graph(top to down flow):
42//
43// loop_var
44// |
45// Enter
46// |
47// cond_func ---<--- Merge ---<--- NextIteration
48// | | |
49// V V ^
50// | | |
51// LoopCond --->--- Switch --->--- body_func
52// |
53// Exit
54// |
55// consumer
56//
57// DT_RESOURCE tensors are handled specially:
58//
59// resource_loop_var -> Enter[is_constant=True] -> cond_func and body_func
60// |
61// V
62// consumer
63class LowerWhileHelper {
64 public:
65 static Status Run(Node* while_op, const NameAttrList& cond_fn,
66 const NameAttrList& body_fn, int parallel_iterations,
67 Graph* graph, const FunctionLibraryDefinition* flib_def,
68 bool keep_node_fetchable) {
69 LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations,
70 graph, flib_def, keep_node_fetchable);
71 return helper.RunInternal();
72 }
73
74 private:
75 // Create a LowerWhileHelper to create the lowering of While op that has cond
76 // and body functions named `cond_fn_name` and `body_fn_name` respectively in
77 // the given graph.
78 LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
79 const NameAttrList& body_fn, int parallel_iterations,
80 Graph* graph, const FunctionLibraryDefinition* flib_def,
81 bool keep_node_fetchable);
82
83 Status RunInternal();
84
85 void InitializeInputOutputToLoweredNodeMap();
86
87 // Creates an Enter node for each `while_op_` input and adds them to
88 // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
89 // `src` node we add a control edge from `src` to each Enter node.
90 Status CreateEnterNodes();
91
92 // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
93 // Initially now both inputs of a Merge node are the Enter node. Input at
94 // index 1 is later updated to the output of NextIteration node in
95 // `UpdateMergeNodes`.
96 Status CreateMergeNodes();
97
98 // Creates the call node for cond func and stores in `cond_call_node_`.
99 Status CreateCondFuncCallNode();
100
101 // Creates a Switch node for each loop var and adds to `switch_nodes_`.
102 // Output at index 1(true) of a Switch node is fed into the loop body.
103 // Output at index 0(false) of a Switch node is fed into the Exit nodes.
104 Status CreateSwitchNodes();
105
106 // Creates the call node for body func and stores in `body_call_node_`.
107 Status CreateBodyFuncCallNode();
108
109 // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
110 // are fed into the consumers of the `while_op_`.
111 Status CreateExitNodes();
112
113 // Creates an NextIteration node for each loop var and adds to
114 // `next_iteration_nodes_`.
115 Status CreateNextIterationNodes();
116
117 // Updates input at index 1 of each merge node created in `CreateMergeNodes`
118 // to use the output of NextIteration node created in
119 // `CreateNextIterationNodes` instead.
120 Status UpdateMergeNodes();
121
122 // Updates consumers of the original `while_op_` to instead use the outputs
123 // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
124 // edges to depend on `lowered_while_executed_` instead.
125 Status UpdateConsumers();
126
127 // Returns unique name containing the name of the While op being rewritten
128 // (name_), infix and a suffix to ensure it is unique within the graph.
129 string NewName(const string& infix);
130
131 // Returns whether the While op's input/output at `index` is a `DT_RESOURCE`.
132 bool IsResource(int index);
133
134 // The original While op.
135 Node* while_op_;
136 // The call node for the cond branch.
137 Node* cond_call_node_;
138 // The LoopCond node specifying the loop termination condition.
139 Node* loop_cond_node_;
140 // The call node for the body branch.
141 Node* body_call_node_;
142 // The node with the same name as the original While op:
143 // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'.
144 // (b) NoOp node with control edge from 'lowered_while_executed_' otherwise.
145 Node* lowered_while_output_;
146 // The NoOp node with control edges from all Exit nodes. This node will be
147 // used as a source of outgoing control edges from lowered While node.
148 Node* lowered_while_executed_;
149 Graph* graph_;
150 const FunctionLibraryDefinition* flib_def_;
151 // Name of the `while_op_`.
152 string name_;
153 // Max number of parallel_iterations for the while loop.
154 const int parallel_iterations_;
155 bool keep_node_fetchable_;
156
157 NodeDebugInfo debug_info_;
158 NodeBuilder cond_call_builder_;
159 NodeBuilder body_call_builder_;
160
161 // `Enter` nodes, one per loop input/output.
162 // Note: `Enter` nodes with type `DT_RESOURCE` have attr `is_constant=True`.
163 std::vector<Node*> enter_nodes_;
164
165 // Merge/Switch/NextIteration/Exit nodes, one per non-resource loop
166 // input/output.
167 std::vector<Node*> merge_nodes_;
168 std::vector<Node*> switch_nodes_;
169 std::vector<Node*> exit_nodes_;
170 std::vector<Node*> next_iterations_nodes_;
171 // Maps from the loop input/output indices to their corresponding
172 // Merge/Switch/NextIteration/Exit node indices. For inputs/outputs of
173 // `DT_RESOURCE` type there are no Merge/Switch/NextIteration/Exit nodes
174 // in which case the mapping contains -1.
175 std::vector<int> op_input_output_to_lowered_node_;
176
177 size_t num_loop_inputs_;
178};
179
180LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
181 const NameAttrList& body_fn,
182 int parallel_iterations, Graph* graph,
183 const FunctionLibraryDefinition* flib_def,
184 bool keep_node_fetchable)
185 : while_op_(while_op),
186 graph_(graph),
187 flib_def_(flib_def),
188 name_(while_op->name()),
189 parallel_iterations_(parallel_iterations),
190 keep_node_fetchable_(keep_node_fetchable),
191 debug_info_(*while_op_),
192 cond_call_builder_(NewName("cond"), cond_fn.name(), flib_def,
193 &debug_info_),
194 body_call_builder_(NewName("body"), body_fn.name(), flib_def,
195 &debug_info_),
196 num_loop_inputs_(while_op_->num_inputs()) {
197 cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
198 for (const auto& i : cond_fn.attr()) {
199 cond_call_builder_.Attr(i.first, i.second);
200 }
201 body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
202 for (const auto& i : body_fn.attr()) {
203 body_call_builder_.Attr(i.first, i.second);
204 }
205 // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
206 // because we need to set it's elements out of order in `CreateEnterNodes`.
207 enter_nodes_.resize(num_loop_inputs_);
208 merge_nodes_.reserve(num_loop_inputs_);
209 switch_nodes_.reserve(num_loop_inputs_);
210 exit_nodes_.reserve(num_loop_inputs_);
211 next_iterations_nodes_.reserve(num_loop_inputs_);
212 op_input_output_to_lowered_node_.resize(num_loop_inputs_, -1);
213}
214
215Status LowerWhileHelper::RunInternal() {
216 InitializeInputOutputToLoweredNodeMap();
217 TF_RETURN_IF_ERROR(CreateEnterNodes());
218 TF_RETURN_IF_ERROR(CreateMergeNodes());
219 TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
220 TF_RETURN_IF_ERROR(CreateSwitchNodes());
221 TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
222 TF_RETURN_IF_ERROR(CreateExitNodes());
223 TF_RETURN_IF_ERROR(CreateNextIterationNodes());
224 TF_RETURN_IF_ERROR(UpdateMergeNodes());
225 TF_RETURN_IF_ERROR(UpdateConsumers());
226 return OkStatus();
227}
228
229void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() {
230 int counter = 0;
231 for (int i = 0; i < num_loop_inputs_; i++) {
232 if (!IsResource(i)) {
233 op_input_output_to_lowered_node_[i] = counter++;
234 }
235 }
236}
237
238Status LowerWhileHelper::CreateEnterNodes() {
239 // Note: `Node::input_edge` runs in O(num_inputs) so we use
240 // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
241 // and not O(num_inputs^2).
242 std::vector<const Edge*> edges;
243 TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
244 for (const Edge* edge : edges) {
245 Node* enter_node;
246 NodeBuilder builder =
247 NodeBuilder(NewName("enter"), "Enter", flib_def_, &debug_info_)
248 .Input(NodeOut(edge->src(), edge->src_output()))
249 .Attr("frame_name", name_)
250 .Attr("parallel_iterations", parallel_iterations_)
251 .Device(edge->src()->requested_device())
252 .AssignedDevice(edge->src()->assigned_device_name());
253 if (IsResource(edge->dst_input())) {
254 builder.Attr("is_constant", true);
255 }
256 TF_RETURN_IF_ERROR(builder.Finalize(graph_, &enter_node));
257 enter_nodes_[edge->dst_input()] = enter_node;
258 }
259 // Create a NoOp node that takes incoming control inputs of the original While
260 // op as control inputs and use it as a control input for all Enter nodes.
261 std::vector<Node*> control_inputs;
262 for (const Edge* e : while_op_->in_edges()) {
263 if (e->IsControlEdge()) {
264 control_inputs.push_back(e->src());
265 }
266 }
267 if (!control_inputs.empty()) {
268 Node* incoming_control_node;
269 TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopControlInputs"), "NoOp",
270 flib_def_, &debug_info_)
271 .ControlInputs(control_inputs)
272 .Device(while_op_->requested_device())
273 .Finalize(graph_, &incoming_control_node));
274 for (Node* n : enter_nodes_) {
275 graph_->AddControlEdge(incoming_control_node, n);
276 }
277 }
278 return OkStatus();
279}
280
281Status LowerWhileHelper::CreateMergeNodes() {
282 for (Node* enter_node : enter_nodes_) {
283 if (enter_node->output_type(0) == DT_RESOURCE) {
284 continue;
285 }
286 Node* merge_node;
287 TF_RETURN_IF_ERROR(
288 NodeBuilder(NewName("merge"), "Merge", flib_def_, &debug_info_)
289 .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
290 .Device(enter_node->requested_device())
291 .AssignedDevice(enter_node->assigned_device_name())
292 .Finalize(graph_, &merge_node));
293 merge_nodes_.emplace_back(merge_node);
294 }
295 return OkStatus();
296}
297
298Status LowerWhileHelper::CreateCondFuncCallNode() {
299 for (int i = 0; i < num_loop_inputs_; i++) {
300 if (IsResource(i)) {
301 cond_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
302 } else {
303 cond_call_builder_.Input(
304 NodeOut(merge_nodes_[op_input_output_to_lowered_node_[i]], 0));
305 }
306 }
307 cond_call_builder_.Device(while_op_->requested_device());
308 TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
309 // Add a control edge to make sure the Const nodes in the cond function
310 // are in the same frame as the rest of the function, otherwise
311 // `BuildControlFlowInfo` throws an error.
312 graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
313 TF_RETURN_IF_ERROR(
314 NodeBuilder(NewName("LoopCond"), "LoopCond", flib_def_, &debug_info_)
315 .Input(NodeOut(cond_call_node_, 0))
316 .Device(while_op_->requested_device())
317 .Finalize(graph_, &loop_cond_node_));
318 return OkStatus();
319}
320
321Status LowerWhileHelper::CreateSwitchNodes() {
322 for (int i = 0; i < num_loop_inputs_; i++) {
323 if (IsResource(i)) {
324 continue;
325 }
326 string op_name;
327 {
328 const Node* input_node;
329 TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
330 op_name = strings::StrCat(input_node->name(), "_switch");
331 }
332 Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
333 Node* switch_node;
334 string op_type = "Switch";
335 if (IsRefType(merge_node->output_type(0))) {
336 op_type = "RefSwitch";
337 }
338 TF_RETURN_IF_ERROR(
339 NodeBuilder(NewName(op_name), op_type, flib_def_, &debug_info_)
340 .Input(NodeOut(merge_node, 0))
341 .Input(NodeOut(loop_cond_node_, 0))
342 .Device(merge_node->requested_device())
343 .AssignedDevice(merge_node->assigned_device_name())
344 .Finalize(graph_, &switch_node));
345 switch_nodes_.emplace_back(switch_node);
346 }
347 return OkStatus();
348}
349
350Status LowerWhileHelper::CreateBodyFuncCallNode() {
351 for (int i = 0; i < num_loop_inputs_; i++) {
352 if (IsResource(i)) {
353 body_call_builder_.Input(NodeOut(enter_nodes_[i], 0));
354 } else {
355 body_call_builder_.Input(
356 NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]], 1));
357 }
358 }
359 body_call_builder_.Device(while_op_->requested_device());
360 TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
361 // Add a control edge to make sure the Const nodes in the body function
362 // are in the same frame as the rest of the function, otherwise
363 // `BuildControlFlowInfo` throws an error.
364 // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
365 // this is how tf.while_loop does it. Can this affect performance if the 0th
366 // node is not the first one to be ready? Can we speed that case up using some
367 // sort of multi-input Merge?
368 Node* body_control_node_;
369 string op_type = "Identity";
370 if (IsRefType(switch_nodes_[0]->output_type(1))) {
371 op_type = "RefIdentity";
372 }
373 TF_RETURN_IF_ERROR(NodeBuilder(NewName("loop_body_control"), op_type,
374 flib_def_, &debug_info_)
375 .Input(NodeOut(switch_nodes_[0], 1))
376 .Device(while_op_->requested_device())
377 .Finalize(graph_, &body_control_node_));
378 graph_->AddControlEdge(body_control_node_, body_call_node_);
379 return OkStatus();
380}
381
382Status LowerWhileHelper::CreateExitNodes() {
383 std::vector<NodeOut> outputs;
384 outputs.reserve(num_loop_inputs_);
385 for (int i = 0; i < num_loop_inputs_; i++) {
386 if (IsResource(i)) {
387 // Note(srbs): A resource output of this While should never be used but we
388 // need this for the IdentityN node below.
389 OutputTensor resource_tensor;
390 TF_RETURN_IF_ERROR(enter_nodes_[i]->input_tensor(0, &resource_tensor));
391 outputs.emplace_back(resource_tensor);
392 } else {
393 Node* exit_node;
394 TF_RETURN_IF_ERROR(
395 NodeBuilder(NewName("exit"), "Exit", flib_def_, &debug_info_)
396 .Input(NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]],
397 0))
398 .Device(switch_nodes_[op_input_output_to_lowered_node_[i]]
399 ->requested_device())
400 .AssignedDevice(switch_nodes_[op_input_output_to_lowered_node_[i]]
401 ->assigned_device_name())
402 .Finalize(graph_, &exit_node));
403 exit_nodes_.emplace_back(exit_node);
404 outputs.emplace_back(NodeOut(exit_node, 0));
405 }
406 }
407
408 // We split data and control outputs of lowered while op, because otherwise
409 // after lowering of multi-device loop body we might end up with DT_RESOURCE
410 // inputs from multiple devices coming into IdentityN.
411
412 // Add a NoOp node that has control edges from all Exit nodes. This node is
413 // used for rewriting control edges with the original while op as src.
414 TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopExecuted"), "NoOp",
415 OpRegistry::Global(), &debug_info_)
416 .ControlInputs(exit_nodes_)
417 .Device(while_op_->requested_device())
418 .Finalize(graph_, &lowered_while_executed_));
419
420 if (keep_node_fetchable_) {
421 // Add an IdentityN node that has the same outputs and same name as the
422 // original functional While op. This is used for fetching the output of the
423 // While node by name in calls to sess.run.
424 TF_RETURN_IF_ERROR(
425 NodeBuilder(name_, "IdentityN", OpRegistry::Global(), &debug_info_)
426 .Input(outputs)
427 .Device(while_op_->requested_device())
428 .Finalize(graph_, &lowered_while_output_));
429 } else {
430 // Even if we don't plan to fetch tensors from the lowered While op, we must
431 // keep it a valid source of control edges, because it might be a part of
432 // function control output set.
433 TF_RETURN_IF_ERROR(
434 NodeBuilder(name_, "NoOp", OpRegistry::Global(), &debug_info_)
435 .ControlInput(lowered_while_executed_)
436 .Device(while_op_->requested_device())
437 .Finalize(graph_, &lowered_while_output_));
438 }
439
440 return OkStatus();
441}
442
443Status LowerWhileHelper::CreateNextIterationNodes() {
444 for (int i = 0; i < num_loop_inputs_; i++) {
445 Node* next_iteration;
446 if (IsResource(i)) {
447 continue;
448 }
449 Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]];
450 TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
451 flib_def_, &debug_info_)
452 .Input(NodeOut(body_call_node_, i))
453 .ControlInput(body_call_node_)
454 .Device(merge_node->requested_device())
455 .AssignedDevice(merge_node->assigned_device_name())
456 .Finalize(graph_, &next_iteration));
457 next_iterations_nodes_.emplace_back(next_iteration);
458 }
459 return OkStatus();
460}
461
462Status LowerWhileHelper::UpdateMergeNodes() {
463 for (int i = 0; i < merge_nodes_.size(); i++) {
464 TF_RETURN_IF_ERROR(
465 graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
466 }
467 return OkStatus();
468}
469
470Status LowerWhileHelper::UpdateConsumers() {
471 for (const Edge* e : while_op_->out_edges()) {
472 if (e->IsControlEdge()) {
473 graph_->AddControlEdge(lowered_while_executed_, e->dst());
474 } else {
475 if (IsResource(e->src_output())) {
476 OutputTensor resource;
477 TF_RETURN_IF_ERROR(
478 enter_nodes_[e->src_output()]->input_tensor(0, &resource));
479 graph_->AddEdge(resource.node, resource.index, e->dst(),
480 e->dst_input());
481 } else {
482 // Feed the outputs directly from the exit nodes so that downstream ops
483 // can start before all the outputs have been computed.
484 int exit_node_index = op_input_output_to_lowered_node_[e->src_output()];
485 if (exit_node_index < 0) {
486 return errors::Internal(
487 "Expecting an Exit node for a Resource tensor.");
488 }
489 graph_->AddEdge(exit_nodes_[exit_node_index], 0, e->dst(),
490 e->dst_input());
491 }
492 }
493 }
494 return OkStatus();
495}
496
497string LowerWhileHelper::NewName(const string& infix) {
498 return graph_->NewName(strings::StrCat(name_, "/", infix));
499}
500
501bool LowerWhileHelper::IsResource(int index) {
502 return while_op_->input_type(index) == DT_RESOURCE;
503}
504
505} // namespace
506
507Status RewriteWhileNode(Node* n, Graph* g,
508 const FunctionLibraryDefinition* flib_def,
509 bool keep_node_fetchable) {
510 VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable
511 << "): " << SummarizeNode(*n);
512
513 const AttrValue* cond_attr = n->attrs().Find("cond");
514 if (cond_attr == nullptr) {
515 return errors::InvalidArgument("While cond function missing");
516 }
517 const AttrValue* body_attr = n->attrs().Find("body");
518 if (body_attr == nullptr) {
519 return errors::InvalidArgument("While body function missing");
520 }
521 const AttrValue* parallel_iterations_attr =
522 n->attrs().Find("parallel_iterations");
523 if (parallel_iterations_attr == nullptr) {
524 return errors::InvalidArgument("parallel_iterations attr missing");
525 }
526 if (parallel_iterations_attr->i() < 1) {
527 return errors::InvalidArgument("parallel_iterations must be > 0");
528 }
529
530 TF_RETURN_IF_ERROR(LowerWhileHelper::Run(
531 n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g,
532 flib_def, keep_node_fetchable));
533 g->RemoveNode(n);
534
535 return OkStatus();
536}
537
538} // namespace tensorflow
539