1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/lower_case_op.h" |
17 | |
18 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
19 | #include "tensorflow/core/framework/node_def_builder.h" |
20 | #include "tensorflow/core/graph/graph.h" |
21 | #include "tensorflow/core/graph/node_builder.h" |
22 | #include "tensorflow/core/lib/core/errors.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | namespace { |
27 | |
28 | using NodeOut = NodeBuilder::NodeOut; |
29 | |
30 | constexpr const char* const kLowerAsMultiDeviceFunctionAttr = |
31 | LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; |
32 | |
33 | // Convenience builder to make it easy to construct a case with a single |
34 | // function call in each branch. This first converts the Case node |
35 | // into switches (for inputs) and merges (for outputs) around a function call |
36 | // per branch. |
37 | class CaseBuilder { |
38 | public: |
39 | // Create a CaseBuilder to create the lowered form of `case` with branch |
40 | // functions identified by `branch_fn_names` in the `graph`. |
41 | CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names, |
42 | bool keep_node_fetchable, Graph* graph); |
43 | |
44 | // Constructs the basic conditional control flow using switch and merge nodes. |
45 | Status CreatePivotNodes(); |
46 | |
47 | // Adds the inputs from the if node to the merge nodes of the lowered if. |
48 | Status AddInputs(); |
49 | |
50 | // Adds the outputs from the if node to the merge nodes of the lowered if. |
51 | // Note: no inputs can be added once outputs are added as the then and else |
52 | // nodes are finalized while adding outputs. |
53 | Status AddOutputs(); |
54 | |
55 | // Builds an identity node with the same outputs as Case. |
56 | Status BuildLoweredCaseOutput(); |
57 | |
58 | private: |
59 | // Returns unique name containing the name of the Case op being rewritten |
60 | // (name_), infix and a suffix to ensure it is unique within the graph. |
61 | string NewName(const string& infix); |
62 | |
63 | // Adds input to both the then and else nodes from src:src_output. |
64 | Status AddInput(Node* src, int src_output); |
65 | |
66 | // The merged outputs of the then and else nodes. |
67 | std::vector<NodeOut> outputs_; |
68 | |
69 | // The node that dominates all execution of the then and else body nodes. |
70 | Node* control_predecessor_; |
71 | // The original Case op. |
72 | Node* case_op_; |
73 | // The node with the same name as the original Case op: |
74 | // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true' |
75 | // and if the original Case op had non-zero data outputs. |
76 | // (b) NoOp node with control edge from 'branch_executed_node_' otherwise. |
77 | Node* lowered_case_output_; |
78 | // The branch selector of the case. |
79 | OutputTensor branch_index_; |
80 | int num_branches_; |
81 | // Nodes corresponding to pivot branch of branch_index _SwitchN, which is |
82 | // the pivot node that dominates all nodes in the i'th branch. |
83 | std::vector<Node*> pivots_; |
84 | std::vector<Node*> call_nodes_; |
85 | // Merge node that has inputs from each of pivots_ and control edges from |
86 | // [^call_node for call_node in call_nodes_]. This node will guarantee that |
87 | // even when branch functions do not have outputs, they still will be executed |
88 | // for the side effects. |
89 | Node* branch_executed_node_; |
90 | Graph* graph_; |
91 | string name_; |
92 | bool keep_node_fetchable_; |
93 | |
94 | NodeDebugInfo debug_info_; |
95 | std::vector<NodeBuilder> branch_call_builders_; |
96 | }; |
97 | |
98 | CaseBuilder::CaseBuilder(Node* case_op, |
99 | const std::vector<string>& branch_fn_names, |
100 | bool keep_node_fetchable, Graph* graph) |
101 | : case_op_(case_op), |
102 | num_branches_(branch_fn_names.size()), |
103 | graph_(graph), |
104 | name_(case_op->name()), |
105 | keep_node_fetchable_(keep_node_fetchable), |
106 | debug_info_(*case_op_) { |
107 | branch_call_builders_.reserve(num_branches_); |
108 | for (int b = 0; b < num_branches_; b++) { |
109 | branch_call_builders_.emplace_back(NewName(strings::StrCat("branch" , b)), |
110 | branch_fn_names[b], graph->op_registry(), |
111 | &debug_info_); |
112 | branch_call_builders_[b].Device(case_op_->requested_device()); |
113 | branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true); |
114 | } |
115 | TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_)); |
116 | } |
117 | |
118 | Status CaseBuilder::CreatePivotNodes() { |
119 | // Construct the basic case body (consisting of feeding in the val to |
120 | // create pivot nodes). |
121 | Node* branch_index; |
122 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index" ), "_SwitchN" , |
123 | graph_->op_registry(), &debug_info_) |
124 | .Input(NodeOut(branch_index_)) |
125 | .Input(NodeOut(branch_index_)) |
126 | .Attr("num_outs" , num_branches_) |
127 | .Device(case_op_->requested_device()) |
128 | .Finalize(graph_, &branch_index)); |
129 | control_predecessor_ = branch_index; |
130 | pivots_.resize(num_branches_, nullptr); |
131 | for (int b = 0; b < num_branches_; b++) { |
132 | TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_" , b)), |
133 | "Identity" , graph_->op_registry(), |
134 | &debug_info_) |
135 | .Input(branch_index, b) |
136 | .Device(case_op_->requested_device()) |
137 | .Finalize(graph_, &pivots_[b])); |
138 | } |
139 | return OkStatus(); |
140 | } |
141 | |
142 | string CaseBuilder::NewName(const string& infix) { |
143 | return graph_->NewName(strings::StrCat(name_, "/" , infix)); |
144 | } |
145 | |
146 | Status CaseBuilder::AddInput(Node* src, int src_output) { |
147 | Node* input; |
148 | NodeDebugInfo debug_info(*src); |
149 | // Colocate the Switch node with the `src` node. |
150 | // |
151 | // This is to avoid unnecessary Host<->Device copies between src and the |
152 | // _SwitchN node. This aligns with the implementation of legacy tf.cond in |
153 | // control_flow_ops.py. The legacy impl colocates the Switch with the |
154 | // input tensor which resets the device stack and forces the Switch to have |
155 | // the same device as the input node (if set) and sets the colocation _class |
156 | // attr. It also ignores the existing colocation constraints on the input node |
157 | // using colocate_with(ignore_existing=True). |
158 | TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN" , |
159 | graph_->op_registry(), &debug_info) |
160 | .Input(src, src_output) |
161 | .Input(branch_index_) |
162 | .Device(src->requested_device()) |
163 | .Attr("_class" , {src->name()}) |
164 | .Attr("num_outs" , num_branches_) |
165 | .Finalize(graph_, &input)); |
166 | for (int b = 0; b < num_branches_; b++) { |
167 | branch_call_builders_[b].Input(input, b); |
168 | } |
169 | return OkStatus(); |
170 | } |
171 | |
172 | Status CaseBuilder::AddInputs() { |
173 | // Add input data edges. |
174 | std::vector<const Edge*> edges; |
175 | TF_RETURN_IF_ERROR(case_op_->input_edges(&edges)); |
176 | // Start at index 1 as the first input is the branch index. |
177 | for (int i = 1; i < edges.size(); ++i) { |
178 | const Edge* e = edges[i]; |
179 | TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output())); |
180 | } |
181 | // Add input control edges. |
182 | for (const Edge* e : case_op_->in_edges()) { |
183 | if (e->IsControlEdge()) { |
184 | graph_->AddControlEdge(e->src(), control_predecessor_); |
185 | } |
186 | } |
187 | return OkStatus(); |
188 | } |
189 | |
190 | Status CaseBuilder::AddOutputs() { |
191 | // Construct the call nodes for each branch. |
192 | call_nodes_.resize(num_branches_, nullptr); |
193 | for (int b = 0; b < num_branches_; b++) { |
194 | TF_RETURN_IF_ERROR( |
195 | branch_call_builders_[b].Finalize(graph_, &call_nodes_[b])); |
196 | graph_->AddControlEdge(pivots_[b], call_nodes_[b]); |
197 | } |
198 | |
199 | // Merge the outputs from the N branches (all branches have matching outputs). |
200 | const int num_outputs = call_nodes_[0]->num_outputs(); |
201 | std::vector<Node*> merges(num_outputs); |
202 | outputs_.resize(merges.size()); |
203 | for (int i = 0; i < num_outputs; ++i) { |
204 | std::vector<NodeOut> merge_input; |
205 | merge_input.reserve(num_branches_); |
206 | for (int j = 0; j < num_branches_; j++) { |
207 | merge_input.emplace_back(call_nodes_[j], i); |
208 | } |
209 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge" ), "Merge" , |
210 | graph_->op_registry(), &debug_info_) |
211 | .Input(merge_input) |
212 | .Device(case_op_->requested_device()) |
213 | .Finalize(graph_, &merges[i])); |
214 | outputs_[i] = NodeOut(merges[i], 0); |
215 | } |
216 | |
217 | // Add a Merge node that will be used as a control dependency source for the |
218 | // lowered output node. This Merge node will guarantee that lowered else/then |
219 | // function calls will be executed even if they do not have data outputs. |
220 | // |
221 | // Furthermore it will guarantee that all function side effects will be |
222 | // executed, if the function will be inlined into the graph. Having data |
223 | // outputs is not enough, because they might become unused after inlining. |
224 | // |
225 | // We will use this node to rewrite outgoing control edges from lowered 'Case' |
226 | // node. All data edges will read tensors directly from Merge nodes. |
227 | std::vector<NodeOut> pivots(num_branches_); |
228 | for (int j = 0; j < num_branches_; j++) { |
229 | pivots[j] = NodeOut(pivots_[j]); |
230 | } |
231 | TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed" ), "Merge" , |
232 | graph_->op_registry(), &debug_info_) |
233 | .Input(pivots) |
234 | .ControlInputs(call_nodes_) |
235 | .Device(case_op_->requested_device()) |
236 | .Finalize(graph_, &branch_executed_node_)); |
237 | |
238 | TF_RETURN_IF_ERROR(BuildLoweredCaseOutput()); |
239 | |
240 | // Add outputs. |
241 | for (const Edge* e : case_op_->out_edges()) { |
242 | if (e->IsControlEdge()) { |
243 | graph_->AddControlEdge(branch_executed_node_, e->dst()); |
244 | } else { |
245 | // Feed the outputs directly from the merge nodes so that downstream ops |
246 | // can start before all the outputs have been computed. |
247 | graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input()); |
248 | } |
249 | } |
250 | return OkStatus(); |
251 | } |
252 | |
253 | Status CaseBuilder::BuildLoweredCaseOutput() { |
254 | // If outputs are empty, it means that we might have only output control |
255 | // edges (already connected to the `branch_executed_node`). Furthermore it's |
256 | // illegal to have an IdentityN with empty inputs. |
257 | // |
258 | // We still must keep lowered Case node as a valid source of control edges, |
259 | // because it might be a part of function control output set. |
260 | NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty() |
261 | ? NodeBuilder(name_, "IdentityN" ).Input(outputs_) |
262 | : NodeBuilder(name_, "NoOp" ); |
263 | return builder.Device(case_op_->requested_device()) |
264 | .ControlInput(branch_executed_node_) |
265 | .Finalize(graph_, &lowered_case_output_); |
266 | } |
267 | |
268 | } // namespace |
269 | |
270 | Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) { |
271 | VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable |
272 | << "): " << SummarizeNode(*n); |
273 | const AttrValue* branches_attr = n->attrs().Find("branches" ); |
274 | if (branches_attr == nullptr) { |
275 | return errors::InvalidArgument("branch functions missing" ); |
276 | } |
277 | |
278 | int num_branches = branches_attr->list().func_size(); |
279 | std::vector<string> branch_fn_names; |
280 | branch_fn_names.reserve(num_branches); |
281 | for (int b = 0; b < num_branches; b++) { |
282 | branch_fn_names.emplace_back(branches_attr->list().func(b).name()); |
283 | } |
284 | CaseBuilder cb(n, branch_fn_names, keep_node_fetchable, g); |
285 | TF_RETURN_IF_ERROR(cb.CreatePivotNodes()); |
286 | TF_RETURN_IF_ERROR(cb.AddInputs()); |
287 | TF_RETURN_IF_ERROR(cb.AddOutputs()); |
288 | g->RemoveNode(n); |
289 | |
290 | return OkStatus(); |
291 | } |
292 | |
293 | } // namespace tensorflow |
294 | |