1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h" |
17 | |
18 | #include <algorithm> |
19 | #include <cstdint> |
20 | #include <string> |
21 | |
22 | #include "tensorflow/core/framework/attr_value.pb.h" |
23 | #include "tensorflow/core/framework/node_def.pb.h" |
24 | #include "tensorflow/core/framework/node_def_util.h" |
25 | #include "tensorflow/core/framework/op_def_builder.h" |
26 | #include "tensorflow/core/framework/tensor.pb.h" |
27 | #include "tensorflow/core/platform/errors.h" |
28 | #include "tensorflow/core/platform/strcat.h" |
29 | #include "tensorflow/core/platform/types.h" |
30 | #include "tensorflow/core/util/dump_graph.h" |
31 | |
32 | namespace tensorflow { |
33 | |
34 | // TODO(mdan): Move this into Grappler - cleaner interface. |
35 | Status ControlFlowDepsToChainsPass::Run( |
36 | const GraphOptimizationPassOptions& options) { |
37 | VLOG(1) << "ControlFlowDepsToChainsPass::Run" ; |
38 | |
39 | if (options.graph == nullptr) { |
40 | VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted" ; |
41 | return OkStatus(); |
42 | } |
43 | |
44 | Graph* g = options.graph->get(); |
45 | DCHECK(g != nullptr); |
46 | FunctionLibraryDefinition* flib_def = options.flib_def; |
47 | DCHECK(flib_def != nullptr); |
48 | |
49 | if (VLOG_IS_ON(1)) { |
50 | DumpGraphToFile("control_flow_deps_to_chains_before" , *g, flib_def); |
51 | } |
52 | |
53 | for (Node* n : g->nodes()) { |
54 | if (n == nullptr) { |
55 | continue; |
56 | } |
57 | if (!n->IsWhileNode()) { |
58 | continue; |
59 | } |
60 | |
61 | // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed? |
62 | // TODO(mdan): Consolidate this with AddWhileInputHack. |
63 | NodeDef* while_node = n->mutable_def(); |
64 | const auto& attrs = while_node->attr(); |
65 | auto* mattrs = while_node->mutable_attr(); |
66 | |
67 | string body_name = attrs.at("body" ).func().name(); |
68 | auto* body_graph = flib_def->Find(body_name); |
69 | DCHECK(body_graph != nullptr); |
70 | |
71 | // Look for required annotations. |
72 | |
73 | if (attrs.find("_stateful_parallelism" ) == attrs.end()) { |
74 | continue; |
75 | } |
76 | if (!attrs.at("_stateful_parallelism" ).b()) { |
77 | continue; |
78 | } |
79 | if (attrs.find("parallel_iterations" ) != attrs.end()) { |
80 | if (attrs.at("parallel_iterations" ).i() < 2) { |
81 | continue; // Loops which are already sequential are more efficient |
82 | // without chains. |
83 | } |
84 | } |
85 | // TODO(mdan): We don't really need this attribute. |
86 | if (attrs.find("_num_original_outputs" ) == attrs.end()) { |
87 | continue; |
88 | } |
89 | int body_barrier_loc = -1; |
90 | std::map<string, int> node_index; |
91 | for (int i = 0, s = body_graph->node_def_size(); i < s; i++) { |
92 | node_index.emplace(body_graph->node_def(i).name(), i); |
93 | if (body_barrier_loc < 0) { |
94 | const auto& node_attr = body_graph->node_def(i).attr(); |
95 | if (node_attr.find("_acd_function_control_output" ) != node_attr.end()) { |
96 | body_barrier_loc = i; |
97 | } |
98 | } |
99 | } |
100 | if (body_barrier_loc < 0) { |
101 | continue; |
102 | } |
103 | bool ok_for_lowering = true; |
104 | for (int i = 0; i < body_graph->control_ret_size(); i++) { |
105 | const auto& control_node = body_graph->node_def( |
106 | node_index[body_graph->signature().control_output(i)]); |
107 | const auto& control_attr = control_node.attr(); |
108 | if (control_attr.find("_res_first_used_by" ) == control_attr.end()) { |
109 | ok_for_lowering = false; |
110 | break; |
111 | } |
112 | } |
113 | if (!ok_for_lowering) { |
114 | continue; |
115 | } |
116 | |
117 | int num_loop_vars = body_graph->signature().input_arg_size(); |
118 | int num_new_chains = body_graph->control_ret_size(); |
119 | int num_node_inputs = while_node->input_size(); |
120 | |
121 | if (!num_new_chains) { |
122 | continue; // Nothing to do for stateless loops. |
123 | } |
124 | |
125 | // Add extra loop vars to the while node. |
126 | |
127 | // TODO(mdan): If the loop vars contains the resource, we should reuse it. |
128 | // Note that stateful ops of resource inputs cause their resources to be |
129 | // captured into the loop vars (through the body/cond captures). We could |
130 | // effectively use those as chains. |
131 | |
132 | // TODO(mdan): Is there a more efficient way to do this? |
133 | // Insert the new While node inputs: at the end of the loop vars, but before |
134 | // any non-loop var inputs (like control dependencies). Once the initial |
135 | // chain values are created below, they will be added to these inputs. |
136 | for (int i = 0; i < num_new_chains; i++) { |
137 | while_node->add_input(); |
138 | } |
139 | for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) { |
140 | while_node->set_input(i + num_new_chains, while_node->input(i)); |
141 | } |
142 | |
143 | std::vector<Node*> new_inputs; |
144 | std::vector<int> new_input_locations; |
145 | // Set their name to a gensym, type to float and shape to scalar. |
146 | for (int i = 0; i < num_new_chains; i++) { |
147 | string c_name = g->NewName("acd__chain" ); |
148 | |
149 | // The initial value for the i'th chain loop var. |
150 | NodeDef new_in; |
151 | new_in.set_name(c_name); |
152 | new_in.set_op("Const" ); |
153 | AttrValue att_dtype; |
154 | att_dtype.set_type(DT_FLOAT); |
155 | new_in.mutable_attr()->insert({"dtype" , att_dtype}); |
156 | AttrValue att_value; |
157 | att_value.mutable_tensor()->set_dtype(DT_FLOAT); |
158 | att_value.mutable_tensor()->mutable_tensor_shape(); |
159 | att_value.mutable_tensor()->add_int_val(0); |
160 | new_in.mutable_attr()->insert({"value" , att_value}); |
161 | Status status; |
162 | new_inputs.push_back(g->AddNode(new_in, &status)); |
163 | TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain" , c_name); |
164 | |
165 | int loc = num_loop_vars + i; |
166 | new_input_locations.push_back(loc); |
167 | while_node->set_input(loc, c_name); |
168 | mattrs->at("T" ).mutable_list()->add_type(DT_FLOAT); |
169 | mattrs->at("output_shapes" ).mutable_list()->add_shape(); |
170 | } |
171 | |
172 | // TODO(mdan): This should not be necessary to update. Delete? |
173 | mattrs->at("_num_original_outputs" ).set_i(num_loop_vars + num_new_chains); |
174 | n->UpdateProperties(); |
175 | for (int i = 0; i < num_new_chains; i++) { |
176 | g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]); |
177 | } |
178 | |
179 | // TODO(mdan): This is wasteful. Can we just mutate the original proto? |
180 | FunctionDef modified_body = *body_graph; |
181 | |
182 | // Disable the global end-of-body barrier from the body function. |
183 | // Because removing a node is too inefficient (would have to walk all the |
184 | // inputs of all graph nodes), we instead clear its control dependencies. |
185 | modified_body.mutable_node_def(body_barrier_loc)->clear_input(); |
186 | |
187 | // Add extra loop vars to the body function. |
188 | |
189 | for (int i = 0; i < num_new_chains; i++) { |
190 | // Input loop vars. |
191 | // TODO(mdan): Double check that this doesn't clash with names in body. |
192 | string c_name = g->NewName("acd__chainv" ); |
193 | std::replace(c_name.begin(), c_name.end(), '/', '_'); |
194 | auto* new_arg = modified_body.mutable_signature()->add_input_arg(); |
195 | new_arg->set_name(c_name); |
196 | new_arg->set_type(DT_FLOAT); |
197 | |
198 | // Output ops. These are copies of the inputs conditioned on the actual |
199 | // control outputs. |
200 | string c_out_name = g->NewName("acd__outchain" ); |
201 | auto* new_out = modified_body.add_node_def(); |
202 | new_out->set_name(c_out_name); |
203 | new_out->set_op("Identity" ); |
204 | new_out->add_input(c_name); |
205 | new_out->add_input( |
206 | strings::StrCat("^" , body_graph->signature().control_output(i))); |
207 | AttrValue attr; |
208 | attr.set_type(DT_FLOAT); |
209 | new_out->mutable_attr()->insert({"T" , attr}); |
210 | |
211 | // Output loop var declarations. |
212 | string c_ret_name = c_out_name; |
213 | std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_'); |
214 | auto* new_out_arg = modified_body.mutable_signature()->add_output_arg(); |
215 | new_out_arg->set_name(c_ret_name); |
216 | new_out_arg->set_type(DT_FLOAT); |
217 | |
218 | // Actual output loop vars. |
219 | modified_body.mutable_ret()->insert( |
220 | {c_ret_name, strings::StrCat(c_out_name, ":output:0" )}); |
221 | AttrValue attr_val; |
222 | attr_val.mutable_list()->add_shape(); |
223 | FunctionDef_ArgAttrs arg_attrs; |
224 | arg_attrs.mutable_attr()->insert({"_output_shapes" , attr_val}); |
225 | modified_body.mutable_arg_attr()->insert( |
226 | {static_cast<uint32_t>(i + num_loop_vars), arg_attrs}); |
227 | } |
228 | |
229 | // Wire chain loop vars to the ops they need to condition. |
230 | |
231 | node_index.clear(); |
232 | for (int i = 0; i < modified_body.node_def_size(); i++) { |
233 | node_index.emplace(modified_body.node_def(i).name(), i); |
234 | } |
235 | auto& modified_sig = modified_body.signature(); |
236 | for (int i = 0; i < num_new_chains; i++) { |
237 | const auto& control_node = |
238 | modified_body.node_def(node_index[modified_sig.control_output(i)]); |
239 | for (const auto& r : |
240 | control_node.attr().at("_res_first_used_by" ).list().s()) { |
241 | NodeDef* first_node = modified_body.mutable_node_def(node_index[r]); |
242 | // This control dependency ensures proper sequencing of stateful ops |
243 | // upon entry into the loop body, so that they run after the ops |
244 | // which affected the same resource in the previous iteration. |
245 | first_node->add_input(strings::StrCat( |
246 | "^" , modified_sig.input_arg(i + num_loop_vars).name())); |
247 | } |
248 | } |
249 | |
250 | // Clear body function's control returns. |
251 | modified_body.mutable_control_ret()->clear(); |
252 | |
253 | // Add extra loop vars to the cond function. |
254 | |
255 | // TODO(mdan): This is wasteful. Can't we just mutate the original proto? |
256 | string cond_name = attrs.at("cond" ).func().name(); |
257 | auto* cond_graph = flib_def->Find(cond_name); |
258 | DCHECK(cond_graph != nullptr); |
259 | FunctionDef modified_cond = *cond_graph; |
260 | |
261 | int cond_barrier_loc = -1; |
262 | for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) { |
263 | if (cond_barrier_loc < 0) { |
264 | const auto& node_attr = cond_graph->node_def(i).attr(); |
265 | if (node_attr.find("_acd_function_control_output" ) != node_attr.end()) { |
266 | cond_barrier_loc = i; |
267 | } |
268 | } |
269 | } |
270 | if (cond_barrier_loc > 0) { |
271 | // Disable the global end-of-body barrier from the cond function. |
272 | // Because removing a node is too inefficient (would have to walk all the |
273 | // inputs of all graph nodes), we instead clear its control dependencies. |
274 | modified_cond.mutable_node_def(cond_barrier_loc)->clear_input(); |
275 | } |
276 | |
277 | for (int i = 0; i < num_new_chains; i++) { |
278 | // Input loop vars. |
279 | // TODO(mdan): These should gate the stateful ops in the cond. |
280 | // Until ACD supplies the necessary information, these are dummies in this |
281 | // function. |
282 | string c_name = g->NewName("acd__chain" ); |
283 | auto* new_arg = modified_cond.mutable_signature()->add_input_arg(); |
284 | new_arg->set_name(c_name); |
285 | new_arg->set_type(DT_FLOAT); |
286 | |
287 | // TODO(mdan): Return values on the cond function? Most likely a bug. |
288 | AttrValue attr_val; |
289 | attr_val.mutable_list()->add_shape(); |
290 | FunctionDef_ArgAttrs arg_attrs; |
291 | arg_attrs.mutable_attr()->insert({"_output_shapes" , attr_val}); |
292 | modified_cond.mutable_arg_attr()->insert( |
293 | {static_cast<uint32_t>(i + num_loop_vars), arg_attrs}); |
294 | } |
295 | |
296 | // Wire the new cond/body functions to the While node. |
297 | |
298 | string new_cond_name = g->NewName("acd__while_cond" ); |
299 | modified_cond.mutable_signature()->set_name(new_cond_name); |
300 | mattrs->at("cond" ).mutable_func()->set_name(new_cond_name); |
301 | |
302 | string new_body_name = g->NewName("acd__while_body" ); |
303 | modified_body.mutable_signature()->set_name(new_body_name); |
304 | mattrs->at("body" ).mutable_func()->set_name(new_body_name); |
305 | |
306 | // Commit the new functions. |
307 | |
308 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
309 | flib_def->AddFunctionDef(modified_body, |
310 | flib_def->GetStackTraces(body_name)), |
311 | "while attaching " , new_body_name, " to flib_def" ); |
312 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
313 | flib_def->AddFunctionDef(modified_cond, |
314 | flib_def->GetStackTraces(cond_name)), |
315 | "while attaching " , new_cond_name, " to flib_def" ); |
316 | |
317 | // TODO(b/183666205): This should not be necessary. |
318 | // It's unclear why adding the functions here is also required. |
319 | // Moreover, it's unclear when graph_lib's parent is flib_def itself. |
320 | auto* graph_lib = g->mutable_flib_def(); |
321 | if (graph_lib->default_registry() != flib_def) { |
322 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
323 | graph_lib->AddFunctionDef(modified_body, |
324 | graph_lib->GetStackTraces(body_name)), |
325 | "while attaching " , new_body_name, " to graph" ); |
326 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
327 | graph_lib->AddFunctionDef(modified_cond, |
328 | graph_lib->GetStackTraces(cond_name)), |
329 | "while attaching " , new_cond_name, " to graph" ); |
330 | } |
331 | } |
332 | |
333 | if (VLOG_IS_ON(1)) { |
334 | DumpGraphToFile("control_flow_deps_to_chains_after" , *g, flib_def); |
335 | } |
336 | |
337 | return OkStatus(); |
338 | } |
339 | |
340 | // Note: This needs to run before functional control flow lowering, which is 10. |
341 | REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9, |
342 | ControlFlowDepsToChainsPass); |
343 | |
344 | } // namespace tensorflow |
345 | |