1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/ops/while_loop.h" |
17 | |
18 | #include "tensorflow/cc/framework/scope_internal.h" |
19 | #include "tensorflow/cc/ops/control_flow_ops_internal.h" |
20 | #include "tensorflow/cc/ops/standard_ops.h" |
21 | #include "tensorflow/core/common_runtime/shape_refiner.h" |
22 | #include "tensorflow/core/graph/node_builder.h" |
23 | |
24 | namespace tensorflow { |
25 | namespace ops { |
26 | |
27 | namespace { |
28 | |
29 | // Utility function for converting to internal C++ datatypes. |
30 | OutputTensor ToOutputTensor(const Output& output) { |
31 | return OutputTensor(output.node(), output.index()); |
32 | } |
33 | |
34 | // Utility function for converting to internal C++ datatypes. |
35 | std::vector<OutputTensor> ToOutputTensors(const std::vector<Output>& outputs) { |
36 | std::vector<OutputTensor> result(outputs.size()); |
37 | for (int i = 0; i < outputs.size(); ++i) { |
38 | result[i] = ToOutputTensor(outputs[i]); |
39 | } |
40 | return result; |
41 | } |
42 | |
43 | // Utility function for converting to internal C++ datatypes. |
44 | std::vector<Node*> ToNodes(const std::vector<Output>& outputs) { |
45 | std::vector<Node*> result(outputs.size()); |
46 | for (int i = 0; i < outputs.size(); ++i) { |
47 | result[i] = outputs[i].node(); |
48 | } |
49 | return result; |
50 | } |
51 | |
52 | // Manually generates the name of the `loop_var_idx`-th NextIteration node of a |
53 | // loop being constructed with `scope`. This is used to define the backedge |
54 | // before the NextIteration node is created. |
55 | string NextIterationName(const Scope& scope, int loop_var_idx) { |
56 | string result; |
57 | const string& prefix = scope.impl()->name(); |
58 | if (!prefix.empty()) strings::StrAppend(&result, prefix, "/" ); |
59 | strings::StrAppend(&result, "NextIteration" ); |
60 | if (loop_var_idx > 0) strings::StrAppend(&result, "_" , loop_var_idx); |
61 | return result; |
62 | } |
63 | |
64 | // Creates the `loop_var_idx`-th Merge node of a loop being constructed with |
65 | // `scope`. `enter_output` is the `loop_var_idx`-th Enter node's output. |
66 | Status CreateMerge(const Scope& scope, int loop_var_idx, |
67 | const Output& enter_output, Output* merge_output) { |
68 | // The merge nodes accept the while loop's back edges as an input (i.e. the |
69 | // not-yet-created next iteration nodes). Use the underlying NodeBuilder API |
70 | // directly to create the back edge. |
71 | NodeBuilder::NodeOut enter_input(enter_output.node(), enter_output.index()); |
72 | |
73 | const int next_output_index = 0; |
74 | DataType dtype = enter_output.node()->output_type(0); |
75 | NodeBuilder::NodeOut next_input(NextIterationName(scope, loop_var_idx), |
76 | next_output_index, dtype); |
77 | |
78 | std::vector<NodeBuilder::NodeOut> input_list({enter_input, next_input}); |
79 | const string unique_name = scope.GetUniqueNameForOp("Merge" ); |
80 | NodeBuilder builder = NodeBuilder(unique_name, "Merge" ).Input(input_list); |
81 | scope.UpdateBuilder(&builder); |
82 | |
83 | Node* merge_node; |
84 | TF_RETURN_IF_ERROR(builder.Finalize(scope.graph(), &merge_node)); |
85 | TF_RETURN_IF_ERROR(scope.DoShapeInference(merge_node)); |
86 | *merge_output = Output(merge_node, 0); |
87 | return OkStatus(); |
88 | } |
89 | |
90 | // Creates the condition subgraph defined by `cond`. |
91 | Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, |
92 | const std::vector<Output>& inputs, Output* output) { |
93 | // The control dependency is for constants in the cond graph, and other ops |
94 | // that do not depend on the loop variables. This ensures that these ops are |
95 | // in the while loop frame (since they will indirectly depend on an Enter node |
96 | // defining the frame) and that they are executed once per loop iteration. |
97 | // |
98 | // TODO(skyewm): the control dep will be added to all nodes in the cond graph. |
99 | // This is at best unnecessary, and at worst may prevent different parts of |
100 | // different loop iterations from executing in parallel. |
101 | Scope cond_scope = |
102 | scope.NewSubScope("cond" ).WithControlDependencies(inputs[0]); |
103 | Output raw_cond_out; |
104 | TF_RETURN_IF_ERROR(cond(cond_scope, inputs, &raw_cond_out)); |
105 | |
106 | TF_RETURN_IF_ERROR(scope.graph()->IsValidOutputTensor(raw_cond_out.node(), |
107 | raw_cond_out.index())); |
108 | if (raw_cond_out.type() != DT_BOOL) { |
109 | return errors::InvalidArgument( |
110 | "BuildWhileLoop: 'cond' argument must return a boolean output, got " , |
111 | DataTypeString(raw_cond_out.type())); |
112 | } |
113 | // TODO(skyewm): check that raw_cond_out is scalar |
114 | |
115 | *output = LoopCond(scope, raw_cond_out).output; |
116 | return OkStatus(); |
117 | } |
118 | |
119 | // Create the body subgraph defined by `body`. `outputs` must be non-null and |
120 | // empty. |
121 | Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, |
122 | const std::vector<Output>& inputs, |
123 | std::vector<Output>* outputs) { |
124 | DCHECK(outputs != nullptr); |
125 | DCHECK(outputs->empty()); |
126 | |
127 | // The control dependency is analogous to that in CreateCond(). |
128 | Scope body_scope = |
129 | scope.NewSubScope("body" ).WithControlDependencies(inputs[0]); |
130 | TF_RETURN_IF_ERROR(body(body_scope, inputs, outputs)); |
131 | |
132 | const size_t num_loop_vars = inputs.size(); |
133 | if (outputs->size() != num_loop_vars) { |
134 | return errors::InvalidArgument( |
135 | "BuildWhileLoop: 'body' argument expected to return " , num_loop_vars, |
136 | " output(s), got " , outputs->size()); |
137 | } |
138 | for (const Output& output : *outputs) { |
139 | TF_RETURN_IF_ERROR( |
140 | scope.graph()->IsValidOutputTensor(output.node(), output.index())); |
141 | // TODO(skyewm): check output types/shapes |
142 | } |
143 | return OkStatus(); |
144 | } |
145 | |
146 | } // namespace |
147 | |
148 | // A while loop with a single loop variable looks like this: |
149 | // |
150 | // (output) |
151 | // ^ +---------------+ |
152 | // | | body subgraph +-------------+ |
153 | // Exit +---------------+ | |
154 | // ^ ^ | |
155 | // | | | |
156 | // Switch<--------+ v |
157 | // ^ | NextIteration |
158 | // | +------+--------+ | |
159 | // +---->| cond subgraph | | |
160 | // | +---------------+ | |
161 | // Merge<---------------------------+ |
162 | // ^ |
163 | // | |
164 | // Enter |
165 | // ^ |
166 | // | |
167 | // (input) |
168 | // |
169 | // If there are multiple loop variables, each of the control flow ops is |
170 | // duplicated for each loop variable. |
171 | // TODO(skyewm): link to public version of design doc |
172 | Status BuildWhileLoop(const Scope& scope, const std::vector<Output>& inputs, |
173 | const CondGraphBuilderFn& cond, |
174 | const BodyGraphBuilderFn& body, const string& frame_name, |
175 | OutputList* outputs, bool create_while_ctx, |
176 | Output* cond_output) { |
177 | DCHECK(!inputs.empty()); |
178 | DCHECK(outputs != nullptr); |
179 | DCHECK(outputs->empty()); |
180 | |
181 | TF_RETURN_IF_ERROR(scope.status()); |
182 | const size_t num_loop_vars = inputs.size(); |
183 | |
184 | std::vector<Output> enter_outputs(num_loop_vars); |
185 | for (size_t i = 0; i < num_loop_vars; ++i) { |
186 | enter_outputs[i] = internal::Enter(scope, inputs[i], frame_name); |
187 | } |
188 | TF_RETURN_IF_ERROR(scope.status()); |
189 | |
190 | std::vector<Output> merge_outputs(num_loop_vars); |
191 | for (size_t i = 0; i < num_loop_vars; ++i) { |
192 | TF_RETURN_IF_ERROR( |
193 | CreateMerge(scope, i, enter_outputs[i], &merge_outputs[i])); |
194 | } |
195 | |
196 | Output cond_out; |
197 | TF_RETURN_IF_ERROR(CreateCond(scope, cond, merge_outputs, &cond_out)); |
198 | if (cond_output != nullptr) *cond_output = cond_out; |
199 | |
200 | std::vector<Output> switch_trues(num_loop_vars); |
201 | std::vector<Output> switch_falses(num_loop_vars); |
202 | for (size_t i = 0; i < num_loop_vars; ++i) { |
203 | auto switch_i = Switch(scope, merge_outputs[i], cond_out); |
204 | switch_trues[i] = switch_i.output_true; |
205 | switch_falses[i] = switch_i.output_false; |
206 | } |
207 | TF_RETURN_IF_ERROR(scope.status()); |
208 | |
209 | std::vector<Output> body_outputs; |
210 | TF_RETURN_IF_ERROR(CreateBody(scope, body, switch_trues, &body_outputs)); |
211 | |
212 | std::vector<Output> next_outputs(num_loop_vars); |
213 | for (size_t i = 0; i < num_loop_vars; ++i) { |
214 | next_outputs[i] = NextIteration(scope, body_outputs[i]); |
215 | DCHECK_EQ(next_outputs[i].node()->name(), NextIterationName(scope, i)); |
216 | } |
217 | TF_RETURN_IF_ERROR(scope.status()); |
218 | |
219 | // Create the backedges from the NextIteration nodes to the Merge nodes. |
220 | for (size_t i = 0; i < num_loop_vars; ++i) { |
221 | const int merge_backedge_output_index = 1; |
222 | scope.graph()->AddEdge(next_outputs[i].node(), next_outputs[i].index(), |
223 | merge_outputs[i].node(), |
224 | merge_backedge_output_index); |
225 | } |
226 | |
227 | outputs->resize(num_loop_vars); |
228 | for (size_t i = 0; i < num_loop_vars; ++i) { |
229 | (*outputs)[i] = internal::Exit(scope, switch_falses[i]); |
230 | } |
231 | TF_RETURN_IF_ERROR(scope.status()); |
232 | |
233 | if (create_while_ctx) { |
234 | WhileContext* while_ctx; |
235 | TF_RETURN_IF_ERROR(scope.graph()->AddWhileContext( |
236 | frame_name, ToNodes(enter_outputs), ToNodes(*outputs), |
237 | ToOutputTensor(cond_out), ToOutputTensors(switch_trues), |
238 | ToOutputTensors(body_outputs), &while_ctx)); |
239 | |
240 | // Set while_ctx for all exit nodes. We currently don't require knowing the |
241 | // while_ctx for any other nodes. |
242 | for (size_t i = 0; i < num_loop_vars; ++i) { |
243 | (*outputs)[i].node()->set_while_ctx(while_ctx); |
244 | } |
245 | } |
246 | return OkStatus(); |
247 | } |
248 | |
249 | } // namespace ops |
250 | } // namespace tensorflow |
251 | |