1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/lower_functional_ops.h" |
17 | |
18 | #include <string> |
19 | |
20 | #include "absl/container/flat_hash_set.h" |
21 | #include "tensorflow/core/common_runtime/device_propagation.h" |
22 | #include "tensorflow/core/common_runtime/function_utils.h" |
23 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
24 | #include "tensorflow/core/common_runtime/lower_case_op.h" |
25 | #include "tensorflow/core/common_runtime/lower_function_call_op.h" |
26 | #include "tensorflow/core/common_runtime/lower_if_op.h" |
27 | #include "tensorflow/core/common_runtime/lower_while_op.h" |
28 | #include "tensorflow/core/framework/node_def_util.h" |
29 | #include "tensorflow/core/graph/graph.h" |
30 | #include "tensorflow/core/graph/graph_node_util.h" |
31 | #include "tensorflow/core/public/session_options.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | namespace { |
36 | |
37 | constexpr const char* const kLowerUsingSwitchMergeAttr = |
38 | LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; |
39 | constexpr const char* const kLowerAsMultiDeviceFunctionAttr = |
40 | LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; |
41 | |
42 | constexpr const char* const kTpuReplicateAttr = "_tpu_replicate" ; |
43 | constexpr const char* const kXlaClusterAttr = "_xla_compile_id" ; |
44 | |
45 | // Checks if boolean attribute is defined and it's value is 'true'. |
46 | bool CheckBoolAttr(const Node* n, absl::string_view attr_name) { |
47 | bool match; |
48 | bool found = TryGetNodeAttr(n->attrs(), attr_name, &match); |
49 | return found && match; |
50 | } |
51 | |
52 | // Checks if string attribute is defined and it's not empty. |
53 | bool CheckStringAttr(const Node* n, absl::string_view attr_name) { |
54 | string match; |
55 | bool found = TryGetNodeAttr(n->attrs(), attr_name, &match); |
56 | return found && !match.empty(); |
57 | } |
58 | |
59 | bool LowerUsingSwitchMergeIsOn(const Node* n) { |
60 | return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr); |
61 | } |
62 | |
63 | bool LowerAsMultiDeviceFunctionIsOn(const Node* n) { |
64 | return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr); |
65 | } |
66 | |
67 | bool MarkedForTpuCompilation(const Node* n) { |
68 | return CheckStringAttr(n, kTpuReplicateAttr); |
69 | } |
70 | |
71 | bool MarkedForXlaCompilation(const Node* n) { |
72 | return CheckStringAttr(n, kXlaClusterAttr); |
73 | } |
74 | |
75 | bool HasArgsOrRetvals(const Graph& g) { |
76 | for (const Node* n : g.op_nodes()) { |
77 | if (n->IsArg() || n->IsRetval()) return true; |
78 | } |
79 | return false; |
80 | } |
81 | |
82 | const absl::flat_hash_set<std::string>& DevicePropagationOpList() { |
83 | // Control flow ops and Identity ops which are inserted by function call |
84 | // inlining. |
85 | static const auto op_list = new absl::flat_hash_set<std::string>( |
86 | {"Identity" , "IdentityN" , "Enter" , "Exit" , "Switch" , "Merge" , |
87 | "NextIteration" }); |
88 | return *op_list; |
89 | } |
90 | |
91 | bool IsPropagatableDevice(StringPiece device_string) { |
92 | DeviceNameUtils::ParsedName device; |
93 | return DeviceNameUtils::ParseFullName(device_string, &device) && |
94 | device.type == DEVICE_TPU; |
95 | } |
96 | |
97 | } // namespace |
98 | |
99 | Status LowerFunctionalOpsPass::Run( |
100 | const GraphOptimizationPassOptions& options) { |
101 | if (options.partition_graphs != nullptr) { |
102 | return errors::Internal( |
103 | "Lowering If/While ops should happen before partitioning." ); |
104 | } |
105 | if (options.graph == nullptr) { |
106 | return OkStatus(); |
107 | } |
108 | |
109 | Graph* g = options.graph->get(); |
110 | if (g == nullptr) { |
111 | return errors::Internal( |
112 | "Lowering While op requires a graph to be available." ); |
113 | } |
114 | |
115 | FunctionLibraryDefinition* flib_def = options.flib_def; |
116 | if (flib_def == nullptr) { |
117 | return errors::Internal( |
118 | "Lowering If op requires a FunctionLibraryDefinition to be available." ); |
119 | } |
120 | |
121 | // Lower function calls only if it's explicitly enabled in session options. |
122 | const bool lower_function_calls = |
123 | options.session_options && options.session_options->config.graph_options() |
124 | .optimizer_options() |
125 | .do_function_inlining(); |
126 | |
127 | // If graph is a function instantiation, it will have `_Arg` and `_Retval` |
128 | // nodes for input and output tensors. Otherwise it's unsafe to remove any of |
129 | // the nodes, because they might be later used as fetches. |
130 | // |
131 | // When we do not keep lowered nodes fetchable, we still add a NoOp node to |
132 | // the graph with the same name as lowered node, because it might be used as a |
133 | // control output source, and it's currently not expressed in a graph. |
134 | bool keep_lowered_nodes_fetchable = !HasArgsOrRetvals(*g); |
135 | |
136 | // We disable lowering control flow to switch/merge variants when requested, |
137 | // and for the single-threaded executor and TFRT runtime, which does not |
138 | // support it. |
139 | const bool functional_control_flow = |
140 | options.session_options && |
141 | (options.session_options->config.experimental().executor_type() == |
142 | "SINGLE_THREADED_EXECUTOR" || |
143 | options.session_options->config.experimental().use_tfrt() || |
144 | options.session_options->config.experimental() |
145 | .disable_functional_ops_lowering()); |
146 | |
147 | // Returns true if `node` will be used for XLA compilation. |
148 | const auto used_by_xla = [](Node* node) -> bool { |
149 | return MarkedForTpuCompilation(node) || MarkedForXlaCompilation(node); |
150 | }; |
151 | |
152 | // Returns true if control flow `node` should be lowered to Switch/Merge. |
153 | const auto lower_control_flow = [&](Node* node) -> bool { |
154 | return LowerUsingSwitchMergeIsOn(node) && !used_by_xla(node); |
155 | }; |
156 | |
157 | // Lower all If, Case, While ops that have the `kLowerUsingSwitchMergeAttr` |
158 | // attr set and inline all function calls into the graph. |
159 | // We start at `i` = 2 to skip the source and sink nodes. |
160 | // Note that `g->num_node_ids()` may change in the for body if a matching If, |
161 | // Case, While node is lowered. Since new graph nodes are always added to the |
162 | // end of the list of nodes it is ensured that nested If/Case/While nodes will |
163 | // be lowered as well. |
164 | int num_node_ids_before_lowering = g->num_node_ids(); |
165 | for (int i = 2; i < g->num_node_ids(); ++i) { |
166 | Node* n = g->FindNodeId(i); |
167 | if (n == nullptr) continue; // deleted node |
168 | |
169 | // Always lower function calls produced by lowering If/While nodes. |
170 | if (IsFunctionCall(*flib_def, *n) && !used_by_xla(n) && |
171 | (lower_function_calls || LowerAsMultiDeviceFunctionIsOn(n))) { |
172 | TF_RETURN_IF_ERROR(RewriteFunctionCallNode(n, g, *flib_def, |
173 | keep_lowered_nodes_fetchable)); |
174 | continue; |
175 | } |
176 | |
177 | // If we are allowed to used function control flow, we do not need to check |
178 | // for If/While/Case nodes in the graph. |
179 | if (functional_control_flow) continue; |
180 | |
181 | if (n->IsIfNode() && lower_control_flow(n)) { |
182 | TF_RETURN_IF_ERROR(RewriteIfNode(n, g, keep_lowered_nodes_fetchable)); |
183 | |
184 | } else if (n->IsCaseNode() && lower_control_flow(n)) { |
185 | TF_RETURN_IF_ERROR(RewriteCaseNode(n, g, keep_lowered_nodes_fetchable)); |
186 | |
187 | } else if (n->IsWhileNode() && lower_control_flow(n)) { |
188 | TF_RETURN_IF_ERROR( |
189 | RewriteWhileNode(n, g, flib_def, keep_lowered_nodes_fetchable)); |
190 | |
191 | } else { |
192 | DCHECK(!lower_control_flow(n)) |
193 | << "Node " << FormatNodeForError(*n) << " of type " |
194 | << n->type_string() << " has '" |
195 | << LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr |
196 | << "' attr set but it does not support lowering.\n" ; |
197 | } |
198 | } |
199 | |
200 | // Propagates device assignments inside a function call to control flow ops |
201 | // after function call is lowered, bcause If/Case/While node lowering happen |
202 | // before function call lowering, |
203 | PropagateDevices( |
204 | [num_node_ids_before_lowering](const Node& n) { |
205 | return DevicePropagationOpList().contains(n.type_string()) && |
206 | n.id() >= num_node_ids_before_lowering; // Newly created nodes. |
207 | }, |
208 | IsPropagatableDevice, g); |
209 | |
210 | return OkStatus(); |
211 | } |
212 | |
213 | REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 10, |
214 | LowerFunctionalOpsPass); |
215 | |
216 | } // namespace tensorflow |
217 | |