1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
17 | |
18 | #include <deque> |
19 | #include <vector> |
20 | |
21 | #include "absl/algorithm/container.h" |
22 | #include "absl/memory/memory.h" |
23 | #include "absl/strings/str_cat.h" |
24 | #include "absl/strings/string_view.h" |
25 | #include "tensorflow/core/common_runtime/device.h" |
26 | #include "tensorflow/core/common_runtime/function_utils.h" |
27 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
28 | #include "tensorflow/core/framework/collective.h" |
29 | #include "tensorflow/core/framework/function.h" |
30 | #include "tensorflow/core/framework/node_def.pb.h" |
31 | #include "tensorflow/core/framework/node_def_util.h" |
32 | #include "tensorflow/core/framework/op.h" |
33 | #include "tensorflow/core/framework/op_kernel.h" |
34 | #include "tensorflow/core/framework/versions.pb.h" |
35 | #include "tensorflow/core/graph/algorithm.h" |
36 | #include "tensorflow/core/graph/control_flow.h" |
37 | #include "tensorflow/core/graph/node_builder.h" |
38 | #include "tensorflow/core/graph/optimizer_cse.h" |
39 | #include "tensorflow/core/lib/core/threadpool.h" |
40 | #include "tensorflow/core/lib/gtl/map_util.h" |
41 | #include "tensorflow/core/platform/macros.h" |
42 | #include "tensorflow/core/profiler/lib/traceme.h" |
43 | #include "tensorflow/core/protobuf/config.pb.h" |
44 | #include "tensorflow/core/util/device_name_utils.h" |
45 | |
46 | namespace tensorflow { |
47 | |
48 | /*static*/ constexpr const char* const |
49 | LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr; |
50 | /*static*/ constexpr const char* const |
51 | LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr; |
52 | |
53 | namespace { |
54 | // A few string constant used throughout this module. |
55 | static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp; |
56 | static constexpr const char* const kDeviceArgOp = |
57 | FunctionLibraryDefinition::kDeviceArgOp; |
58 | static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp; |
59 | static constexpr const char* const kDeviceRetOp = |
60 | FunctionLibraryDefinition::kDeviceRetOp; |
61 | static constexpr const char* const kGradientOp = |
62 | FunctionLibraryDefinition::kGradientOp; |
63 | static constexpr const char* const kNodeLabel = "Func" ; |
64 | static constexpr const char* const kFuncAttr = |
65 | FunctionLibraryDefinition::kFuncAttr; |
66 | |
67 | // Represents the index-th output of a node. |
68 | struct Endpoint { |
69 | Node* node; |
70 | int index; |
71 | |
72 | // Returns the string name represents this endpoint. |
73 | string name() const { |
74 | if (index == 0) { |
75 | return node->name(); |
76 | } else { |
77 | return strings::StrCat(node->name(), ":" , index); |
78 | } |
79 | } |
80 | |
81 | DataType dtype() const { return node->output_type(index); } |
82 | }; |
83 | |
84 | struct EndpointHash { |
85 | uint64 operator()(const Endpoint& x) const { |
86 | return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*), |
87 | x.index); |
88 | } |
89 | }; |
90 | |
91 | struct EndpointEq { |
92 | bool operator()(const Endpoint& x, const Endpoint& y) const { |
93 | return (x.node == y.node) && (x.index == y.index); |
94 | } |
95 | }; |
96 | |
97 | // The following Add* routines are used to add a few graph nodes while |
98 | // functions are transformed. |
99 | static Node* AddNoOp(StringPiece name, Graph* g) { |
100 | NodeDef ndef; |
101 | ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/" , name))); |
102 | ndef.set_op("NoOp" ); |
103 | Status s; |
104 | Node* ret = g->AddNode(ndef, &s); |
105 | TF_CHECK_OK(s); |
106 | return ret; |
107 | } |
108 | |
109 | static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) { |
110 | DCHECK_LT(0, input.dtype()); |
111 | NodeDef ndef; |
112 | ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/" , name))); |
113 | ndef.set_op("Identity" ); |
114 | ndef.add_input(input.name()); |
115 | AddNodeAttr("T" , BaseType(input.dtype()), &ndef); |
116 | Status s; |
117 | Node* ret = g->AddNode(ndef, &s); |
118 | TF_CHECK_OK(s); |
119 | g->AddEdge(input.node, input.index, ret, 0); |
120 | return ret; |
121 | } |
122 | |
123 | std::vector<string> InputDevices(const Node& caller) { |
124 | std::vector<string> input_devices(caller.in_edges().size()); |
125 | std::vector<string> input_tensors(caller.in_edges().size()); |
126 | |
127 | for (const Edge* edge : caller.in_edges()) { |
128 | if (edge->IsControlEdge()) continue; |
129 | const string& input_device = edge->src()->has_assigned_device_name() |
130 | ? edge->src()->assigned_device_name() |
131 | : edge->src()->requested_device(); |
132 | input_devices[edge->dst_input()] = input_device; |
133 | input_tensors[edge->dst_input()] = |
134 | absl::StrCat(edge->src()->name(), ":" , edge->src_output()); |
135 | } |
136 | |
137 | if (VLOG_IS_ON(4)) { |
138 | VLOG(4) << "Function instantiation input devices:" ; |
139 | for (int i = 0; i < input_devices.size(); ++i) { |
140 | if (input_tensors[i].empty()) continue; // skip control edges |
141 | VLOG(4) << " [index " << i << "]" |
142 | << " device: " << input_devices[i] |
143 | << " (input: " << input_tensors[i] << ")" ; |
144 | } |
145 | } |
146 | |
147 | return input_devices; |
148 | } |
149 | |
150 | // Place input nodes on the same device as the corresponding caller input |
151 | // node. Do not specify any placement for all other nodes. |
152 | class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer { |
153 | public: |
154 | explicit DefaultFunctionBodyPlacer(const Node& caller) |
155 | : input_devices_(InputDevices(caller)) {} |
156 | |
157 | absl::optional<string> InputNodeDevice(int input_index) const override { |
158 | return input_devices_[input_index]; |
159 | } |
160 | absl::optional<string> OutputNodeDevice(int output_index) const override { |
161 | return absl::nullopt; |
162 | } |
163 | bool ColocateInputOutputIdentities() const override { return false; } |
164 | absl::optional<string> ControlNodeDevice() const override { |
165 | return absl::nullopt; |
166 | } |
167 | absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override { |
168 | return absl::nullopt; |
169 | } |
170 | |
171 | private: |
172 | const std::vector<string> input_devices_; |
173 | }; |
174 | |
175 | // Place all nodes on the same device as caller node. |
176 | class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { |
177 | public: |
178 | explicit SingleDeviceFunctionBodyPlacer(const Node& caller) |
179 | : caller_device_(caller.def().device()) {} |
180 | |
181 | absl::optional<string> InputNodeDevice(int input_index) const override { |
182 | return caller_device_; |
183 | } |
184 | absl::optional<string> OutputNodeDevice(int output_index) const override { |
185 | return caller_device_; |
186 | } |
187 | bool ColocateInputOutputIdentities() const override { return false; } |
188 | absl::optional<string> ControlNodeDevice() const override { |
189 | return caller_device_; |
190 | } |
191 | absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override { |
192 | return caller_device_; |
193 | } |
194 | |
195 | private: |
196 | const string caller_device_; |
197 | }; |
198 | |
199 | // Place input nodes on the same device as the corresponding caller input |
200 | // node. Do not place output node. Place control nodes on the same device as |
201 | // caller node. For all function body nodes overrides job, replica and task |
202 | // parts of the device assignment to match function caller node. |
203 | class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer { |
204 | public: |
205 | explicit MultiDeviceFunctionBodyPlacer(const Node& caller) |
206 | : caller_device_(caller.def().device()), |
207 | input_devices_(InputDevices(caller)) { |
208 | has_parsed_caller_device_ = |
209 | DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_); |
210 | } |
211 | |
212 | absl::optional<string> InputNodeDevice(int input_index) const override { |
213 | return input_devices_[input_index]; |
214 | } |
215 | absl::optional<string> OutputNodeDevice(int output_index) const override { |
216 | return absl::nullopt; |
217 | } |
218 | bool ColocateInputOutputIdentities() const override { return true; } |
219 | absl::optional<string> ControlNodeDevice() const override { |
220 | return caller_device_; |
221 | } |
222 | absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override { |
223 | // LINT.IfChange |
224 | // TODO(ezhulenev): If function would have been instantiated as a |
225 | // multi-device function and executed via FunctionLibraryRuntime, it could |
226 | // be potentially placed on any available device. However there are multiple |
227 | // tests relying on this assumption. Fix them, and remove this line. |
228 | if (ndef.device().empty()) return caller_device_; |
229 | |
230 | if (!has_parsed_caller_device_) return ndef.device(); |
231 | |
232 | DeviceNameUtils::ParsedName ndef_parsed_device; |
233 | if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device)) |
234 | return ndef.device(); |
235 | |
236 | DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device, |
237 | caller_parsed_device_); |
238 | return DeviceNameUtils::ParsedNameToString(ndef_parsed_device); |
239 | // LINT.ThenChange(../../compiler/mlir/tensorflow/ir/tf_ops.cc) |
240 | } |
241 | |
242 | private: |
243 | string caller_device_; |
244 | bool has_parsed_caller_device_; |
245 | DeviceNameUtils::ParsedName caller_parsed_device_; |
246 | std::vector<string> input_devices_; |
247 | }; |
248 | |
249 | } // namespace |
250 | |
251 | std::unique_ptr<InlinedFunctionBodyPlacer> |
252 | InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph, |
253 | const Node& caller) { |
254 | VLOG(3) << "Create default placer for inlined function body." ; |
255 | return std::make_unique<DefaultFunctionBodyPlacer>(caller); |
256 | } |
257 | |
258 | std::unique_ptr<InlinedFunctionBodyPlacer> |
259 | InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph, |
260 | const Node& caller) { |
261 | VLOG(3) << "Create single device placer for inlined function body." ; |
262 | return std::make_unique<SingleDeviceFunctionBodyPlacer>(caller); |
263 | } |
264 | |
265 | std::unique_ptr<InlinedFunctionBodyPlacer> |
266 | InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph, |
267 | const Node& caller) { |
268 | VLOG(3) << "Create multi device placer for inlined function body." ; |
269 | return std::make_unique<MultiDeviceFunctionBodyPlacer>(caller); |
270 | } |
271 | |
272 | namespace { |
273 | |
274 | Status ValidateNoInline(const FunctionBody* fbody) { |
275 | const auto attr = AttrSlice(&fbody->fdef.attr()); |
276 | bool noinline = false; |
277 | if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) { |
278 | return errors::InvalidArgument( |
279 | "Can't inline function marked with '_noinline'" ); |
280 | } |
281 | return OkStatus(); |
282 | } |
283 | |
284 | using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; |
285 | |
286 | // Propagate the debug info of `nodes` in function `func` to the `target` node. |
287 | // If the debug info of any node is missing, its node name and function name |
288 | // is used. |
289 | void PropagateDebugInfoToNode(const string& func, |
290 | const std::vector<const Node*>& nodes, |
291 | NodeDef* target) { |
292 | if (nodes.empty() || target->has_experimental_debug_info()) { |
293 | return; |
294 | } |
295 | for (const Node* node : nodes) { |
296 | const auto& node_def = node->def(); |
297 | if (node_def.has_experimental_debug_info()) { |
298 | target->mutable_experimental_debug_info()->MergeFrom( |
299 | node_def.experimental_debug_info()); |
300 | } else { |
301 | target->mutable_experimental_debug_info()->add_original_node_names( |
302 | node_def.name()); |
303 | target->mutable_experimental_debug_info()->add_original_func_names(func); |
304 | } |
305 | } |
306 | } |
307 | } // namespace |
308 | |
309 | string InlineFunctionBodyOptions::DebugString() const { |
310 | const auto true_false = [](bool b) { return b ? "true" : "false" ; }; |
311 | |
312 | const auto keep_caller_node_str = [this]() -> string { |
313 | switch (keep_caller_node) { |
314 | case KeepCallerNode::kDoNotKeep: |
315 | return "DoNotKeep" ; |
316 | case KeepCallerNode::kFetchable: |
317 | return "Fetchable" ; |
318 | case KeepCallerNode::kTargetable: |
319 | return "Targetable" ; |
320 | } |
321 | }; |
322 | |
323 | return absl::StrCat( |
324 | "disable_inlining=" , true_false(disable_inlining), |
325 | ", ignore_noinline=" , true_false(ignore_noinline), |
326 | ", inline_impl_selection_group_functions=" , |
327 | true_false(inline_impl_selection_group_functions), |
328 | ", keep_caller_node=" , keep_caller_node_str(), ", output_control_src=" , |
329 | output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs" |
330 | : "ControlOutputs" , |
331 | ", inlined_function_body_placer=" , inlined_function_body_placer.name, |
332 | ", uniquify_frame_names=" , true_false(uniquify_frame_names)); |
333 | } |
334 | |
335 | Status ValidateInlining(const Node* node, const FunctionBody* fbody, |
336 | const InlineFunctionBodyOptions& options) { |
337 | // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee |
338 | // that all side-effectful ops will be executed after inlining. See Grappler |
339 | // function_optimizer for details. Unify all function inlining mechanism. |
340 | // Do not inline if `!fbody->control_ret_nodes.empty()`. |
341 | |
342 | const auto num_node_inputs = static_cast<size_t>(node->num_inputs()); |
343 | const auto num_node_outputs = static_cast<size_t>(node->num_outputs()); |
344 | |
345 | if (num_node_inputs != fbody->arg_types.size() || |
346 | num_node_inputs != fbody->arg_nodes.size()) { |
347 | return errors::InvalidArgument( |
348 | "Node inputs do not match function arguments: inputs=" , num_node_inputs, |
349 | " arg_types=" , fbody->arg_types.size(), |
350 | " arg_nodes=" , fbody->arg_nodes.size()); |
351 | } |
352 | |
353 | if (num_node_outputs != fbody->ret_types.size() || |
354 | num_node_outputs != fbody->ret_nodes.size()) { |
355 | return errors::InvalidArgument( |
356 | "Node outputs do not match function returns: outputs=" , |
357 | num_node_outputs, " ret_types=" , fbody->ret_types.size(), |
358 | " ret_nodes=" , fbody->ret_nodes.size()); |
359 | } |
360 | |
361 | for (int i = 0; i < node->num_inputs(); ++i) { |
362 | if (node->input_type(i) != fbody->arg_types[i]) { |
363 | return errors::InvalidArgument( |
364 | "Node input type doesn't match function argument type: " , |
365 | node->input_type(i), " != " , fbody->arg_types[i], " @ index=" , i); |
366 | } |
367 | } |
368 | for (int i = 0; i < node->num_outputs(); ++i) { |
369 | if (node->output_type(i) != fbody->ret_types[i]) { |
370 | return errors::InvalidArgument( |
371 | "Node output type doesn't match function return type: " , |
372 | node->output_type(i), " != " , fbody->ret_types[i], " @ index=" , i); |
373 | } |
374 | } |
375 | |
376 | if (options.disable_inlining) { |
377 | return errors::InvalidArgument( |
378 | "Function inlining explicitly disabled by 'options.disable_inlining'" ); |
379 | } |
380 | |
381 | if (!options.inline_impl_selection_group_functions) { |
382 | bool is_impl_selection_group_function = |
383 | fbody->fdef.attr().find("api_implements" ) != fbody->fdef.attr().end(); |
384 | if (is_impl_selection_group_function) { |
385 | return errors::InvalidArgument( |
386 | "Inlining of implementation selection group function " , |
387 | fbody->fdef.signature().name(), |
388 | " is disabled by options.inline_impl_selection_group_functions" ); |
389 | } |
390 | } |
391 | |
392 | if (!options.ignore_noinline) { |
393 | TF_RETURN_IF_ERROR(ValidateNoInline(fbody)); |
394 | } |
395 | |
396 | return OkStatus(); |
397 | } |
398 | |
399 | // Function inlining must preserve function execution semantics with regards to |
400 | // side-effects visibility. Tensorflow in Eager mode has an automatic control |
401 | // dependencies tracking mechanism, which enforces well-defined execution order |
402 | // of all side-effects. Any other frontend (e.g. Swift) must produce graphs |
403 | // following the same rules, to ensure that function inlining works correctly. |
404 | // |
405 | // IMPORTANT: Currently we do not have a true notion of "side-effectful" node, |
406 | // we assume that all stateful nodes might have side-effects, though it's not |
407 | // true in practice, e.g. `ReadVariableOp` doesn't have an observable |
408 | // side-effect. |
409 | // |
410 | // Automatic control dependency rules in Tensorflow 2.0 (python in eager mode): |
411 | // |
412 | // 1) When a function has a resource (DT_RESOURCE data type) input argument it |
413 | // "captures" the mutable resource. This is implemented by automatically |
414 | // adding a incoming control edge from the previous side-effectful op |
415 | // touching that resource, and an outgoing control edge to the next |
416 | // side-effectful op using the same resource. This serializes the mutations |
417 | // of the resource to make graph execution deterministic. |
418 | // |
419 | // 2) All stateful ops inside a function body are guaranteed to execute in |
420 | // program order, this is achieved by adding control edges between stateful |
421 | // ops at graph construction time. Stateful ops (or ops that must execute) |
422 | // should be in the function control return set. Having a data edge to the |
423 | // regular function output might be not enough, because after function |
424 | // inlining it might happen that data output is unused. |
425 | // |
426 | // 3) Furthermore, all ops accepting the same resource as an input are |
427 | // guaranteed to run in program order. This is also done by adding control |
428 | // edges at graph construction time. The last op touching the resource |
429 | // must be in a control return set, which will guarantee that all side |
430 | // effects to the resource will happen before function completion. |
431 | // |
432 | // Function inlining must preserve side-effect visibility: |
433 | // |
434 | // 1) All side-effects to the captured resources, that happened before function |
435 | // call must be visible to the function body nodes using that resources. |
436 | // |
437 | // 2) All side-effects to the captured resources, that happened inside function |
438 | // body, must be visible to every op/function using that resource after the |
439 | // function call completed. |
440 | // |
441 | // To guarantee that these properties are preserved after inlining we: |
442 | // |
443 | // 1) Create "input_control_node" NoOp. Function call node incoming control |
444 | // edges will be forwarded *to* this node. Function inputs (Identity nodes) |
445 | // will have a control edge *from* this node. If function body has nodes |
446 | // without inputs, they will have a control edge *from* this node. |
447 | // |
448 | // 2) Create "output_control_node" NoOp. All nodes that have incoming control |
449 | // edge *from* the function call node, will be forwarded to this node. |
450 | // |
451 | // We have two options for choosing which nodes will have a control edge *to* |
452 | // the "output control node": |
453 | // a) control returns (`control_ret` field in FunctionDef) |
454 | // b) data returns (`ret` field in FunctionDef) |
455 | // |
456 | // We do a) for multi-device function calls in Tensorflow v2 and b) |
457 | // for the rest for compatibility with Tensorflow v1. |
458 | // |
459 | // Following the automatic control dependencies tracking rules, a node that |
460 | // has an incoming control edge from the function call node is dependent on |
461 | // the side-effects happening inside the function body. The output control |
462 | // node will guarantee side-effects execution order. |
463 | // |
464 | // If function call node doesn't have an outgoing control edge, it means that |
465 | // no one is interested in observing side-effects that might have happened. |
466 | // |
467 | // Function inlining might leave the graph in partially-placed state. Function |
468 | // inlining caller must call Placer to guarantee that all nodes are placed. |
469 | // |
470 | // Function inlining with `options.override_device=true` will leave graph in |
471 | // fully placed state, by overriding all inlined nodes devices with the caller |
472 | // node device, but it will make functions always single-device. These functions |
473 | // after inlining will not be able to handle resources on multiple devices. This |
474 | // is currently acceptable for XLA use cases (XLA cluster is always executed on |
475 | // a single device). |
476 | // |
477 | // TODO(ezhulenev): Documentation above is ahead of implementation below. |
478 | Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g, |
479 | Node* caller, const FunctionBody* fbody, |
480 | const InlineFunctionBodyOptions& options) { |
481 | VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " [" |
482 | << options.DebugString() << "]" ; |
483 | VLOG(4) << "Inlining function: " << fbody->fdef.DebugString(); |
484 | VLOG(4) << "Current graphdef: " << g->ToGraphDefDebug().DebugString(); |
485 | VLOG(4) << "Caller: " << caller->DebugString(); |
486 | |
487 | Status validation = ValidateInlining(caller, fbody, options); |
488 | if (!validation.ok()) { |
489 | return errors::Internal("Inlining mismatch: " , validation.error_message()); |
490 | } |
491 | |
492 | // Placer is responsible for assigning devices for all nodes that we will add |
493 | // to the graph. |
494 | const std::unique_ptr<InlinedFunctionBodyPlacer> placer = |
495 | options.inlined_function_body_placer.get(*g, *caller); |
496 | |
497 | // We can't possibly introduce a duplicate control edge during function |
498 | // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'. |
499 | static constexpr bool kDoNotCheckDuplicates = true; |
500 | |
501 | // ------------------------------------------------------------------------ // |
502 | // Helper functions to create `NoOp` and `Identity` nodes for auxiliary |
503 | // control nodes and inlined function inputs and outputs. |
504 | |
505 | // Add a NoOp node for function control inputs/outputs. |
506 | const auto no_op = [&](StringPiece name) -> Node* { |
507 | Node* node = AddNoOp(absl::StrCat(caller->name(), "/" , name), g); |
508 | const absl::optional<string> device = placer->ControlNodeDevice(); |
509 | if (device.has_value()) node->set_requested_device(*device); |
510 | return node; |
511 | }; |
512 | |
513 | // Add an Identity node for function input. |
514 | const auto input_identity = [&](StringPiece name, Endpoint input, |
515 | int index) -> Node* { |
516 | Node* node = AddIdentity(absl::StrCat(caller->name(), "/" , name), g, input); |
517 | const absl::optional<string> device = placer->InputNodeDevice(index); |
518 | if (device.has_value()) node->set_requested_device(*device); |
519 | bool colocate_identity = placer->ColocateInputOutputIdentities(); |
520 | if (colocate_identity) { |
521 | node->AddAttr(kColocationAttrName, |
522 | std::vector<string>{absl::StrCat(kColocationGroupPrefix, |
523 | input.node->name())}); |
524 | } |
525 | return node; |
526 | }; |
527 | |
528 | // Add an Identity node for function output. |
529 | const auto output_identity = [&](StringPiece name, Endpoint input, |
530 | int index) -> Node* { |
531 | Node* node = AddIdentity(absl::StrCat(caller->name(), "/" , name), g, input); |
532 | const absl::optional<string> device = placer->OutputNodeDevice(index); |
533 | if (device.has_value()) node->set_requested_device(*device); |
534 | bool colocate_identity = placer->ColocateInputOutputIdentities(); |
535 | if (colocate_identity) { |
536 | node->AddAttr(kColocationAttrName, |
537 | std::vector<string>{absl::StrCat(kColocationGroupPrefix, |
538 | input.node->name())}); |
539 | } |
540 | return node; |
541 | }; |
542 | |
543 | // ------------------------------------------------------------------------ // |
544 | // Helper function to get an input/output argument name by index. For |
545 | // functions instantiated from SymbolicGradien corresponding FunctionDef is |
546 | // empty, and argument name is unknown. |
547 | |
548 | auto arg_name = [&](auto& args, size_t i) -> absl::string_view { |
549 | if (i < args.size()) { |
550 | return args[i].name(); |
551 | } else { |
552 | return "<unknown>" ; |
553 | } |
554 | }; |
555 | |
556 | // ------------------------------------------------------------------------ // |
557 | // Input edges. For data edges coming into "caller", we first compute the |
558 | // <src>:<src_output> for the i-th input in "inputs". |
559 | // If "caller" has any input control dependencies, we add a NoOp |
560 | // node "input_control_node", which depends on "caller"'s control inputs. |
561 | std::vector<Endpoint> inputs(caller->num_inputs()); |
562 | Node* input_control_node = nullptr; |
563 | for (const Edge* e : caller->in_edges()) { |
564 | if (e->IsControlEdge()) { |
565 | if (input_control_node == nullptr) { |
566 | input_control_node = no_op("input_control_node" ); |
567 | } |
568 | g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates); |
569 | } else { |
570 | inputs[e->dst_input()] = {e->src(), e->src_output()}; |
571 | } |
572 | } |
573 | if (input_control_node != nullptr) { |
574 | VLOG(3) << "Created input control node: " << input_control_node->name(); |
575 | } |
576 | |
577 | // We create one Identity node for each input. |
578 | std::vector<Node*> input_nodes; |
579 | std::map<absl::string_view, absl::string_view> input_node_name_map; |
580 | for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { |
581 | Node* n = input_identity("input" , inputs[i], i); |
582 | input_node_name_map[arg_name(fbody->fdef.signature().input_arg(), i)] = |
583 | n->name(); |
584 | input_nodes.push_back(n); |
585 | } |
586 | |
587 | // ------------------------------------------------------------------------ // |
588 | // Duplicate fbody->graph into 'g'. First, we copy the nodes of |
589 | // fbody->graph into 'g' except the source and sink nodes. We copy |
590 | // edges among nodes in 'fbody->graph'. |
591 | // |
592 | // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we |
593 | // remember 'y' in node_map[x->id()]. |
594 | std::unordered_set<string> fn_nodes; |
595 | for (Node* n : fbody->graph->op_nodes()) { |
596 | fn_nodes.insert(n->name()); |
597 | } |
598 | std::vector<Node*> node_map(fbody->graph->num_node_ids()); |
599 | for (Node* n : fbody->graph->op_nodes()) { |
600 | NodeDef ndef = n->def(); |
601 | |
602 | // Maybe override requested node device assignment. |
603 | const absl::optional<string> device = placer->BodyNodeDevice(ndef); |
604 | if (device.has_value()) ndef.set_device(*device); |
605 | |
606 | // Add inlined function name to inlined node debug information. |
607 | PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef); |
608 | |
609 | // Add the function node name as a prefix: |
610 | // 1) to node name to avoid collisions |
611 | // 2) to frame name to avoid multiple LoopCond nodes in one frame |
612 | // 3) to colocation attribute |
613 | const string prefix = strings::StrCat(caller->name(), "/" ); |
614 | TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"" , &ndef, |
615 | options.uniquify_frame_names)); |
616 | |
617 | // If the colocation attribute is an input arg, we need to change it to the |
618 | // new input (Identity) node now. |
619 | TF_RETURN_IF_ERROR( |
620 | MaybeUpdateColocationConstraintsWithMap(input_node_name_map, &ndef)); |
621 | |
622 | TF_RETURN_IF_ERROR( |
623 | MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef)); |
624 | |
625 | Status added_node; |
626 | Node* clone = g->AddNode(std::move(ndef), &added_node); |
627 | TF_CHECK_OK(added_node); |
628 | node_map[n->id()] = clone; |
629 | clone->SetStackTrace(n->GetStackTrace()); |
630 | |
631 | // If there is an input control node, and one of: |
632 | // a) the node has no data or control inputs, or |
633 | // b) the node is a function call (including SymbolicGradient), |
634 | // then add a control edge from the input control node to the clone (only |
635 | // if it does not already have a control input). |
636 | // |
637 | // We must not execute any nodes if the original function call would not |
638 | // have executed. This is especially critical when the function call is |
639 | // inside a control-flow construct like tf.cond(). Case (a) ensures that |
640 | // such nodes do not run. |
641 | // |
642 | // The purpose of case (b) is to ensure that instances of case (a) created |
643 | // by further inlining steps also receive the control dependency. |
644 | // |
645 | // This edge is required to transfer execution frame down to all function |
646 | // body nodes of inlined nested function calls. |
647 | if (input_control_node) { |
648 | const auto is_input_edge = [](const Edge* e) -> bool { |
649 | return !e->src()->IsSource(); |
650 | }; |
651 | const auto is_control_edge = [](const Edge* e) -> bool { |
652 | return !e->src()->IsSource() && e->IsControlEdge(); |
653 | }; |
654 | |
655 | // Forward execution frame if: |
656 | // |
657 | // a) The node has no data or control inputs. |
658 | // b) OR the node is a function call without control inputs (control edge |
659 | // will be used in nested function inlining to forward execution frame |
660 | // to constants inside the function body). |
661 | // |
662 | // c) Do not forward control frame to function argument nodes, they will |
663 | // be connected to the corresponding function input later. |
664 | const bool forward_execution_frame = |
665 | (absl::c_none_of(n->in_edges(), is_input_edge) || // (a) |
666 | (n->IsFunctionCall() && // (b) |
667 | absl::c_none_of(n->in_edges(), is_control_edge))) && // |
668 | !n->IsArg(); // (c) |
669 | |
670 | if (forward_execution_frame) { |
671 | VLOG(4) << "Add control edge from input control node to: " |
672 | << clone->name(); |
673 | g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates); |
674 | } |
675 | } |
676 | } |
677 | for (const Edge* e : fbody->graph->edges()) { |
678 | if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() || |
679 | e->dst()->IsSink()) { |
680 | continue; |
681 | } |
682 | Node* src_copy = node_map[e->src()->id()]; |
683 | Node* dst_copy = node_map[e->dst()->id()]; |
684 | g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input()); |
685 | } |
686 | |
687 | // ------------------------------------------------------------------------ // |
688 | // Connect input edges. |
689 | // |
690 | // Then, we connect inputs[i] to the i-th identity node added. The nodes that |
691 | // previously connected to the j-th output of i-th arg node are reconnected |
692 | // to the i-th identity node. |
693 | // |
694 | // The added identity nodes depend on "input_control_node". |
695 | VLOG(4) << "Add input Identity nodes for each function argument:" ; |
696 | |
697 | for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) { |
698 | Node* arg = node_map[fbody->arg_nodes[i]->id()]; |
699 | Node* n = input_nodes[i]; |
700 | VLOG(4) << " [index " << i << "] " |
701 | << arg_name(fbody->fdef.signature().input_arg(), i) << " as " |
702 | << n->name() << " (input: " << inputs[i].name() |
703 | << ", requested_device: " << n->requested_device() << ")" ; |
704 | |
705 | if (input_control_node) { |
706 | g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates); |
707 | } |
708 | for (const Edge* e : arg->out_edges()) { |
709 | if (e->IsControlEdge()) { |
710 | g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates); |
711 | } else { |
712 | g->AddEdge(n, 0, e->dst(), e->dst_input()); |
713 | } |
714 | } |
715 | node_map[fbody->arg_nodes[i]->id()] = n; |
716 | g->RemoveNode(arg); // 'arg' is disconnected. |
717 | } |
718 | |
719 | // ------------------------------------------------------------------------ // |
720 | // Connect output edges. |
721 | // |
722 | // For i-th return node in fbody->graph, we add in "g" an identity node |
723 | // (outputs[i-th]). We then reconnect every incoming edge into the i-th return |
724 | // node to the added identity node. |
725 | // |
726 | // For every data edge coming out of "callee"s i-th output, we reconnect it to |
727 | // the i-th identity added above. |
728 | // |
729 | // If "callee" is control-depended upon by any other nodes, we add a NoOp node |
730 | // "output_control_node". "output_control_node" depends on all identity nodes |
731 | // added above or on all control return nodes (controlled by |
732 | // `options.output_control_src` value). And nodes previously depend on |
733 | // "callee" is changed to depend on "output_control_node". |
734 | // |
735 | // If `keep_node_fetchable` is `true` we always add an output control node, to |
736 | // guarantee that executing a fetchable node will execute all side-effects. |
737 | VLOG(4) << "Add output Identity nodes for each function output argument:" ; |
738 | |
739 | std::vector<Node*> outputs(caller->num_outputs()); |
740 | for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) { |
741 | Node* ret = node_map[fbody->ret_nodes[i]->id()]; |
742 | Endpoint data; // Data input for the ret node. |
743 | for (const Edge* e : ret->in_edges()) { |
744 | if (!e->IsControlEdge()) { |
745 | data = {e->src(), e->src_output()}; |
746 | break; |
747 | } |
748 | } |
749 | CHECK(data.node != nullptr); |
750 | Node* n = output_identity("output" , data, i); |
751 | outputs[i] = n; |
752 | VLOG(4) << " [index " << i << "] " |
753 | << arg_name(fbody->fdef.signature().output_arg(), i) << " as " |
754 | << n->name() << " (ret: " << data.node->name() << ":" << data.index |
755 | << ", requested_device: " << n->requested_device() << ")" ; |
756 | for (const Edge* e : ret->in_edges()) { |
757 | if (e->IsControlEdge()) { |
758 | g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates); |
759 | } |
760 | } |
761 | g->RemoveNode(ret); // 'ret' is disconnected. |
762 | } |
763 | |
764 | Node* output_control_node = nullptr; |
765 | const bool has_control_outputs = absl::c_any_of( |
766 | caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); }); |
767 | |
768 | using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; |
769 | const bool keep_caller_node = |
770 | options.keep_caller_node == KeepCallerNode::kFetchable || |
771 | options.keep_caller_node == KeepCallerNode::kTargetable; |
772 | |
773 | if (has_control_outputs || keep_caller_node) { |
774 | output_control_node = no_op("output_control_node" ); |
775 | VLOG(4) << "Add output control node: " << output_control_node->name(); |
776 | if (options.output_control_src == OutputControlSrc::kDataOutputs) { |
777 | for (Node* n : outputs) { |
778 | VLOG(4) << " [data output] add control edge from: " << n->name(); |
779 | g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); |
780 | } |
781 | } else { |
782 | for (Node* fbody_node : fbody->control_ret_nodes) { |
783 | Node* n = node_map[fbody_node->id()]; |
784 | VLOG(4) << " [control output] add control edge from: " << n->name(); |
785 | g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates); |
786 | } |
787 | } |
788 | } |
789 | |
790 | // We can't leave output control node without incoming control edges, because |
791 | // in this case outgoing control edge will loose execution frame information. |
792 | // We connect input_control_node and output_control_node with a control edge |
793 | // to forward execution frame to the controlled nodes. Above we add a control |
794 | // edge to all function calls inside function body, to guarantee that we will |
795 | // always have input_control_node when we need it. |
796 | if (output_control_node && output_control_node->in_edges().empty()) { |
797 | if (input_control_node) { |
798 | VLOG(4) << "Add a control edge between input and output control nodes: " |
799 | << input_control_node->name() << " to " |
800 | << output_control_node->name(); |
801 | g->AddControlEdge(input_control_node, output_control_node, |
802 | kDoNotCheckDuplicates); |
803 | } else { |
804 | VLOG(4) << "Function inlining potentially dropped execution frame " |
805 | "information from outgoing control edges." ; |
806 | } |
807 | } |
808 | |
809 | for (const Edge* e : caller->out_edges()) { |
810 | if (e->IsControlEdge()) { |
811 | g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates); |
812 | } else { |
813 | g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input()); |
814 | } |
815 | } |
816 | |
817 | // ------------------------------------------------------------------------ // |
818 | // Add an IdentityN or NoOp node in-place of caller node to keep `caller` |
819 | // fetchable or targetable. |
820 | |
821 | if (keep_caller_node) { |
822 | std::vector<NodeBuilder::NodeOut> output_tensors; |
823 | absl::c_transform(outputs, std::back_inserter(output_tensors), |
824 | [](Node* n) { return NodeBuilder::NodeOut(n, 0); }); |
825 | |
826 | Node* caller_substitute_node; |
827 | if (options.keep_caller_node == KeepCallerNode::kTargetable || |
828 | output_tensors.empty()) { |
829 | // IdentityN node must have at least one data input. If function has no |
830 | // data outputs, we can't keep it fetchable. |
831 | TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp" ) |
832 | .Device(caller->requested_device()) |
833 | .ControlInput(output_control_node) |
834 | .Finalize(g, &caller_substitute_node)); |
835 | |
836 | } else if (options.keep_caller_node == KeepCallerNode::kFetchable) { |
837 | TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN" ) |
838 | .Device(caller->requested_device()) |
839 | .Input(output_tensors) |
840 | .ControlInput(output_control_node) |
841 | .Finalize(g, &caller_substitute_node)); |
842 | } |
843 | } |
844 | |
845 | // ------------------------------------------------------------------------ // |
846 | // 'caller' is replaced with inlined function body nodes and maybe IdentityN |
847 | // to keep it fetchable. |
848 | VLOG(3) << "Successfully inlined function call node: " << caller->name(); |
849 | g->RemoveNode(caller); |
850 | |
851 | VLOG(4) << "Final graph: " << g->ToGraphDefDebug().DebugString(); |
852 | |
853 | return OkStatus(); |
854 | } |
855 | |
856 | bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph, |
857 | const ExpandInlineFunctionsOptions& options) { |
858 | std::vector<std::pair<Node*, const FunctionBody*>> candidates; |
859 | |
860 | const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition(); |
861 | |
862 | for (Node* node : graph->nodes()) { |
863 | // Skip nodes that are not function calls or SymbolicGradient calls. |
864 | if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) { |
865 | continue; |
866 | } |
867 | // Skip function calls that marked noinline. |
868 | bool noinline; |
869 | if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) { |
870 | VLOG(3) << "noinline: " << SummarizeNode(*node); |
871 | continue; |
872 | } |
873 | FunctionLibraryRuntime::Handle handle; |
874 | Status s = InstantiateFunctionCall(node->def(), lib, &handle); |
875 | if (!s.ok()) { |
876 | LOG(ERROR) << "Failed to instantiate a function: " << s.error_message(); |
877 | continue; |
878 | } |
879 | const FunctionBody* fbody = lib->GetFunctionBody(handle); |
880 | CHECK_NOTNULL(fbody); |
881 | candidates.emplace_back(node, fbody); |
882 | } |
883 | |
884 | bool inlined_any = false; |
885 | for (const auto& p : candidates) { |
886 | Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second, |
887 | p.first->IsPartitionedCall() |
888 | ? options.multi_device_options |
889 | : options.native_options); |
890 | if (inlined.ok()) { |
891 | inlined_any = true; |
892 | } else { |
893 | VLOG(1) << "Failed to inline function call: node=" << p.first->name() |
894 | << " error=" << inlined.error_message(); |
895 | } |
896 | } |
897 | |
898 | // TODO(ezhulenev): Release handles for inlined function calls. |
899 | |
900 | return inlined_any; |
901 | } |
902 | |
903 | } // end namespace tensorflow |
904 | |