1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/cc/ops/while_loop.h"
17
18#include "tensorflow/cc/framework/scope_internal.h"
19#include "tensorflow/cc/ops/control_flow_ops_internal.h"
20#include "tensorflow/cc/ops/standard_ops.h"
21#include "tensorflow/core/common_runtime/shape_refiner.h"
22#include "tensorflow/core/graph/node_builder.h"
23
24namespace tensorflow {
25namespace ops {
26
27namespace {
28
29// Utility function for converting to internal C++ datatypes.
30OutputTensor ToOutputTensor(const Output& output) {
31 return OutputTensor(output.node(), output.index());
32}
33
34// Utility function for converting to internal C++ datatypes.
35std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) {
36 std::vector<OutputTensor> result(outputs.size());
37 for (int i = 0; i < outputs.size(); ++i) {
38 result[i] = ToOutputTensor(outputs[i]);
39 }
40 return result;
41}
42
43// Utility function for converting to internal C++ datatypes.
44std::vector<Node*> ToNodes(const std::vector<Output>& outputs) {
45 std::vector<Node*> result(outputs.size());
46 for (int i = 0; i < outputs.size(); ++i) {
47 result[i] = outputs[i].node();
48 }
49 return result;
50}
51
52// Manually generates the name of the `loop_var_idx`-th NextIteration node of a
53// loop being constructed with `scope`. This is used to define the backedge
54// before the NextIteration node is created.
55string NextIterationName(const Scope& scope, int loop_var_idx) {
56 string result;
57 const string& prefix = scope.impl()->name();
58 if (!prefix.empty()) strings::StrAppend(&result, prefix, "/");
59 strings::StrAppend(&result, "NextIteration");
60 if (loop_var_idx > 0) strings::StrAppend(&result, "_", loop_var_idx);
61 return result;
62}
63
64// Creates the `loop_var_idx`-th Merge node of a loop being constructed with
65// `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output.
66Status CreateMerge(const Scope& scope, int loop_var_idx,
67 const Output& enter_output, Output* merge_output) {
68 // The merge nodes accept the while loop's back edges as an input (i.e. the
69 // not-yet-created next iteration nodes). Use the underlying NodeBuilder API
70 // directly to create the back edge.
71 NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index());
72
73 const int next_output_index = 0;
74 DataType dtype = enter_output.node()->output_type(0);
75 NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx),
76 next_output_index, dtype);
77
78 std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input});
79 const string unique_name = scope.GetUniqueNameForOp("Merge");
80 NodeBuilder builder = NodeBuilder(unique_name, "Merge").Input(input_list);
81 scope.UpdateBuilder(&builder);
82
83 Node* merge_node;
84 TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node));
85 TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node));
86 *merge_output = Output(merge_node, 0);
87 return OkStatus();
88}
89
90// Creates the condition subgraph defined by `cond`.
91Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond,
92 const std::vector<Output>& inputs, Output* output) {
93 // The control dependency is for constants in the cond graph, and other ops
94 // that do not depend on the loop variables. This ensures that these ops are
95 // in the while loop frame (since they will indirectly depend on an Enter node
96 // defining the frame) and that they are executed once per loop iteration.
97 //
98 // TODO(skyewm): the control dep will be added to all nodes in the cond graph.
99 // This is at best unnecessary, and at worst may prevent different parts of
100 // different loop iterations from executing in parallel.
101 Scope cond_scope =
102 scope.NewSubScope("cond").WithControlDependencies(inputs[0]);
103 Output raw_cond_out;
104 TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out));
105
106 TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(),
107 raw_cond_out.index()));
108 if (raw_cond_out.type() != DT_BOOL) {
109 return errors::InvalidArgument(
110 "BuildWhileLoop: 'cond' argument must return a boolean output, got ",
111 DataTypeString(raw_cond_out.type()));
112 }
113 // TODO(skyewm): check that raw_cond_out is scalar
114
115 *output = LoopCond(scope, raw_cond_out).output;
116 return OkStatus();
117}
118
119// Create the body subgraph defined by `body`. `outputs` must be non-null and
120// empty.
121Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body,
122 const std::vector<Output>& inputs,
123 std::vector<Output>* outputs) {
124 DCHECK(outputs != nullptr);
125 DCHECK(outputs->empty());
126
127 // The control dependency is analogous to that in CreateCond().
128 Scope body_scope =
129 scope.NewSubScope("body").WithControlDependencies(inputs[0]);
130 TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs));
131
132 const size_t num_loop_vars = inputs.size();
133 if (outputs->size() != num_loop_vars) {
134 return errors::InvalidArgument(
135 "BuildWhileLoop: 'body' argument expected to return ", num_loop_vars,
136 " output(s), got ", outputs->size());
137 }
138 for (const Output& output : *outputs) {
139 TF_RETURN_IF_ERROR(
140 scope.graph()->IsValidOutputTensor(output.node(), output.index()));
141 // TODO(skyewm): check output types/shapes
142 }
143 return OkStatus();
144}
145
146} // namespace
147
148// A while loop with a single loop variable looks like this:
149//
150// (output)
151// ^ +---------------+
152// | | body subgraph +-------------+
153// Exit +---------------+ |
154// ^ ^ |
155// | | |
156// Switch<--------+ v
157// ^ | NextIteration
158// | +------+--------+ |
159// +---->| cond subgraph | |
160// | +---------------+ |
161// Merge<---------------------------+
162// ^
163// |
164// Enter
165// ^
166// |
167// (input)
168//
169// If there are multiple loop variables, each of the control flow ops is
170// duplicated for each loop variable.
171// TODO(skyewm): link to public version of design doc
172Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs,
173 const CondGraphBuilderFn& cond,
174 const BodyGraphBuilderFn& body, const string& frame_name,
175 OutputList* outputs, bool create_while_ctx,
176 Output* cond_output) {
177 DCHECK(!inputs.empty());
178 DCHECK(outputs != nullptr);
179 DCHECK(outputs->empty());
180
181 TF_RETURN_IF_ERROR(scope.status());
182 const size_t num_loop_vars = inputs.size();
183
184 std::vector<Output> enter_outputs(num_loop_vars);
185 for (size_t i = 0; i < num_loop_vars; ++i) {
186 enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name);
187 }
188 TF_RETURN_IF_ERROR(scope.status());
189
190 std::vector<Output> merge_outputs(num_loop_vars);
191 for (size_t i = 0; i < num_loop_vars; ++i) {
192 TF_RETURN_IF_ERROR(
193 CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i]));
194 }
195
196 Output cond_out;
197 TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out));
198 if (cond_output != nullptr) *cond_output = cond_out;
199
200 std::vector<Output> switch_trues(num_loop_vars);
201 std::vector<Output> switch_falses(num_loop_vars);
202 for (size_t i = 0; i < num_loop_vars; ++i) {
203 auto switch_i = Switch(scope, merge_outputs[i], cond_out);
204 switch_trues[i] = switch_i.output_true;
205 switch_falses[i] = switch_i.output_false;
206 }
207 TF_RETURN_IF_ERROR(scope.status());
208
209 std::vector<Output> body_outputs;
210 TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs));
211
212 std::vector<Output> next_outputs(num_loop_vars);
213 for (size_t i = 0; i < num_loop_vars; ++i) {
214 next_outputs[i] = NextIteration(scope, body_outputs[i]);
215 DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i));
216 }
217 TF_RETURN_IF_ERROR(scope.status());
218
219 // Create the backedges from the NextIteration nodes to the Merge nodes.
220 for (size_t i = 0; i < num_loop_vars; ++i) {
221 const int merge_backedge_output_index = 1;
222 scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(),
223 merge_outputs[i].node(),
224 merge_backedge_output_index);
225 }
226
227 outputs->resize(num_loop_vars);
228 for (size_t i = 0; i < num_loop_vars; ++i) {
229 (*outputs)[i] = internal::Exit(scope, switch_falses[i]);
230 }
231 TF_RETURN_IF_ERROR(scope.status());
232
233 if (create_while_ctx) {
234 WhileContext* while_ctx;
235 TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext(
236 frame_name, ToNodes(enter_outputs), ToNodes(*outputs),
237 ToOutputTensor(cond_out), ToOutputTensors(switch_trues),
238 ToOutputTensors(body_outputs), &while_ctx));
239
240 // Set while_ctx for all exit nodes. We currently don't require knowing the
241 // while_ctx for any other nodes.
242 for (size_t i = 0; i < num_loop_vars; ++i) {
243 (*outputs)[i].node()->set_while_ctx(while_ctx);
244 }
245 }
246 return OkStatus();
247}
248
249} // namespace ops
250} // namespace tensorflow
251