1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/lower_if_op.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
19 | #include "tensorflow/core/framework/node_def_builder.h" |
20 | #include "tensorflow/core/graph/graph.h" |
21 | #include "tensorflow/core/graph/node_builder.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace { |
25 | |
26 | using NodeOut = NodeBuilder::NodeOut; |
27 | |
28 | constexpr const char* const kLowerAsMultiDeviceFunctionAttr = |
29 | LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; |
30 | |
31 | // Convenience builder to make it easy to construct a conditional with a single |
32 | // function call in the then and else branch. This first converts the if node |
33 | // into switches (for inputs) and merges (for outputs) around a function call |
34 | // per branch. |
35 | class CondBuilder { |
36 | public: |
37 | enum Branch { kElseBranch = 0, kThenBranch = 1 }; |
38 | |
39 | // Create a CondBuilder to create the lowered form of `if_op` with then and |
40 | // else functions `then_fn` and `else_fn` respectively in the `graph`. The |
41 | // functions should be available in `flib`. |
42 | CondBuilder(Node* if_op, const NameAttrList& then_fn, |
43 | const NameAttrList& else_fn, bool keep_node_fetchable, |
44 | Graph* graph); |
45 | |
46 | // Constructs the basic conditional control flow using switch and merge nodes. |
47 | Status CreatePivotNodes(); |
48 | |
49 | // Adds the inputs from the if node to the merge nodes of the lowered if. |
50 | Status AddInputs(); |
51 | |
52 | // Adds the outputs from the if node to the merge nodes of the lowered if. |
53 | // Note: no inputs can be added once outputs are added as the then and else |
54 | // nodes are finalized while adding outputs. |
55 | Status AddOutputs(); |
56 | |
57 | // Builds an identity node with the same outputs as If. |
58 | Status BuildLoweredIfOutput(); |
59 | |
60 | private: |
61 | // Returns unique name containing the name of the If op being rewritten |
62 | // (name_), infix and a suffix to ensure it is unique within the graph. |
63 | string NewName(const string& infix); |
64 | |
65 | // Adds input to both the then and else nodes from src:src_output. |
66 | Status AddInput(Node* src, int src_output); |
67 | |
68 | // Finalizes the node described by `node_builder`. If `coloc_attr_` is not |
69 | // nullptr, adds the colocation attr to the node before finalizing it. |
70 | Status SetColocationAndFinalize(NodeBuilder node_builder, Graph* graph, |
71 | Node** created_node); |
72 | |
73 | // The merged outputs of the then and else nodes. |
74 | std::vector<NodeOut> outputs_; |
75 | |
76 | // The node that dominates all execution of the then and else body nodes. |
77 | Node* control_predecessor_; |
78 | // The original If op. |
79 | Node* if_op_; |
80 | // The colocation attr on the original If op. If it exists, control flow nodes |
81 | // created in the lowering (except the data Switch nodes) will inherit this |
82 | // attribute. |
83 | const AttrValue* coloc_attr_; |
84 | // The node with the same name as the original If op: |
85 | // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true' |
86 | // and if the original If op had non-zero data outputs. |
87 | // (b) NoOp node with control edge from 'branch_executed_node_' otherwise. |
88 | Node* lowered_if_output_; |
89 | // The predicate of the conditional. |
90 | OutputTensor pred_; |
91 | // Node corresponding to pivot_f branch of predicate switch which is |
92 | // the pivot node that dominates all nodes in the false/else branch. |
93 | Node* pivot_f_; |
94 | // Node corresponding to pivot_t branch of predicate switch which is |
95 | // the pivot node that dominates all nodes in the true/then branch. |
96 | Node* pivot_t_; |
97 | Node* then_call_node_; |
98 | Node* else_call_node_; |
99 | // Merge node that has inputs from [pivot_t, pivot_f] and control edges from |
100 | // [^then_call_node_, ^else_call_node_]. This node will guarantee that even |
101 | // when then/else branch functions do not have outputs, they still will be |
102 | // executed for the side effects. |
103 | Node* branch_executed_node_; |
104 | Graph* graph_; |
105 | string name_; |
106 | bool keep_node_fetchable_; |
107 | |
108 | NodeDebugInfo debug_info_; |
109 | NodeBuilder then_call_builder_; |
110 | NodeBuilder else_call_builder_; |
111 | }; |
112 | |
113 | CondBuilder::CondBuilder(Node* if_op, const NameAttrList& then_fn, |
114 | const NameAttrList& else_fn, bool keep_node_fetchable, |
115 | Graph* graph) |
116 | : if_op_(if_op), |
117 | coloc_attr_(if_op_->attrs().Find(kColocationAttrName)), |
118 | graph_(graph), |
119 | name_(if_op->name()), |
120 | keep_node_fetchable_(keep_node_fetchable), |
121 | debug_info_(*if_op_), |
122 | then_call_builder_(NewName("then" ), then_fn.name(), graph->op_registry(), |
123 | &debug_info_), |
124 | else_call_builder_(NewName("else" ), else_fn.name(), graph->op_registry(), |
125 | &debug_info_) { |
126 | TF_CHECK_OK(if_op_->input_tensor(0, &pred_)); |
127 | then_call_builder_.Device(if_op_->requested_device()); |
128 | then_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); |
129 | for (const auto& i : then_fn.attr()) { |
130 | then_call_builder_.Attr(i.first, i.second); |
131 | } |
132 | else_call_builder_.Device(if_op_->requested_device()); |
133 | else_call_builder_.Attr(kLowerAsMultiDeviceFunctionAttr, true); |
134 | for (const auto& i : else_fn.attr()) { |
135 | else_call_builder_.Attr(i.first, i.second); |
136 | } |
137 | } |
138 | |
139 | Status CondBuilder::SetColocationAndFinalize(NodeBuilder node_builder, |
140 | Graph* graph, |
141 | Node** created_node) { |
142 | if (coloc_attr_ != nullptr) { |
143 | node_builder = node_builder.Attr(kColocationAttrName, *coloc_attr_); |
144 | } |
145 | return node_builder.Finalize(graph, created_node); |
146 | } |
147 | |
148 | Status CondBuilder::CreatePivotNodes() { |
149 | // Construct the basic cond body (consisting of feeding in the predicate to |
150 | // create pivot nodes). |
151 | Node* switch_pred; |
152 | TF_RETURN_IF_ERROR( |
153 | SetColocationAndFinalize(NodeBuilder(NewName("switch_pred" ), "Switch" , |
154 | graph_->op_registry(), &debug_info_) |
155 | .Input(NodeOut(pred_)) |
156 | .Input(NodeOut(pred_)) |
157 | .Device(if_op_->requested_device()), |
158 | graph_, &switch_pred)); |
159 | control_predecessor_ = switch_pred; |
160 | TF_RETURN_IF_ERROR( |
161 | SetColocationAndFinalize(NodeBuilder(NewName("pivot_f" ), "Identity" , |
162 | graph_->op_registry(), &debug_info_) |
163 | .Input(switch_pred, kElseBranch) |
164 | .Device(if_op_->requested_device()), |
165 | graph_, &pivot_f_)); |
166 | TF_RETURN_IF_ERROR( |
167 | SetColocationAndFinalize(NodeBuilder(NewName("pivot_t" ), "Identity" , |
168 | graph_->op_registry(), &debug_info_) |
169 | .Input(switch_pred, kThenBranch) |
170 | .Device(if_op_->requested_device()), |
171 | graph_, &pivot_t_)); |
172 | return OkStatus(); |
173 | } |
174 | |
175 | string CondBuilder::NewName(const string& infix) { |
176 | return graph_->NewName(strings::StrCat(name_, "/" , infix)); |
177 | } |
178 | |
179 | Status CondBuilder::AddInput(Node* src, int src_output) { |
180 | Node* input; |
181 | NodeDebugInfo debug_info(*src); |
182 | // Colocate the Switch node with the `src` node. |
183 | // |
184 | // This is to avoid unnecessary Host<->Device copies between src and the |
185 | // Switch node. |
186 | // |
187 | // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize`, |
188 | // and instead ignore the existing colocation stack. This is aligned with the |
189 | // legacy impl in control_flow_ops.py. The legacy impl colocates this Switch |
190 | // with the input tensor which resets the device stack and forces the Switch |
191 | // to have the same device as the input node (if set) and sets the colocation |
192 | // _class attr. It also ignores the existing colocation stack in the context |
193 | // by using colocate_with(ignore_existing=True). |
194 | TF_RETURN_IF_ERROR( |
195 | NodeBuilder(NewName(src->name()), "Switch" , graph_->op_registry(), |
196 | &debug_info) |
197 | .Input(src, src_output) |
198 | .Input(pred_) |
199 | .Device(src->requested_device()) |
200 | .Attr(kColocationAttrName, |
201 | {absl::StrCat(kColocationGroupPrefix, src->name())}) |
202 | .Finalize(graph_, &input)); |
203 | then_call_builder_.Input(input, kThenBranch); |
204 | else_call_builder_.Input(input, kElseBranch); |
205 | return OkStatus(); |
206 | } |
207 | |
208 | Status CondBuilder::AddInputs() { |
209 | // Add input data edges. |
210 | std::vector<const Edge*> edges; |
211 | TF_RETURN_IF_ERROR(if_op_->input_edges(&edges)); |
212 | // Start at index 1 as the first input is the predicate. |
213 | for (int i = 1; i < edges.size(); ++i) { |
214 | const Edge* e = edges[i]; |
215 | TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output())); |
216 | } |
217 | // Add input control edges. |
218 | for (const Edge* e : if_op_->in_edges()) { |
219 | if (e->IsControlEdge()) { |
220 | graph_->AddControlEdge(e->src(), control_predecessor_); |
221 | } |
222 | } |
223 | return OkStatus(); |
224 | } |
225 | |
226 | Status CondBuilder::AddOutputs() { |
227 | // Construct the then and else nodes. |
228 | // NOTE(rachelim): Here, we don't use `CondBuilder::SetColocationAndFinalize` |
229 | // because the colocation for branch nodes is applied in python. |
230 | TF_RETURN_IF_ERROR(then_call_builder_.Finalize(graph_, &then_call_node_)); |
231 | graph_->AddControlEdge(pivot_t_, then_call_node_); |
232 | TF_RETURN_IF_ERROR(else_call_builder_.Finalize(graph_, &else_call_node_)); |
233 | graph_->AddControlEdge(pivot_f_, else_call_node_); |
234 | |
235 | // Add Merge node for each data output of the If node. |
236 | std::vector<Node*> merges(then_call_node_->num_outputs()); |
237 | outputs_.resize(merges.size()); |
238 | for (int i = 0; i < then_call_node_->num_outputs(); ++i) { |
239 | TF_RETURN_IF_ERROR(SetColocationAndFinalize( |
240 | NodeBuilder(NewName("output" ), "Merge" , graph_->op_registry(), |
241 | &debug_info_) |
242 | .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) |
243 | .Device(if_op_->requested_device()), |
244 | graph_, &merges[i])); |
245 | outputs_[i] = NodeOut(merges[i], 0); |
246 | } |
247 | |
248 | // Add a Merge node that will be used as a control dependency source for the |
249 | // lowered output node. This Merge node will guarantee that lowered else/then |
250 | // function calls will be executed even if they do not have data outputs. |
251 | // |
252 | // Furthermore it will guarantee that all function side effects will be |
253 | // executed, if the function will be inlined into the graph. Having data |
254 | // outputs is not enough, because they might become unused after inlining. |
255 | // |
256 | // We will use this node to rewrite outgoing control edges from lowered 'If' |
257 | // node. All data edges will read tensors directly from Merge nodes. |
258 | TF_RETURN_IF_ERROR(SetColocationAndFinalize( |
259 | NodeBuilder(NewName("branch_executed" ), "Merge" , graph_->op_registry(), |
260 | &debug_info_) |
261 | .Input({pivot_t_, pivot_f_}) |
262 | .ControlInputs({then_call_node_, else_call_node_}) |
263 | .Device(if_op_->requested_device()), |
264 | graph_, &branch_executed_node_)); |
265 | |
266 | TF_RETURN_IF_ERROR(BuildLoweredIfOutput()); |
267 | |
268 | // Add outputs. |
269 | for (const Edge* e : if_op_->out_edges()) { |
270 | if (e->IsControlEdge()) { |
271 | graph_->AddControlEdge(branch_executed_node_, e->dst()); |
272 | } else { |
273 | // Feed the outputs directly from the merge nodes so that downstream ops |
274 | // can start before all the outputs have been computed. |
275 | graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input()); |
276 | } |
277 | } |
278 | |
279 | return OkStatus(); |
280 | } |
281 | |
282 | Status CondBuilder::BuildLoweredIfOutput() { |
283 | // If outputs are empty, it means that we might have only output control |
284 | // edges (already connected to the `branch_executed_node`). Furthermore it's |
285 | // illegal to have an IdentityN with empty inputs. |
286 | // |
287 | // We still must keep lowered If node as a valid source of control edges, |
288 | // because it might be a part of function control output set. |
289 | NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty() |
290 | ? NodeBuilder(name_, "IdentityN" ).Input(outputs_) |
291 | : NodeBuilder(name_, "NoOp" ); |
292 | |
293 | return builder.Device(if_op_->requested_device()) |
294 | .ControlInput(branch_executed_node_) |
295 | .Finalize(graph_, &lowered_if_output_); |
296 | } |
297 | |
298 | } // namespace |
299 | |
300 | Status RewriteIfNode(Node* n, Graph* g, bool keep_node_fetchable) { |
301 | VLOG(2) << "Lower If node (keep_node_fetchable=" << keep_node_fetchable |
302 | << "): " << SummarizeNode(*n); |
303 | |
304 | const AttrValue* then_attr = n->attrs().Find("then_branch" ); |
305 | if (then_attr == nullptr) { |
306 | return errors::InvalidArgument("Then branch function missing" ); |
307 | } |
308 | const AttrValue* else_attr = n->attrs().Find("else_branch" ); |
309 | if (else_attr == nullptr) { |
310 | return errors::InvalidArgument("Else branch function missing" ); |
311 | } |
312 | |
313 | CondBuilder cb(n, then_attr->func(), else_attr->func(), keep_node_fetchable, |
314 | g); |
315 | TF_RETURN_IF_ERROR(cb.CreatePivotNodes()); |
316 | TF_RETURN_IF_ERROR(cb.AddInputs()); |
317 | TF_RETURN_IF_ERROR(cb.AddOutputs()); |
318 | g->RemoveNode(n); |
319 | |
320 | return OkStatus(); |
321 | } |
322 | |
323 | } // namespace tensorflow |
324 | |