1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/lower_case_op.h"
17
18#include "tensorflow/core/common_runtime/inline_function_utils.h"
19#include "tensorflow/core/framework/node_def_builder.h"
20#include "tensorflow/core/graph/graph.h"
21#include "tensorflow/core/graph/node_builder.h"
22#include "tensorflow/core/lib/core/errors.h"
23
24namespace tensorflow {
25
26namespace {
27
28using NodeOut = NodeBuilder::NodeOut;
29
30constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
31 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
32
33// Convenience builder to make it easy to construct a case with a single
34// function call in each branch. This first converts the Case node
35// into switches (for inputs) and merges (for outputs) around a function call
36// per branch.
37class CaseBuilder {
38 public:
39 // Create a CaseBuilder to create the lowered form of `case` with branch
40 // functions identified by `branch_fn_names` in the `graph`.
41 CaseBuilder(Node* case_op, const std::vector<string>& branch_fn_names,
42 bool keep_node_fetchable, Graph* graph);
43
44 // Constructs the basic conditional control flow using switch and merge nodes.
45 Status CreatePivotNodes();
46
47 // Adds the inputs from the if node to the merge nodes of the lowered if.
48 Status AddInputs();
49
50 // Adds the outputs from the if node to the merge nodes of the lowered if.
51 // Note: no inputs can be added once outputs are added as the then and else
52 // nodes are finalized while adding outputs.
53 Status AddOutputs();
54
55 // Builds an identity node with the same outputs as Case.
56 Status BuildLoweredCaseOutput();
57
58 private:
59 // Returns unique name containing the name of the Case op being rewritten
60 // (name_), infix and a suffix to ensure it is unique within the graph.
61 string NewName(const string& infix);
62
63 // Adds input to both the then and else nodes from src:src_output.
64 Status AddInput(Node* src, int src_output);
65
66 // The merged outputs of the then and else nodes.
67 std::vector<NodeOut> outputs_;
68
69 // The node that dominates all execution of the then and else body nodes.
70 Node* control_predecessor_;
71 // The original Case op.
72 Node* case_op_;
73 // The node with the same name as the original Case op:
74 // (a) IdentityN node with same outputs if 'keep_node_fetchable_ == true'
75 // and if the original Case op had non-zero data outputs.
76 // (b) NoOp node with control edge from 'branch_executed_node_' otherwise.
77 Node* lowered_case_output_;
78 // The branch selector of the case.
79 OutputTensor branch_index_;
80 int num_branches_;
81 // Nodes corresponding to pivot branch of branch_index _SwitchN, which is
82 // the pivot node that dominates all nodes in the i'th branch.
83 std::vector<Node*> pivots_;
84 std::vector<Node*> call_nodes_;
85 // Merge node that has inputs from each of pivots_ and control edges from
86 // [^call_node for call_node in call_nodes_]. This node will guarantee that
87 // even when branch functions do not have outputs, they still will be executed
88 // for the side effects.
89 Node* branch_executed_node_;
90 Graph* graph_;
91 string name_;
92 bool keep_node_fetchable_;
93
94 NodeDebugInfo debug_info_;
95 std::vector<NodeBuilder> branch_call_builders_;
96};
97
98CaseBuilder::CaseBuilder(Node* case_op,
99 const std::vector<string>& branch_fn_names,
100 bool keep_node_fetchable, Graph* graph)
101 : case_op_(case_op),
102 num_branches_(branch_fn_names.size()),
103 graph_(graph),
104 name_(case_op->name()),
105 keep_node_fetchable_(keep_node_fetchable),
106 debug_info_(*case_op_) {
107 branch_call_builders_.reserve(num_branches_);
108 for (int b = 0; b < num_branches_; b++) {
109 branch_call_builders_.emplace_back(NewName(strings::StrCat("branch", b)),
110 branch_fn_names[b], graph->op_registry(),
111 &debug_info_);
112 branch_call_builders_[b].Device(case_op_->requested_device());
113 branch_call_builders_[b].Attr(kLowerAsMultiDeviceFunctionAttr, true);
114 }
115 TF_CHECK_OK(case_op_->input_tensor(0, &branch_index_));
116}
117
118Status CaseBuilder::CreatePivotNodes() {
119 // Construct the basic case body (consisting of feeding in the val to
120 // create pivot nodes).
121 Node* branch_index;
122 TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_index"), "_SwitchN",
123 graph_->op_registry(), &debug_info_)
124 .Input(NodeOut(branch_index_))
125 .Input(NodeOut(branch_index_))
126 .Attr("num_outs", num_branches_)
127 .Device(case_op_->requested_device())
128 .Finalize(graph_, &branch_index));
129 control_predecessor_ = branch_index;
130 pivots_.resize(num_branches_, nullptr);
131 for (int b = 0; b < num_branches_; b++) {
132 TF_RETURN_IF_ERROR(NodeBuilder(NewName(strings::StrCat("pivot_", b)),
133 "Identity", graph_->op_registry(),
134 &debug_info_)
135 .Input(branch_index, b)
136 .Device(case_op_->requested_device())
137 .Finalize(graph_, &pivots_[b]));
138 }
139 return OkStatus();
140}
141
142string CaseBuilder::NewName(const string& infix) {
143 return graph_->NewName(strings::StrCat(name_, "/", infix));
144}
145
146Status CaseBuilder::AddInput(Node* src, int src_output) {
147 Node* input;
148 NodeDebugInfo debug_info(*src);
149 // Colocate the Switch node with the `src` node.
150 //
151 // This is to avoid unnecessary Host<->Device copies between src and the
152 // _SwitchN node. This aligns with the implementation of legacy tf.cond in
153 // control_flow_ops.py. The legacy impl colocates the Switch with the
154 // input tensor which resets the device stack and forces the Switch to have
155 // the same device as the input node (if set) and sets the colocation _class
156 // attr. It also ignores the existing colocation constraints on the input node
157 // using colocate_with(ignore_existing=True).
158 TF_RETURN_IF_ERROR(NodeBuilder(NewName(src->name()), "_SwitchN",
159 graph_->op_registry(), &debug_info)
160 .Input(src, src_output)
161 .Input(branch_index_)
162 .Device(src->requested_device())
163 .Attr("_class", {src->name()})
164 .Attr("num_outs", num_branches_)
165 .Finalize(graph_, &input));
166 for (int b = 0; b < num_branches_; b++) {
167 branch_call_builders_[b].Input(input, b);
168 }
169 return OkStatus();
170}
171
172Status CaseBuilder::AddInputs() {
173 // Add input data edges.
174 std::vector<const Edge*> edges;
175 TF_RETURN_IF_ERROR(case_op_->input_edges(&edges));
176 // Start at index 1 as the first input is the branch index.
177 for (int i = 1; i < edges.size(); ++i) {
178 const Edge* e = edges[i];
179 TF_RETURN_IF_ERROR(AddInput(e->src(), e->src_output()));
180 }
181 // Add input control edges.
182 for (const Edge* e : case_op_->in_edges()) {
183 if (e->IsControlEdge()) {
184 graph_->AddControlEdge(e->src(), control_predecessor_);
185 }
186 }
187 return OkStatus();
188}
189
190Status CaseBuilder::AddOutputs() {
191 // Construct the call nodes for each branch.
192 call_nodes_.resize(num_branches_, nullptr);
193 for (int b = 0; b < num_branches_; b++) {
194 TF_RETURN_IF_ERROR(
195 branch_call_builders_[b].Finalize(graph_, &call_nodes_[b]));
196 graph_->AddControlEdge(pivots_[b], call_nodes_[b]);
197 }
198
199 // Merge the outputs from the N branches (all branches have matching outputs).
200 const int num_outputs = call_nodes_[0]->num_outputs();
201 std::vector<Node*> merges(num_outputs);
202 outputs_.resize(merges.size());
203 for (int i = 0; i < num_outputs; ++i) {
204 std::vector<NodeOut> merge_input;
205 merge_input.reserve(num_branches_);
206 for (int j = 0; j < num_branches_; j++) {
207 merge_input.emplace_back(call_nodes_[j], i);
208 }
209 TF_RETURN_IF_ERROR(NodeBuilder(NewName("merge"), "Merge",
210 graph_->op_registry(), &debug_info_)
211 .Input(merge_input)
212 .Device(case_op_->requested_device())
213 .Finalize(graph_, &merges[i]));
214 outputs_[i] = NodeOut(merges[i], 0);
215 }
216
217 // Add a Merge node that will be used as a control dependency source for the
218 // lowered output node. This Merge node will guarantee that lowered else/then
219 // function calls will be executed even if they do not have data outputs.
220 //
221 // Furthermore it will guarantee that all function side effects will be
222 // executed, if the function will be inlined into the graph. Having data
223 // outputs is not enough, because they might become unused after inlining.
224 //
225 // We will use this node to rewrite outgoing control edges from lowered 'Case'
226 // node. All data edges will read tensors directly from Merge nodes.
227 std::vector<NodeOut> pivots(num_branches_);
228 for (int j = 0; j < num_branches_; j++) {
229 pivots[j] = NodeOut(pivots_[j]);
230 }
231 TF_RETURN_IF_ERROR(NodeBuilder(NewName("branch_executed"), "Merge",
232 graph_->op_registry(), &debug_info_)
233 .Input(pivots)
234 .ControlInputs(call_nodes_)
235 .Device(case_op_->requested_device())
236 .Finalize(graph_, &branch_executed_node_));
237
238 TF_RETURN_IF_ERROR(BuildLoweredCaseOutput());
239
240 // Add outputs.
241 for (const Edge* e : case_op_->out_edges()) {
242 if (e->IsControlEdge()) {
243 graph_->AddControlEdge(branch_executed_node_, e->dst());
244 } else {
245 // Feed the outputs directly from the merge nodes so that downstream ops
246 // can start before all the outputs have been computed.
247 graph_->AddEdge(merges[e->src_output()], 0, e->dst(), e->dst_input());
248 }
249 }
250 return OkStatus();
251}
252
253Status CaseBuilder::BuildLoweredCaseOutput() {
254 // If outputs are empty, it means that we might have only output control
255 // edges (already connected to the `branch_executed_node`). Furthermore it's
256 // illegal to have an IdentityN with empty inputs.
257 //
258 // We still must keep lowered Case node as a valid source of control edges,
259 // because it might be a part of function control output set.
260 NodeBuilder builder = keep_node_fetchable_ && !outputs_.empty()
261 ? NodeBuilder(name_, "IdentityN").Input(outputs_)
262 : NodeBuilder(name_, "NoOp");
263 return builder.Device(case_op_->requested_device())
264 .ControlInput(branch_executed_node_)
265 .Finalize(graph_, &lowered_case_output_);
266}
267
268} // namespace
269
270Status RewriteCaseNode(Node* n, Graph* g, bool keep_node_fetchable) {
271 VLOG(2) << "Lower Case node (keep_node_fetchable=" << keep_node_fetchable
272 << "): " << SummarizeNode(*n);
273 const AttrValue* branches_attr = n->attrs().Find("branches");
274 if (branches_attr == nullptr) {
275 return errors::InvalidArgument("branch functions missing");
276 }
277
278 int num_branches = branches_attr->list().func_size();
279 std::vector<string> branch_fn_names;
280 branch_fn_names.reserve(num_branches);
281 for (int b = 0; b < num_branches; b++) {
282 branch_fn_names.emplace_back(branches_attr->list().func(b).name());
283 }
284 CaseBuilder cb(n, branch_fn_names, keep_node_fetchable, g);
285 TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
286 TF_RETURN_IF_ERROR(cb.AddInputs());
287 TF_RETURN_IF_ERROR(cb.AddOutputs());
288 g->RemoveNode(n);
289
290 return OkStatus();
291}
292
293} // namespace tensorflow
294