1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/lower_if_op.h"
17
18#include "tensorflow/core/common_runtime/inline_function_utils.h"
19#include "tensorflow/core/framework/node_def_builder.h"
20#include "tensorflow/core/graph/graph.h"
21#include "tensorflow/core/graph/node_builder.h"
22
23namespace tensorflow {
24namespace {
25
26using NodeOut = NodeBuilder::NodeOut;
27
28constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
29 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
30
31// Convenience builder to make it easy to construct a conditional with a single
32// function call in the then and else branch. This first converts the if node
33// into switches (for inputs) and merges (for outputs) around a function call
34// per branch.
35class CondBuilder {
36 public:
37 enum Branch { kElseBranch = 0, kThenBranch = 1 };
38
39 // Create a CondBuilder to create the lowered form of `if_op` with then and
40 // else functions `then_fn` and `else_fn` respectively in the `graph`. The
41 // functions should be available in `flib`.
42 CondBuilder(Node* if_op, const NameAttrList& then_fn,
43 const NameAttrList& else_fn, bool keep_node_fetchable,
44 Graph* graph);
45
46 // Constructs the basic conditional control flow using switch and merge nodes.
47 Status CreatePivotNodes();
48
49 // Adds the inputs from the if node to the merge nodes of the lowered if.
50 Status AddInputs();
51
52 // Adds the outputs from the if node to the merge nodes of the lowered if.
53 // Note: no inputs can be added once outputs are added as the then and else
54 // nodes are finalized while adding outputs.
55 Status AddOutputs();
56
57 // Builds an identity node with the same outputs as If.
58 Status BuildLoweredIfOutput();
59
60 private:
61 // Returns unique name containing the name of the If op being rewritten
62 // (name_), infix and a suffix to ensure it is unique within the graph.
63 string NewName(const string& infix);
64
65 // Adds input to both the then and else nodes from src:src_output.
66 Status AddInput(Node* src, int src_output);
67
68 // Finalizes the node described by `node_builder`. If `coloc_attr_` is not
69 // nullptr, adds the colocation attr to the node before finalizing it.
70 Status SetColocationAndFinalize(NodeBuilder node_builder, Graph* graph,
71 Node** created_node);
72
73 // The merged outputs of the then and else nodes.
74 std::vector<NodeOut> outputs_;
75
76 // The node that dominates all execution of the then and else body nodes.
77 Node* control_predecessor_;
78 // The original If op.
79 Node* if_op_;
80 // The colocation attr on the original If op. If it exists, control flow nodes
81 // created in the lowering (except the data Switch nodes) will inherit this
82 // attribute.
83 const AttrValue* coloc_attr_;
84 // The node with the same name as the original If op:
85 // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
86 // and if the original If op had non-zero data outputs.
87 // (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
88 Node* lowered_if_output_;
89 // The predicate of the conditional.
90 OutputTensor pred_;
91 // Node corresponding to pivot_f branch of predicate switch which is
92 // the pivot node that dominates all nodes in the false/else branch.
93 Node* pivot_f_;
94 // Node corresponding to pivot_t branch of predicate switch which is
95 // the pivot node that dominates all nodes in the true/then branch.
96 Node* pivot_t_;
97 Node* then_call_node_;
98 Node* else_call_node_;
99 // Merge node that has inputs from [pivot_t, pivot_f] and control edges from
100 // [^then_call_node_, ^else_call_node_]. This node will guarantee that even
101 // when then/else branch functions do not have outputs, they still will be
102 // executed for the side effects.
103 Node* branch_executed_node_;
104 Graph* graph_;
105 string name_;
106 bool keep_node_fetchable_;
107
108 NodeDebugInfo debug_info_;
109 NodeBuilder then_call_builder_;
110 NodeBuilder else_call_builder_;
111};
112
113CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn,
114 const NameAttrList& else_fn, bool keep_node_fetchable,
115 Graph* graph)
116 : if_op_(if_op),
117 coloc_attr_(if_op_->attrs().Find(kColocationAttrName)),
118 graph_(graph),
119 name_(if_op->name()),
120 keep_node_fetchable_(keep_node_fetchable),
121 debug_info_(*if_op_),
122 then_call_builder_(NewName("then"), then_fn.name(), graph->op_registry(),
123 &debug_info_),
124 else_call_builder_(NewName("else"), else_fn.name(), graph->op_registry(),
125 &debug_info_) {
126 TF_CHECK_OK(if_op_->input_tensor(0, &pred_));
127 then_call_builder_.Device(if_op_->requested_device());
128 then_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
129 for (const auto& i : then_fn.attr()) {
130 then_call_builder_.Attr(i.first, i.second);
131 }
132 else_call_builder_.Device(if_op_->requested_device());
133 else_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true);
134 for (const auto& i : else_fn.attr()) {
135 else_call_builder_.Attr(i.first, i.second);
136 }
137}
138
139Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder,
140 Graph* graph,
141 Node** created_node) {
142 if (coloc_attr_ != nullptr) {
143 node_builder = node_builder.Attr(kColocationAttrName, *coloc_attr_);
144 }
145 return node_builder.Finalize(graph, created_node);
146}
147
148Status CondBuilder::CreatePivotNodes() {
149 // Construct the basic cond body (consisting of feeding in the predicate to
150 // create pivot nodes).
151 Node* switch_pred;
152 TF_RETURN_IF_ERROR(
153 SetColocationAndFinalize(NodeBuilder(NewName("switch_pred"), "Switch",
154 graph_->op_registry(), &debug_info_)
155 .Input(NodeOut(pred_))
156 .Input(NodeOut(pred_))
157 .Device(if_op_->requested_device()),
158 graph_, &switch_pred));
159 control_predecessor_ = switch_pred;
160 TF_RETURN_IF_ERROR(
161 SetColocationAndFinalize(NodeBuilder(NewName("pivot_f"), "Identity",
162 graph_->op_registry(), &debug_info_)
163 .Input(switch_pred, kElseBranch)
164 .Device(if_op_->requested_device()),
165 graph_, &pivot_f_));
166 TF_RETURN_IF_ERROR(
167 SetColocationAndFinalize(NodeBuilder(NewName("pivot_t"), "Identity",
168 graph_->op_registry(), &debug_info_)
169 .Input(switch_pred, kThenBranch)
170 .Device(if_op_->requested_device()),
171 graph_, &pivot_t_));
172 return OkStatus();
173}
174
175string CondBuilder::NewName(const string& infix) {
176 return graph_->NewName(strings::StrCat(name_, "/", infix));
177}
178
179Status CondBuilder::AddInput(Node* src, int src_output) {
180 Node* input;
181 NodeDebugInfo debug_info(*src);
182 // Colocate the Switch node with the `src` node.
183 //
184 // This is to avoid unnecessary Host<->Device copies between src and the
185 // Switch node.
186 //
187 // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize`,
188 // and instead ignore the existing colocation stack. This is aligned with the
189 // legacy impl in control_flow_ops.py. The legacy impl colocates this Switch
190 // with the input tensor which resets the device stack and forces the Switch
191 // to have the same device as the input node (if set) and sets the colocation
192 // _class attr. It also ignores the existing colocation stack in the context
193 // by using colocate_with(ignore_existing=True).
194 TF_RETURN_IF_ERROR(
195 NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry(),
196 &debug_info)
197 .Input(src, src_output)
198 .Input(pred_)
199 .Device(src->requested_device())
200 .Attr(kColocationAttrName,
201 {absl::StrCat(kColocationGroupPrefix, src->name())})
202 .Finalize(graph_, &input));
203 then_call_builder_.Input(input, kThenBranch);
204 else_call_builder_.Input(input, kElseBranch);
205 return OkStatus();
206}
207
208Status CondBuilder::AddInputs() {
209 // Add input data edges.
210 std::vector<const Edge*> edges;
211 TF_RETURN_IF_ERROR(if_op_->input_edges(&edges));
212 // Start at index 1 as the first input is the predicate.
213 for (int i = 1; i < edges.size(); ++i) {
214 const Edge* e = edges[i];
215 TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
216 }
217 // Add input control edges.
218 for (const Edge* e : if_op_->in_edges()) {
219 if (e->IsControlEdge()) {
220 graph_->AddControlEdge(e->src(), control_predecessor_);
221 }
222 }
223 return OkStatus();
224}
225
226Status CondBuilder::AddOutputs() {
227 // Construct the then and else nodes.
228 // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize`
229 // because the colocation for branch nodes is applied in python.
230 TF_RETURN_IF_ERROR(then_call_builder_.Finalize(graph_, &then_call_node_));
231 graph_->AddControlEdge(pivot_t_, then_call_node_);
232 TF_RETURN_IF_ERROR(else_call_builder_.Finalize(graph_, &else_call_node_));
233 graph_->AddControlEdge(pivot_f_, else_call_node_);
234
235 // Add Merge node for each data output of the If node.
236 std::vector<Node*> merges(then_call_node_->num_outputs());
237 outputs_.resize(merges.size());
238 for (int i = 0; i < then_call_node_->num_outputs(); ++i) {
239 TF_RETURN_IF_ERROR(SetColocationAndFinalize(
240 NodeBuilder(NewName("output"), "Merge", graph_->op_registry(),
241 &debug_info_)
242 .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
243 .Device(if_op_->requested_device()),
244 graph_, &merges[i]));
245 outputs_[i] = NodeOut(merges[i], 0);
246 }
247
248 // Add a Merge node that will be used as a control dependency source for the
249 // lowered output node. This Merge node will guarantee that lowered else/then
250 // function calls will be executed even if they do not have data outputs.
251 //
252 // Furthermore it will guarantee that all function side effects will be
253 // executed, if the function will be inlined into the graph. Having data
254 // outputs is not enough, because they might become unused after inlining.
255 //
256 // We will use this node to rewrite outgoing control edges from lowered 'If'
257 // node. All data edges will read tensors directly from Merge nodes.
258 TF_RETURN_IF_ERROR(SetColocationAndFinalize(
259 NodeBuilder(NewName("branch_executed"), "Merge", graph_->op_registry(),
260 &debug_info_)
261 .Input({pivot_t_, pivot_f_})
262 .ControlInputs({then_call_node_, else_call_node_})
263 .Device(if_op_->requested_device()),
264 graph_, &branch_executed_node_));
265
266 TF_RETURN_IF_ERROR(BuildLoweredIfOutput());
267
268 // Add outputs.
269 for (const Edge* e : if_op_->out_edges()) {
270 if (e->IsControlEdge()) {
271 graph_->AddControlEdge(branch_executed_node_, e->dst());
272 } else {
273 // Feed the outputs directly from the merge nodes so that downstream ops
274 // can start before all the outputs have been computed.
275 graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
276 }
277 }
278
279 return OkStatus();
280}
281
282Status CondBuilder::BuildLoweredIfOutput() {
283 // If outputs are empty, it means that we might have only output control
284 // edges (already connected to the `branch_executed_node`). Furthermore it's
285 // illegal to have an IdentityN with empty inputs.
286 //
287 // We still must keep lowered If node as a valid source of control edges,
288 // because it might be a part of function control output set.
289 NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
290 ? NodeBuilder(name_, "IdentityN").Input(outputs_)
291 : NodeBuilder(name_, "NoOp");
292
293 return builder.Device(if_op_->requested_device())
294 .ControlInput(branch_executed_node_)
295 .Finalize(graph_, &lowered_if_output_);
296}
297
298} // namespace
299
300Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) {
301 VLOG(2) << "Lower If node (keep_node_fetchable=" << keep_node_fetchable
302 << "): " << SummarizeNode(*n);
303
304 const AttrValue* then_attr = n->attrs().Find("then_branch");
305 if (then_attr == nullptr) {
306 return errors::InvalidArgument("Then branch function missing");
307 }
308 const AttrValue* else_attr = n->attrs().Find("else_branch");
309 if (else_attr == nullptr) {
310 return errors::InvalidArgument("Else branch function missing");
311 }
312
313 CondBuilder cb(n, then_attr->func(), else_attr->func(), keep_node_fetchable,
314 g);
315 TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
316 TF_RETURN_IF_ERROR(cb.AddInputs());
317 TF_RETURN_IF_ERROR(cb.AddOutputs());
318 g->RemoveNode(n);
319
320 return OkStatus();
321}
322
323} // namespace tensorflow
324