1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/control_flow_deps_to_chains.h"
17
18#include <algorithm>
19#include <cstdint>
20#include <string>
21
22#include "tensorflow/core/framework/attr_value.pb.h"
23#include "tensorflow/core/framework/node_def.pb.h"
24#include "tensorflow/core/framework/node_def_util.h"
25#include "tensorflow/core/framework/op_def_builder.h"
26#include "tensorflow/core/framework/tensor.pb.h"
27#include "tensorflow/core/platform/errors.h"
28#include "tensorflow/core/platform/strcat.h"
29#include "tensorflow/core/platform/types.h"
30#include "tensorflow/core/util/dump_graph.h"
31
32namespace tensorflow {
33
34// TODO(mdan): Move this into Grappler - cleaner interface.
35Status ControlFlowDepsToChainsPass::Run(
36 const GraphOptimizationPassOptions& options) {
37 VLOG(1) << "ControlFlowDepsToChainsPass::Run";
38
39 if (options.graph == nullptr) {
40 VLOG(1) << "ControlFlowDepsToChainsPass::Run Aborted";
41 return OkStatus();
42 }
43
44 Graph* g = options.graph->get();
45 DCHECK(g != nullptr);
46 FunctionLibraryDefinition* flib_def = options.flib_def;
47 DCHECK(flib_def != nullptr);
48
49 if (VLOG_IS_ON(1)) {
50 DumpGraphToFile("control_flow_deps_to_chains_before", *g, flib_def);
51 }
52
53 for (Node* n : g->nodes()) {
54 if (n == nullptr) {
55 continue;
56 }
57 if (!n->IsWhileNode()) {
58 continue;
59 }
60
61 // TODO(mdan): This breaks encapsulation of Node/Graph. Is there any needed?
62 // TODO(mdan): Consolidate this with AddWhileInputHack.
63 NodeDef* while_node = n->mutable_def();
64 const auto& attrs = while_node->attr();
65 auto* mattrs = while_node->mutable_attr();
66
67 string body_name = attrs.at("body").func().name();
68 auto* body_graph = flib_def->Find(body_name);
69 DCHECK(body_graph != nullptr);
70
71 // Look for required annotations.
72
73 if (attrs.find("_stateful_parallelism") == attrs.end()) {
74 continue;
75 }
76 if (!attrs.at("_stateful_parallelism").b()) {
77 continue;
78 }
79 if (attrs.find("parallel_iterations") != attrs.end()) {
80 if (attrs.at("parallel_iterations").i() < 2) {
81 continue; // Loops which are already sequential are more efficient
82 // without chains.
83 }
84 }
85 // TODO(mdan): We don't really need this attribute.
86 if (attrs.find("_num_original_outputs") == attrs.end()) {
87 continue;
88 }
89 int body_barrier_loc = -1;
90 std::map<string, int> node_index;
91 for (int i = 0, s = body_graph->node_def_size(); i < s; i++) {
92 node_index.emplace(body_graph->node_def(i).name(), i);
93 if (body_barrier_loc < 0) {
94 const auto& node_attr = body_graph->node_def(i).attr();
95 if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
96 body_barrier_loc = i;
97 }
98 }
99 }
100 if (body_barrier_loc < 0) {
101 continue;
102 }
103 bool ok_for_lowering = true;
104 for (int i = 0; i < body_graph->control_ret_size(); i++) {
105 const auto& control_node = body_graph->node_def(
106 node_index[body_graph->signature().control_output(i)]);
107 const auto& control_attr = control_node.attr();
108 if (control_attr.find("_res_first_used_by") == control_attr.end()) {
109 ok_for_lowering = false;
110 break;
111 }
112 }
113 if (!ok_for_lowering) {
114 continue;
115 }
116
117 int num_loop_vars = body_graph->signature().input_arg_size();
118 int num_new_chains = body_graph->control_ret_size();
119 int num_node_inputs = while_node->input_size();
120
121 if (!num_new_chains) {
122 continue; // Nothing to do for stateless loops.
123 }
124
125 // Add extra loop vars to the while node.
126
127 // TODO(mdan): If the loop vars contains the resource, we should reuse it.
128 // Note that stateful ops of resource inputs cause their resources to be
129 // captured into the loop vars (through the body/cond captures). We could
130 // effectively use those as chains.
131
132 // TODO(mdan): Is there a more efficient way to do this?
133 // Insert the new While node inputs: at the end of the loop vars, but before
134 // any non-loop var inputs (like control dependencies). Once the initial
135 // chain values are created below, they will be added to these inputs.
136 for (int i = 0; i < num_new_chains; i++) {
137 while_node->add_input();
138 }
139 for (int i = num_node_inputs - 1; i >= num_loop_vars; i--) {
140 while_node->set_input(i + num_new_chains, while_node->input(i));
141 }
142
143 std::vector<Node*> new_inputs;
144 std::vector<int> new_input_locations;
145 // Set their name to a gensym, type to float and shape to scalar.
146 for (int i = 0; i < num_new_chains; i++) {
147 string c_name = g->NewName("acd__chain");
148
149 // The initial value for the i'th chain loop var.
150 NodeDef new_in;
151 new_in.set_name(c_name);
152 new_in.set_op("Const");
153 AttrValue att_dtype;
154 att_dtype.set_type(DT_FLOAT);
155 new_in.mutable_attr()->insert({"dtype", att_dtype});
156 AttrValue att_value;
157 att_value.mutable_tensor()->set_dtype(DT_FLOAT);
158 att_value.mutable_tensor()->mutable_tensor_shape();
159 att_value.mutable_tensor()->add_int_val(0);
160 new_in.mutable_attr()->insert({"value", att_value});
161 Status status;
162 new_inputs.push_back(g->AddNode(new_in, &status));
163 TF_RETURN_WITH_CONTEXT_IF_ERROR(status, "while creating chain", c_name);
164
165 int loc = num_loop_vars + i;
166 new_input_locations.push_back(loc);
167 while_node->set_input(loc, c_name);
168 mattrs->at("T").mutable_list()->add_type(DT_FLOAT);
169 mattrs->at("output_shapes").mutable_list()->add_shape();
170 }
171
172 // TODO(mdan): This should not be necessary to update. Delete?
173 mattrs->at("_num_original_outputs").set_i(num_loop_vars + num_new_chains);
174 n->UpdateProperties();
175 for (int i = 0; i < num_new_chains; i++) {
176 g->AddEdge(new_inputs[i], 0, n, new_input_locations[i]);
177 }
178
179 // TODO(mdan): This is wasteful. Can we just mutate the original proto?
180 FunctionDef modified_body = *body_graph;
181
182 // Disable the global end-of-body barrier from the body function.
183 // Because removing a node is too inefficient (would have to walk all the
184 // inputs of all graph nodes), we instead clear its control dependencies.
185 modified_body.mutable_node_def(body_barrier_loc)->clear_input();
186
187 // Add extra loop vars to the body function.
188
189 for (int i = 0; i < num_new_chains; i++) {
190 // Input loop vars.
191 // TODO(mdan): Double check that this doesn't clash with names in body.
192 string c_name = g->NewName("acd__chainv");
193 std::replace(c_name.begin(), c_name.end(), '/', '_');
194 auto* new_arg = modified_body.mutable_signature()->add_input_arg();
195 new_arg->set_name(c_name);
196 new_arg->set_type(DT_FLOAT);
197
198 // Output ops. These are copies of the inputs conditioned on the actual
199 // control outputs.
200 string c_out_name = g->NewName("acd__outchain");
201 auto* new_out = modified_body.add_node_def();
202 new_out->set_name(c_out_name);
203 new_out->set_op("Identity");
204 new_out->add_input(c_name);
205 new_out->add_input(
206 strings::StrCat("^", body_graph->signature().control_output(i)));
207 AttrValue attr;
208 attr.set_type(DT_FLOAT);
209 new_out->mutable_attr()->insert({"T", attr});
210
211 // Output loop var declarations.
212 string c_ret_name = c_out_name;
213 std::replace(c_ret_name.begin(), c_ret_name.end(), '/', '_');
214 auto* new_out_arg = modified_body.mutable_signature()->add_output_arg();
215 new_out_arg->set_name(c_ret_name);
216 new_out_arg->set_type(DT_FLOAT);
217
218 // Actual output loop vars.
219 modified_body.mutable_ret()->insert(
220 {c_ret_name, strings::StrCat(c_out_name, ":output:0")});
221 AttrValue attr_val;
222 attr_val.mutable_list()->add_shape();
223 FunctionDef_ArgAttrs arg_attrs;
224 arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
225 modified_body.mutable_arg_attr()->insert(
226 {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
227 }
228
229 // Wire chain loop vars to the ops they need to condition.
230
231 node_index.clear();
232 for (int i = 0; i < modified_body.node_def_size(); i++) {
233 node_index.emplace(modified_body.node_def(i).name(), i);
234 }
235 auto& modified_sig = modified_body.signature();
236 for (int i = 0; i < num_new_chains; i++) {
237 const auto& control_node =
238 modified_body.node_def(node_index[modified_sig.control_output(i)]);
239 for (const auto& r :
240 control_node.attr().at("_res_first_used_by").list().s()) {
241 NodeDef* first_node = modified_body.mutable_node_def(node_index[r]);
242 // This control dependency ensures proper sequencing of stateful ops
243 // upon entry into the loop body, so that they run after the ops
244 // which affected the same resource in the previous iteration.
245 first_node->add_input(strings::StrCat(
246 "^", modified_sig.input_arg(i + num_loop_vars).name()));
247 }
248 }
249
250 // Clear body function's control returns.
251 modified_body.mutable_control_ret()->clear();
252
253 // Add extra loop vars to the cond function.
254
255 // TODO(mdan): This is wasteful. Can't we just mutate the original proto?
256 string cond_name = attrs.at("cond").func().name();
257 auto* cond_graph = flib_def->Find(cond_name);
258 DCHECK(cond_graph != nullptr);
259 FunctionDef modified_cond = *cond_graph;
260
261 int cond_barrier_loc = -1;
262 for (int i = 0, s = cond_graph->node_def_size(); i < s; i++) {
263 if (cond_barrier_loc < 0) {
264 const auto& node_attr = cond_graph->node_def(i).attr();
265 if (node_attr.find("_acd_function_control_output") != node_attr.end()) {
266 cond_barrier_loc = i;
267 }
268 }
269 }
270 if (cond_barrier_loc > 0) {
271 // Disable the global end-of-body barrier from the cond function.
272 // Because removing a node is too inefficient (would have to walk all the
273 // inputs of all graph nodes), we instead clear its control dependencies.
274 modified_cond.mutable_node_def(cond_barrier_loc)->clear_input();
275 }
276
277 for (int i = 0; i < num_new_chains; i++) {
278 // Input loop vars.
279 // TODO(mdan): These should gate the stateful ops in the cond.
280 // Until ACD supplies the necessary information, these are dummies in this
281 // function.
282 string c_name = g->NewName("acd__chain");
283 auto* new_arg = modified_cond.mutable_signature()->add_input_arg();
284 new_arg->set_name(c_name);
285 new_arg->set_type(DT_FLOAT);
286
287 // TODO(mdan): Return values on the cond function? Most likely a bug.
288 AttrValue attr_val;
289 attr_val.mutable_list()->add_shape();
290 FunctionDef_ArgAttrs arg_attrs;
291 arg_attrs.mutable_attr()->insert({"_output_shapes", attr_val});
292 modified_cond.mutable_arg_attr()->insert(
293 {static_cast<uint32_t>(i + num_loop_vars), arg_attrs});
294 }
295
296 // Wire the new cond/body functions to the While node.
297
298 string new_cond_name = g->NewName("acd__while_cond");
299 modified_cond.mutable_signature()->set_name(new_cond_name);
300 mattrs->at("cond").mutable_func()->set_name(new_cond_name);
301
302 string new_body_name = g->NewName("acd__while_body");
303 modified_body.mutable_signature()->set_name(new_body_name);
304 mattrs->at("body").mutable_func()->set_name(new_body_name);
305
306 // Commit the new functions.
307
308 TF_RETURN_WITH_CONTEXT_IF_ERROR(
309 flib_def->AddFunctionDef(modified_body,
310 flib_def->GetStackTraces(body_name)),
311 "while attaching ", new_body_name, " to flib_def");
312 TF_RETURN_WITH_CONTEXT_IF_ERROR(
313 flib_def->AddFunctionDef(modified_cond,
314 flib_def->GetStackTraces(cond_name)),
315 "while attaching ", new_cond_name, " to flib_def");
316
317 // TODO(b/183666205): This should not be necessary.
318 // It's unclear why adding the functions here is also required.
319 // Moreover, it's unclear when graph_lib's parent is flib_def itself.
320 auto* graph_lib = g->mutable_flib_def();
321 if (graph_lib->default_registry() != flib_def) {
322 TF_RETURN_WITH_CONTEXT_IF_ERROR(
323 graph_lib->AddFunctionDef(modified_body,
324 graph_lib->GetStackTraces(body_name)),
325 "while attaching ", new_body_name, " to graph");
326 TF_RETURN_WITH_CONTEXT_IF_ERROR(
327 graph_lib->AddFunctionDef(modified_cond,
328 graph_lib->GetStackTraces(cond_name)),
329 "while attaching ", new_cond_name, " to graph");
330 }
331 }
332
333 if (VLOG_IS_ON(1)) {
334 DumpGraphToFile("control_flow_deps_to_chains_after", *g, flib_def);
335 }
336
337 return OkStatus();
338}
339
340// Note: This needs to run before functional control flow lowering, which is 10.
341REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 9,
342 ControlFlowDepsToChainsPass);
343
344} // namespace tensorflow
345