1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/lower_while_op.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
19 | #include "tensorflow/core/framework/node_def_builder.h" |
20 | #include "tensorflow/core/framework/types.pb.h" |
21 | #include "tensorflow/core/graph/graph.h" |
22 | #include "tensorflow/core/graph/node_builder.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace { |
27 | |
28 | using NodeOut = NodeBuilder::NodeOut; |
29 | |
30 | constexpr const char* const kLowerAsMultiDeviceFunctionAttr = |
31 | LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; |
32 | |
33 | // Helper to convert a functional While op to its lowered form. |
34 | // |
35 | // Example: |
36 | // |
37 | // Input graph: |
38 | // |
39 | // loop_var -> WhileOp<cond_func, body_func> -> consumer |
40 | // |
41 | // Output graph(top to down flow): |
42 | // |
43 | // loop_var |
44 | // | |
45 | // Enter |
46 | // | |
47 | // cond_func ---<--- Merge ---<--- NextIteration |
48 | // | | | |
49 | // V V ^ |
50 | // | | | |
51 | // LoopCond --->--- Switch --->--- body_func |
52 | // | |
53 | // Exit |
54 | // | |
55 | // consumer |
56 | // |
57 | // DT_RESOURCE tensors are handled specially: |
58 | // |
59 | // resource_loop_var -> Enter[is_constant=True] -> cond_func and body_func |
60 | // | |
61 | // V |
62 | // consumer |
63 | class LowerWhileHelper { |
64 | public: |
65 | static Status Run(Node* while_op, const NameAttrList& cond_fn, |
66 | const NameAttrList& body_fn, int parallel_iterations, |
67 | Graph* graph, const FunctionLibraryDefinition* flib_def, |
68 | bool keep_node_fetchable) { |
69 | LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations, |
70 | graph, flib_def, keep_node_fetchable); |
71 | return helper.RunInternal(); |
72 | } |
73 | |
74 | private: |
75 | // Create a LowerWhileHelper to create the lowering of While op that has cond |
76 | // and body functions named `cond_fn_name` and `body_fn_name` respectively in |
77 | // the given graph. |
78 | LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, |
79 | const NameAttrList& body_fn, int parallel_iterations, |
80 | Graph* graph, const FunctionLibraryDefinition* flib_def, |
81 | bool keep_node_fetchable); |
82 | |
83 | Status RunInternal(); |
84 | |
85 | void InitializeInputOutputToLoweredNodeMap(); |
86 | |
87 | // Creates an Enter node for each `while_op_` input and adds them to |
88 | // `enter_nodes_`. If the `while_op_` has an incoming control edge from a |
89 | // `src` node we add a control edge from `src` to each Enter node. |
90 | Status CreateEnterNodes(); |
91 | |
92 | // Creates a Merge node for each Enter node and adds to `merge_nodes_`. |
93 | // Initially now both inputs of a Merge node are the Enter node. Input at |
94 | // index 1 is later updated to the output of NextIteration node in |
95 | // `UpdateMergeNodes`. |
96 | Status CreateMergeNodes(); |
97 | |
98 | // Creates the call node for cond func and stores in `cond_call_node_`. |
99 | Status CreateCondFuncCallNode(); |
100 | |
101 | // Creates a Switch node for each loop var and adds to `switch_nodes_`. |
102 | // Output at index 1(true) of a Switch node is fed into the loop body. |
103 | // Output at index 0(false) of a Switch node is fed into the Exit nodes. |
104 | Status CreateSwitchNodes(); |
105 | |
106 | // Creates the call node for body func and stores in `body_call_node_`. |
107 | Status CreateBodyFuncCallNode(); |
108 | |
109 | // Creates an Exit node for each loop var and adds to `exit_nodes_`. These |
110 | // are fed into the consumers of the `while_op_`. |
111 | Status CreateExitNodes(); |
112 | |
113 | // Creates an NextIteration node for each loop var and adds to |
114 | // `next_iteration_nodes_`. |
115 | Status CreateNextIterationNodes(); |
116 | |
117 | // Updates input at index 1 of each merge node created in `CreateMergeNodes` |
118 | // to use the output of NextIteration node created in |
119 | // `CreateNextIterationNodes` instead. |
120 | Status UpdateMergeNodes(); |
121 | |
122 | // Updates consumers of the original `while_op_` to instead use the outputs |
123 | // from the exit nodes in `exit_nodes_`. Also updates any outgoing control |
124 | // edges to depend on `lowered_while_executed_` instead. |
125 | Status UpdateConsumers(); |
126 | |
127 | // Returns unique name containing the name of the While op being rewritten |
128 | // (name_), infix and a suffix to ensure it is unique within the graph. |
129 | string NewName(const string& infix); |
130 | |
131 | // Returns whether the While op's input/output at `index` is a `DT_RESOURCE`. |
132 | bool IsResource(int index); |
133 | |
134 | // The original While op. |
135 | Node* while_op_; |
136 | // The call node for the cond branch. |
137 | Node* cond_call_node_; |
138 | // The LoopCond node specifying the loop termination condition. |
139 | Node* loop_cond_node_; |
140 | // The call node for the body branch. |
141 | Node* body_call_node_; |
142 | // The node with the same name as the original While op: |
143 | // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'. |
144 | // (b) NoOp node with control edge from 'lowered_while_executed_' otherwise. |
145 | Node* lowered_while_output_; |
146 | // The NoOp node with control edges from all Exit nodes. This node will be |
147 | // used as a source of outgoing control edges from lowered While node. |
148 | Node* lowered_while_executed_; |
149 | Graph* graph_; |
150 | const FunctionLibraryDefinition* flib_def_; |
151 | // Name of the `while_op_`. |
152 | string name_; |
153 | // Max number of parallel_iterations for the while loop. |
154 | const int parallel_iterations_; |
155 | bool keep_node_fetchable_; |
156 | |
157 | NodeDebugInfo debug_info_; |
158 | NodeBuilder cond_call_builder_; |
159 | NodeBuilder body_call_builder_; |
160 | |
161 | // `Enter` nodes, one per loop input/output. |
162 | // Note: `Enter` nodes with type `DT_RESOURCE` have attr `is_constant=True`. |
163 | std::vector<Node*> enter_nodes_; |
164 | |
165 | // Merge/Switch/NextIteration/Exit nodes, one per non-resource loop |
166 | // input/output. |
167 | std::vector<Node*> merge_nodes_; |
168 | std::vector<Node*> switch_nodes_; |
169 | std::vector<Node*> exit_nodes_; |
170 | std::vector<Node*> next_iterations_nodes_; |
171 | // Maps from the loop input/output indices to their corresponding |
172 | // Merge/Switch/NextIteration/Exit node indices. For inputs/outputs of |
173 | // `DT_RESOURCE` type there are no Merge/Switch/NextIteration/Exit nodes |
174 | // in which case the mapping contains -1. |
175 | std::vector<int> op_input_output_to_lowered_node_; |
176 | |
177 | size_t num_loop_inputs_; |
178 | }; |
179 | |
180 | LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, |
181 | const NameAttrList& body_fn, |
182 | int parallel_iterations, Graph* graph, |
183 | const FunctionLibraryDefinition* flib_def, |
184 | bool keep_node_fetchable) |
185 | : while_op_(while_op), |
186 | graph_(graph), |
187 | flib_def_(flib_def), |
188 | name_(while_op->name()), |
189 | parallel_iterations_(parallel_iterations), |
190 | keep_node_fetchable_(keep_node_fetchable), |
191 | debug_info_(*while_op_), |
192 | cond_call_builder_(NewName("cond" ), cond_fn.name(), flib_def, |
193 | &debug_info_), |
194 | body_call_builder_(NewName("body" ), body_fn.name(), flib_def, |
195 | &debug_info_), |
196 | num_loop_inputs_(while_op_->num_inputs()) { |
197 | cond_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); |
198 | for (const auto& i : cond_fn.attr()) { |
199 | cond_call_builder_.Attr(i.first, i.second); |
200 | } |
201 | body_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); |
202 | for (const auto& i : body_fn.attr()) { |
203 | body_call_builder_.Attr(i.first, i.second); |
204 | } |
205 | // We intentionally `resize` instead of `reserve` space in `enter_nodes_` |
206 | // because we need to set it's elements out of order in `CreateEnterNodes`. |
207 | enter_nodes_.resize(num_loop_inputs_); |
208 | merge_nodes_.reserve(num_loop_inputs_); |
209 | switch_nodes_.reserve(num_loop_inputs_); |
210 | exit_nodes_.reserve(num_loop_inputs_); |
211 | next_iterations_nodes_.reserve(num_loop_inputs_); |
212 | op_input_output_to_lowered_node_.resize(num_loop_inputs_, -1); |
213 | } |
214 | |
215 | Status LowerWhileHelper::RunInternal() { |
216 | InitializeInputOutputToLoweredNodeMap(); |
217 | TF_RETURN_IF_ERROR(CreateEnterNodes()); |
218 | TF_RETURN_IF_ERROR(CreateMergeNodes()); |
219 | TF_RETURN_IF_ERROR(CreateCondFuncCallNode()); |
220 | TF_RETURN_IF_ERROR(CreateSwitchNodes()); |
221 | TF_RETURN_IF_ERROR(CreateBodyFuncCallNode()); |
222 | TF_RETURN_IF_ERROR(CreateExitNodes()); |
223 | TF_RETURN_IF_ERROR(CreateNextIterationNodes()); |
224 | TF_RETURN_IF_ERROR(UpdateMergeNodes()); |
225 | TF_RETURN_IF_ERROR(UpdateConsumers()); |
226 | return OkStatus(); |
227 | } |
228 | |
229 | void LowerWhileHelper::InitializeInputOutputToLoweredNodeMap() { |
230 | int counter = 0; |
231 | for (int i = 0; i < num_loop_inputs_; i++) { |
232 | if (!IsResource(i)) { |
233 | op_input_output_to_lowered_node_[i] = counter++; |
234 | } |
235 | } |
236 | } |
237 | |
238 | Status LowerWhileHelper::CreateEnterNodes() { |
239 | // Note: `Node::input_edge` runs in O(num_inputs) so we use |
240 | // `Node::input_edges` instead so that below loop runs in O(num_inputs) time |
241 | // and not O(num_inputs^2). |
242 | std::vector<const Edge*> edges; |
243 | TF_RETURN_IF_ERROR(while_op_->input_edges(&edges)); |
244 | for (const Edge* edge : edges) { |
245 | Node* enter_node; |
246 | NodeBuilder builder = |
247 | NodeBuilder(NewName("enter" ), "Enter" , flib_def_, &debug_info_) |
248 | .Input(NodeOut(edge->src(), edge->src_output())) |
249 | .Attr("frame_name" , name_) |
250 | .Attr("parallel_iterations" , parallel_iterations_) |
251 | .Device(edge->src()->requested_device()) |
252 | .AssignedDevice(edge->src()->assigned_device_name()); |
253 | if (IsResource(edge->dst_input())) { |
254 | builder.Attr("is_constant" , true); |
255 | } |
256 | TF_RETURN_IF_ERROR(builder.Finalize(graph_, &enter_node)); |
257 | enter_nodes_[edge->dst_input()] = enter_node; |
258 | } |
259 | // Create a NoOp node that takes incoming control inputs of the original While |
260 | // op as control inputs and use it as a control input for all Enter nodes. |
261 | std::vector<Node*> control_inputs; |
262 | for (const Edge* e : while_op_->in_edges()) { |
263 | if (e->IsControlEdge()) { |
264 | control_inputs.push_back(e->src()); |
265 | } |
266 | } |
267 | if (!control_inputs.empty()) { |
268 | Node* incoming_control_node; |
269 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopControlInputs" ), "NoOp" , |
270 | flib_def_, &debug_info_) |
271 | .ControlInputs(control_inputs) |
272 | .Device(while_op_->requested_device()) |
273 | .Finalize(graph_, &incoming_control_node)); |
274 | for (Node* n : enter_nodes_) { |
275 | graph_->AddControlEdge(incoming_control_node, n); |
276 | } |
277 | } |
278 | return OkStatus(); |
279 | } |
280 | |
281 | Status LowerWhileHelper::CreateMergeNodes() { |
282 | for (Node* enter_node : enter_nodes_) { |
283 | if (enter_node->output_type(0) == DT_RESOURCE) { |
284 | continue; |
285 | } |
286 | Node* merge_node; |
287 | TF_RETURN_IF_ERROR( |
288 | NodeBuilder(NewName("merge" ), "Merge" , flib_def_, &debug_info_) |
289 | .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)}) |
290 | .Device(enter_node->requested_device()) |
291 | .AssignedDevice(enter_node->assigned_device_name()) |
292 | .Finalize(graph_, &merge_node)); |
293 | merge_nodes_.emplace_back(merge_node); |
294 | } |
295 | return OkStatus(); |
296 | } |
297 | |
298 | Status LowerWhileHelper::CreateCondFuncCallNode() { |
299 | for (int i = 0; i < num_loop_inputs_; i++) { |
300 | if (IsResource(i)) { |
301 | cond_call_builder_.Input(NodeOut(enter_nodes_[i], 0)); |
302 | } else { |
303 | cond_call_builder_.Input( |
304 | NodeOut(merge_nodes_[op_input_output_to_lowered_node_[i]], 0)); |
305 | } |
306 | } |
307 | cond_call_builder_.Device(while_op_->requested_device()); |
308 | TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_)); |
309 | // Add a control edge to make sure the Const nodes in the cond function |
310 | // are in the same frame as the rest of the function, otherwise |
311 | // `BuildControlFlowInfo` throws an error. |
312 | graph_->AddControlEdge(merge_nodes_[0], cond_call_node_); |
313 | TF_RETURN_IF_ERROR( |
314 | NodeBuilder(NewName("LoopCond" ), "LoopCond" , flib_def_, &debug_info_) |
315 | .Input(NodeOut(cond_call_node_, 0)) |
316 | .Device(while_op_->requested_device()) |
317 | .Finalize(graph_, &loop_cond_node_)); |
318 | return OkStatus(); |
319 | } |
320 | |
321 | Status LowerWhileHelper::CreateSwitchNodes() { |
322 | for (int i = 0; i < num_loop_inputs_; i++) { |
323 | if (IsResource(i)) { |
324 | continue; |
325 | } |
326 | string op_name; |
327 | { |
328 | const Node* input_node; |
329 | TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node)); |
330 | op_name = strings::StrCat(input_node->name(), "_switch" ); |
331 | } |
332 | Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]]; |
333 | Node* switch_node; |
334 | string op_type = "Switch" ; |
335 | if (IsRefType(merge_node->output_type(0))) { |
336 | op_type = "RefSwitch" ; |
337 | } |
338 | TF_RETURN_IF_ERROR( |
339 | NodeBuilder(NewName(op_name), op_type, flib_def_, &debug_info_) |
340 | .Input(NodeOut(merge_node, 0)) |
341 | .Input(NodeOut(loop_cond_node_, 0)) |
342 | .Device(merge_node->requested_device()) |
343 | .AssignedDevice(merge_node->assigned_device_name()) |
344 | .Finalize(graph_, &switch_node)); |
345 | switch_nodes_.emplace_back(switch_node); |
346 | } |
347 | return OkStatus(); |
348 | } |
349 | |
350 | Status LowerWhileHelper::CreateBodyFuncCallNode() { |
351 | for (int i = 0; i < num_loop_inputs_; i++) { |
352 | if (IsResource(i)) { |
353 | body_call_builder_.Input(NodeOut(enter_nodes_[i], 0)); |
354 | } else { |
355 | body_call_builder_.Input( |
356 | NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]], 1)); |
357 | } |
358 | } |
359 | body_call_builder_.Device(while_op_->requested_device()); |
360 | TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_)); |
361 | // Add a control edge to make sure the Const nodes in the body function |
362 | // are in the same frame as the rest of the function, otherwise |
363 | // `BuildControlFlowInfo` throws an error. |
364 | // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however |
365 | // this is how tf.while_loop does it. Can this affect performance if the 0th |
366 | // node is not the first one to be ready? Can we speed that case up using some |
367 | // sort of multi-input Merge? |
368 | Node* body_control_node_; |
369 | string op_type = "Identity" ; |
370 | if (IsRefType(switch_nodes_[0]->output_type(1))) { |
371 | op_type = "RefIdentity" ; |
372 | } |
373 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("loop_body_control" ), op_type, |
374 | flib_def_, &debug_info_) |
375 | .Input(NodeOut(switch_nodes_[0], 1)) |
376 | .Device(while_op_->requested_device()) |
377 | .Finalize(graph_, &body_control_node_)); |
378 | graph_->AddControlEdge(body_control_node_, body_call_node_); |
379 | return OkStatus(); |
380 | } |
381 | |
382 | Status LowerWhileHelper::CreateExitNodes() { |
383 | std::vector<NodeOut> outputs; |
384 | outputs.reserve(num_loop_inputs_); |
385 | for (int i = 0; i < num_loop_inputs_; i++) { |
386 | if (IsResource(i)) { |
387 | // Note(srbs): A resource output of this While should never be used but we |
388 | // need this for the IdentityN node below. |
389 | OutputTensor resource_tensor; |
390 | TF_RETURN_IF_ERROR(enter_nodes_[i]->input_tensor(0, &resource_tensor)); |
391 | outputs.emplace_back(resource_tensor); |
392 | } else { |
393 | Node* exit_node; |
394 | TF_RETURN_IF_ERROR( |
395 | NodeBuilder(NewName("exit" ), "Exit" , flib_def_, &debug_info_) |
396 | .Input(NodeOut(switch_nodes_[op_input_output_to_lowered_node_[i]], |
397 | 0)) |
398 | .Device(switch_nodes_[op_input_output_to_lowered_node_[i]] |
399 | ->requested_device()) |
400 | .AssignedDevice(switch_nodes_[op_input_output_to_lowered_node_[i]] |
401 | ->assigned_device_name()) |
402 | .Finalize(graph_, &exit_node)); |
403 | exit_nodes_.emplace_back(exit_node); |
404 | outputs.emplace_back(NodeOut(exit_node, 0)); |
405 | } |
406 | } |
407 | |
408 | // We split data and control outputs of lowered while op, because otherwise |
409 | // after lowering of multi-device loop body we might end up with DT_RESOURCE |
410 | // inputs from multiple devices coming into IdentityN. |
411 | |
412 | // Add a NoOp node that has control edges from all Exit nodes. This node is |
413 | // used for rewriting control edges with the original while op as src. |
414 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("LoopExecuted" ), "NoOp" , |
415 | OpRegistry::Global(), &debug_info_) |
416 | .ControlInputs(exit_nodes_) |
417 | .Device(while_op_->requested_device()) |
418 | .Finalize(graph_, &lowered_while_executed_)); |
419 | |
420 | if (keep_node_fetchable_) { |
421 | // Add an IdentityN node that has the same outputs and same name as the |
422 | // original functional While op. This is used for fetching the output of the |
423 | // While node by name in calls to sess.run. |
424 | TF_RETURN_IF_ERROR( |
425 | NodeBuilder(name_, "IdentityN" , OpRegistry::Global(), &debug_info_) |
426 | .Input(outputs) |
427 | .Device(while_op_->requested_device()) |
428 | .Finalize(graph_, &lowered_while_output_)); |
429 | } else { |
430 | // Even if we don't plan to fetch tensors from the lowered While op, we must |
431 | // keep it a valid source of control edges, because it might be a part of |
432 | // function control output set. |
433 | TF_RETURN_IF_ERROR( |
434 | NodeBuilder(name_, "NoOp" , OpRegistry::Global(), &debug_info_) |
435 | .ControlInput(lowered_while_executed_) |
436 | .Device(while_op_->requested_device()) |
437 | .Finalize(graph_, &lowered_while_output_)); |
438 | } |
439 | |
440 | return OkStatus(); |
441 | } |
442 | |
443 | Status LowerWhileHelper::CreateNextIterationNodes() { |
444 | for (int i = 0; i < num_loop_inputs_; i++) { |
445 | Node* next_iteration; |
446 | if (IsResource(i)) { |
447 | continue; |
448 | } |
449 | Node* merge_node = merge_nodes_[op_input_output_to_lowered_node_[i]]; |
450 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration" ), "NextIteration" , |
451 | flib_def_, &debug_info_) |
452 | .Input(NodeOut(body_call_node_, i)) |
453 | .ControlInput(body_call_node_) |
454 | .Device(merge_node->requested_device()) |
455 | .AssignedDevice(merge_node->assigned_device_name()) |
456 | .Finalize(graph_, &next_iteration)); |
457 | next_iterations_nodes_.emplace_back(next_iteration); |
458 | } |
459 | return OkStatus(); |
460 | } |
461 | |
462 | Status LowerWhileHelper::UpdateMergeNodes() { |
463 | for (int i = 0; i < merge_nodes_.size(); i++) { |
464 | TF_RETURN_IF_ERROR( |
465 | graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1)); |
466 | } |
467 | return OkStatus(); |
468 | } |
469 | |
470 | Status LowerWhileHelper::UpdateConsumers() { |
471 | for (const Edge* e : while_op_->out_edges()) { |
472 | if (e->IsControlEdge()) { |
473 | graph_->AddControlEdge(lowered_while_executed_, e->dst()); |
474 | } else { |
475 | if (IsResource(e->src_output())) { |
476 | OutputTensor resource; |
477 | TF_RETURN_IF_ERROR( |
478 | enter_nodes_[e->src_output()]->input_tensor(0, &resource)); |
479 | graph_->AddEdge(resource.node, resource.index, e->dst(), |
480 | e->dst_input()); |
481 | } else { |
482 | // Feed the outputs directly from the exit nodes so that downstream ops |
483 | // can start before all the outputs have been computed. |
484 | int exit_node_index = op_input_output_to_lowered_node_[e->src_output()]; |
485 | if (exit_node_index < 0) { |
486 | return errors::Internal( |
487 | "Expecting an Exit node for a Resource tensor." ); |
488 | } |
489 | graph_->AddEdge(exit_nodes_[exit_node_index], 0, e->dst(), |
490 | e->dst_input()); |
491 | } |
492 | } |
493 | } |
494 | return OkStatus(); |
495 | } |
496 | |
497 | string LowerWhileHelper::NewName(const string& infix) { |
498 | return graph_->NewName(strings::StrCat(name_, "/" , infix)); |
499 | } |
500 | |
501 | bool LowerWhileHelper::IsResource(int index) { |
502 | return while_op_->input_type(index) == DT_RESOURCE; |
503 | } |
504 | |
505 | } // namespace |
506 | |
507 | Status RewriteWhileNode(Node* n, Graph* g, |
508 | const FunctionLibraryDefinition* flib_def, |
509 | bool keep_node_fetchable) { |
510 | VLOG(2) << "Lower While node (keep_node_fetchable=" << keep_node_fetchable |
511 | << "): " << SummarizeNode(*n); |
512 | |
513 | const AttrValue* cond_attr = n->attrs().Find("cond" ); |
514 | if (cond_attr == nullptr) { |
515 | return errors::InvalidArgument("While cond function missing" ); |
516 | } |
517 | const AttrValue* body_attr = n->attrs().Find("body" ); |
518 | if (body_attr == nullptr) { |
519 | return errors::InvalidArgument("While body function missing" ); |
520 | } |
521 | const AttrValue* parallel_iterations_attr = |
522 | n->attrs().Find("parallel_iterations" ); |
523 | if (parallel_iterations_attr == nullptr) { |
524 | return errors::InvalidArgument("parallel_iterations attr missing" ); |
525 | } |
526 | if (parallel_iterations_attr->i() < 1) { |
527 | return errors::InvalidArgument("parallel_iterations must be > 0" ); |
528 | } |
529 | |
530 | TF_RETURN_IF_ERROR(LowerWhileHelper::Run( |
531 | n, cond_attr->func(), body_attr->func(), parallel_iterations_attr->i(), g, |
532 | flib_def, keep_node_fetchable)); |
533 | g->RemoveNode(n); |
534 | |
535 | return OkStatus(); |
536 | } |
537 | |
538 | } // namespace tensorflow |
539 | |