1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/lower_functional_ops.h"
17
18#include <string>
19
20#include "absl/container/flat_hash_set.h"
21#include "tensorflow/core/common_runtime/device_propagation.h"
22#include "tensorflow/core/common_runtime/function_utils.h"
23#include "tensorflow/core/common_runtime/inline_function_utils.h"
24#include "tensorflow/core/common_runtime/lower_case_op.h"
25#include "tensorflow/core/common_runtime/lower_function_call_op.h"
26#include "tensorflow/core/common_runtime/lower_if_op.h"
27#include "tensorflow/core/common_runtime/lower_while_op.h"
28#include "tensorflow/core/framework/node_def_util.h"
29#include "tensorflow/core/graph/graph.h"
30#include "tensorflow/core/graph/graph_node_util.h"
31#include "tensorflow/core/public/session_options.h"
32
33namespace tensorflow {
34
35namespace {
36
37constexpr const char* const kLowerUsingSwitchMergeAttr =
38 LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
39constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
40 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
41
42constexpr const char* const kTpuReplicateAttr = "_tpu_replicate";
43constexpr const char* const kXlaClusterAttr = "_xla_compile_id";
44
45// Checks if boolean attribute is defined and it's value is 'true'.
46bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
47 bool match;
48 bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
49 return found && match;
50}
51
52// Checks if string attribute is defined and it's not empty.
53bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
54 string match;
55 bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
56 return found && !match.empty();
57}
58
59bool LowerUsingSwitchMergeIsOn(const Node* n) {
60 return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
61}
62
63bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
64 return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
65}
66
67bool MarkedForTpuCompilation(const Node* n) {
68 return CheckStringAttr(n, kTpuReplicateAttr);
69}
70
71bool MarkedForXlaCompilation(const Node* n) {
72 return CheckStringAttr(n, kXlaClusterAttr);
73}
74
75bool HasArgsOrRetvals(const Graph& g) {
76 for (const Node* n : g.op_nodes()) {
77 if (n->IsArg() || n->IsRetval()) return true;
78 }
79 return false;
80}
81
82const absl::flat_hash_set<std::string>& DevicePropagationOpList() {
83 // Control flow ops and Identity ops which are inserted by function call
84 // inlining.
85 static const auto op_list = new absl::flat_hash_set<std::string>(
86 {"Identity", "IdentityN", "Enter", "Exit", "Switch", "Merge",
87 "NextIteration"});
88 return *op_list;
89}
90
91bool IsPropagatableDevice(StringPiece device_string) {
92 DeviceNameUtils::ParsedName device;
93 return DeviceNameUtils::ParseFullName(device_string, &device) &&
94 device.type == DEVICE_TPU;
95}
96
97} // namespace
98
99Status LowerFunctionalOpsPass::Run(
100 const GraphOptimizationPassOptions& options) {
101 if (options.partition_graphs != nullptr) {
102 return errors::Internal(
103 "Lowering If/While ops should happen before partitioning.");
104 }
105 if (options.graph == nullptr) {
106 return OkStatus();
107 }
108
109 Graph* g = options.graph->get();
110 if (g == nullptr) {
111 return errors::Internal(
112 "Lowering While op requires a graph to be available.");
113 }
114
115 FunctionLibraryDefinition* flib_def = options.flib_def;
116 if (flib_def == nullptr) {
117 return errors::Internal(
118 "Lowering If op requires a FunctionLibraryDefinition to be available.");
119 }
120
121 // Lower function calls only if it's explicitly enabled in session options.
122 const bool lower_function_calls =
123 options.session_options && options.session_options->config.graph_options()
124 .optimizer_options()
125 .do_function_inlining();
126
127 // If graph is a function instantiation, it will have `_Arg` and `_Retval`
128 // nodes for input and output tensors. Otherwise it's unsafe to remove any of
129 // the nodes, because they might be later used as fetches.
130 //
131 // When we do not keep lowered nodes fetchable, we still add a NoOp node to
132 // the graph with the same name as lowered node, because it might be used as a
133 // control output source, and it's currently not expressed in a graph.
134 bool keep_lowered_nodes_fetchable = !HasArgsOrRetvals(*g);
135
136 // We disable lowering control flow to switch/merge variants when requested,
137 // and for the single-threaded executor and TFRT runtime, which does not
138 // support it.
139 const bool functional_control_flow =
140 options.session_options &&
141 (options.session_options->config.experimental().executor_type() ==
142 "SINGLE_THREADED_EXECUTOR" ||
143 options.session_options->config.experimental().use_tfrt() ||
144 options.session_options->config.experimental()
145 .disable_functional_ops_lowering());
146
147 // Returns true if `node` will be used for XLA compilation.
148 const auto used_by_xla = [](Node* node) -> bool {
149 return MarkedForTpuCompilation(node) || MarkedForXlaCompilation(node);
150 };
151
152 // Returns true if control flow `node` should be lowered to Switch/Merge.
153 const auto lower_control_flow = [&](Node* node) -> bool {
154 return LowerUsingSwitchMergeIsOn(node) && !used_by_xla(node);
155 };
156
157 // Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr`
158 // attr set and inline all function calls into the graph.
159 // We start at `i` = 2 to skip the source and sink nodes.
160 // Note that `g->num_node_ids()` may change in the for body if a matching If,
161 // Case, While node is lowered. Since new graph nodes are always added to the
162 // end of the list of nodes it is ensured that nested If/Case/While nodes will
163 // be lowered as well.
164 int num_node_ids_before_lowering = g->num_node_ids();
165 for (int i = 2; i < g->num_node_ids(); ++i) {
166 Node* n = g->FindNodeId(i);
167 if (n == nullptr) continue; // deleted node
168
169 // Always lower function calls produced by lowering If/While nodes.
170 if (IsFunctionCall(*flib_def, *n) && !used_by_xla(n) &&
171 (lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) {
172 TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def,
173 keep_lowered_nodes_fetchable));
174 continue;
175 }
176
177 // If we are allowed to used function control flow, we do not need to check
178 // for If/While/Case nodes in the graph.
179 if (functional_control_flow) continue;
180
181 if (n->IsIfNode() && lower_control_flow(n)) {
182 TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable));
183
184 } else if (n->IsCaseNode() && lower_control_flow(n)) {
185 TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable));
186
187 } else if (n->IsWhileNode() && lower_control_flow(n)) {
188 TF_RETURN_IF_ERROR(
189 RewriteWhileNode(n, g, flib_def, keep_lowered_nodes_fetchable));
190
191 } else {
192 DCHECK(!lower_control_flow(n))
193 << "Node " << FormatNodeForError(*n) << " of type "
194 << n->type_string() << " has '"
195 << LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr
196 << "' attr set but it does not support lowering.\n";
197 }
198 }
199
200 // Propagates device assignments inside a function call to control flow ops
201 // after function call is lowered, bcause If/Case/While node lowering happen
202 // before function call lowering,
203 PropagateDevices(
204 [num_node_ids_before_lowering](const Node& n) {
205 return DevicePropagationOpList().contains(n.type_string()) &&
206 n.id() >= num_node_ids_before_lowering; // Newly created nodes.
207 },
208 IsPropagatableDevice, g);
209
210 return OkStatus();
211}
212
213REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10,
214 LowerFunctionalOpsPass);
215
216} // namespace tensorflow
217