1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/inline_function_utils.h"
17
18#include <deque>
19#include <vector>
20
21#include "absl/algorithm/container.h"
22#include "absl/memory/memory.h"
23#include "absl/strings/str_cat.h"
24#include "absl/strings/string_view.h"
25#include "tensorflow/core/common_runtime/device.h"
26#include "tensorflow/core/common_runtime/function_utils.h"
27#include "tensorflow/core/common_runtime/graph_constructor.h"
28#include "tensorflow/core/framework/collective.h"
29#include "tensorflow/core/framework/function.h"
30#include "tensorflow/core/framework/node_def.pb.h"
31#include "tensorflow/core/framework/node_def_util.h"
32#include "tensorflow/core/framework/op.h"
33#include "tensorflow/core/framework/op_kernel.h"
34#include "tensorflow/core/framework/versions.pb.h"
35#include "tensorflow/core/graph/algorithm.h"
36#include "tensorflow/core/graph/control_flow.h"
37#include "tensorflow/core/graph/node_builder.h"
38#include "tensorflow/core/graph/optimizer_cse.h"
39#include "tensorflow/core/lib/core/threadpool.h"
40#include "tensorflow/core/lib/gtl/map_util.h"
41#include "tensorflow/core/platform/macros.h"
42#include "tensorflow/core/profiler/lib/traceme.h"
43#include "tensorflow/core/protobuf/config.pb.h"
44#include "tensorflow/core/util/device_name_utils.h"
45
46namespace tensorflow {
47
48/*static*/ constexpr const char* const
49 LowerFunctionalOpsConstants::kLowerUsingSwitchMergeAttr;
50/*static*/ constexpr const char* const
51 LowerFunctionalOpsConstants::kLowerAsMultiDeviceFunctionAttr;
52
53namespace {
54// A few string constant used throughout this module.
55static constexpr const char* const kArgOp = FunctionLibraryDefinition::kArgOp;
56static constexpr const char* const kDeviceArgOp =
57 FunctionLibraryDefinition::kDeviceArgOp;
58static constexpr const char* const kRetOp = FunctionLibraryDefinition::kRetOp;
59static constexpr const char* const kDeviceRetOp =
60 FunctionLibraryDefinition::kDeviceRetOp;
61static constexpr const char* const kGradientOp =
62 FunctionLibraryDefinition::kGradientOp;
63static constexpr const char* const kNodeLabel = "Func";
64static constexpr const char* const kFuncAttr =
65 FunctionLibraryDefinition::kFuncAttr;
66
67// Represents the index-th output of a node.
68struct Endpoint {
69 Node* node;
70 int index;
71
72 // Returns the string name represents this endpoint.
73 string name() const {
74 if (index == 0) {
75 return node->name();
76 } else {
77 return strings::StrCat(node->name(), ":", index);
78 }
79 }
80
81 DataType dtype() const { return node->output_type(index); }
82};
83
84struct EndpointHash {
85 uint64 operator()(const Endpoint& x) const {
86 return Hash64(reinterpret_cast<const char*>(&x.node), sizeof(Node*),
87 x.index);
88 }
89};
90
91struct EndpointEq {
92 bool operator()(const Endpoint& x, const Endpoint& y) const {
93 return (x.node == y.node) && (x.index == y.index);
94 }
95};
96
97// The following Add* routines are used to add a few graph nodes while
98// functions are transformed.
99static Node* AddNoOp(StringPiece name, Graph* g) {
100 NodeDef ndef;
101 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
102 ndef.set_op("NoOp");
103 Status s;
104 Node* ret = g->AddNode(ndef, &s);
105 TF_CHECK_OK(s);
106 return ret;
107}
108
109static Node* AddIdentity(StringPiece name, Graph* g, Endpoint input) {
110 DCHECK_LT(0, input.dtype());
111 NodeDef ndef;
112 ndef.set_name(g->NewName(absl::StrCat(kNodeLabel, "/", name)));
113 ndef.set_op("Identity");
114 ndef.add_input(input.name());
115 AddNodeAttr("T", BaseType(input.dtype()), &ndef);
116 Status s;
117 Node* ret = g->AddNode(ndef, &s);
118 TF_CHECK_OK(s);
119 g->AddEdge(input.node, input.index, ret, 0);
120 return ret;
121}
122
123std::vector<string> InputDevices(const Node& caller) {
124 std::vector<string> input_devices(caller.in_edges().size());
125 std::vector<string> input_tensors(caller.in_edges().size());
126
127 for (const Edge* edge : caller.in_edges()) {
128 if (edge->IsControlEdge()) continue;
129 const string& input_device = edge->src()->has_assigned_device_name()
130 ? edge->src()->assigned_device_name()
131 : edge->src()->requested_device();
132 input_devices[edge->dst_input()] = input_device;
133 input_tensors[edge->dst_input()] =
134 absl::StrCat(edge->src()->name(), ":", edge->src_output());
135 }
136
137 if (VLOG_IS_ON(4)) {
138 VLOG(4) << "Function instantiation input devices:";
139 for (int i = 0; i < input_devices.size(); ++i) {
140 if (input_tensors[i].empty()) continue; // skip control edges
141 VLOG(4) << " [index " << i << "]"
142 << " device: " << input_devices[i]
143 << " (input: " << input_tensors[i] << ")";
144 }
145 }
146
147 return input_devices;
148}
149
150// Place input nodes on the same device as the corresponding caller input
151// node. Do not specify any placement for all other nodes.
152class DefaultFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
153 public:
154 explicit DefaultFunctionBodyPlacer(const Node& caller)
155 : input_devices_(InputDevices(caller)) {}
156
157 absl::optional<string> InputNodeDevice(int input_index) const override {
158 return input_devices_[input_index];
159 }
160 absl::optional<string> OutputNodeDevice(int output_index) const override {
161 return absl::nullopt;
162 }
163 bool ColocateInputOutputIdentities() const override { return false; }
164 absl::optional<string> ControlNodeDevice() const override {
165 return absl::nullopt;
166 }
167 absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
168 return absl::nullopt;
169 }
170
171 private:
172 const std::vector<string> input_devices_;
173};
174
175// Place all nodes on the same device as caller node.
176class SingleDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
177 public:
178 explicit SingleDeviceFunctionBodyPlacer(const Node& caller)
179 : caller_device_(caller.def().device()) {}
180
181 absl::optional<string> InputNodeDevice(int input_index) const override {
182 return caller_device_;
183 }
184 absl::optional<string> OutputNodeDevice(int output_index) const override {
185 return caller_device_;
186 }
187 bool ColocateInputOutputIdentities() const override { return false; }
188 absl::optional<string> ControlNodeDevice() const override {
189 return caller_device_;
190 }
191 absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
192 return caller_device_;
193 }
194
195 private:
196 const string caller_device_;
197};
198
199// Place input nodes on the same device as the corresponding caller input
200// node. Do not place output node. Place control nodes on the same device as
201// caller node. For all function body nodes overrides job, replica and task
202// parts of the device assignment to match function caller node.
203class MultiDeviceFunctionBodyPlacer : public InlinedFunctionBodyPlacer {
204 public:
205 explicit MultiDeviceFunctionBodyPlacer(const Node& caller)
206 : caller_device_(caller.def().device()),
207 input_devices_(InputDevices(caller)) {
208 has_parsed_caller_device_ =
209 DeviceNameUtils::ParseFullName(caller_device_, &caller_parsed_device_);
210 }
211
212 absl::optional<string> InputNodeDevice(int input_index) const override {
213 return input_devices_[input_index];
214 }
215 absl::optional<string> OutputNodeDevice(int output_index) const override {
216 return absl::nullopt;
217 }
218 bool ColocateInputOutputIdentities() const override { return true; }
219 absl::optional<string> ControlNodeDevice() const override {
220 return caller_device_;
221 }
222 absl::optional<string> BodyNodeDevice(const NodeDef& ndef) const override {
223 // LINT.IfChange
224 // TODO(ezhulenev): If function would have been instantiated as a
225 // multi-device function and executed via FunctionLibraryRuntime, it could
226 // be potentially placed on any available device. However there are multiple
227 // tests relying on this assumption. Fix them, and remove this line.
228 if (ndef.device().empty()) return caller_device_;
229
230 if (!has_parsed_caller_device_) return ndef.device();
231
232 DeviceNameUtils::ParsedName ndef_parsed_device;
233 if (!DeviceNameUtils::ParseFullName(ndef.device(), &ndef_parsed_device))
234 return ndef.device();
235
236 DeviceNameUtils::MergeUnsetDevNames(&ndef_parsed_device,
237 caller_parsed_device_);
238 return DeviceNameUtils::ParsedNameToString(ndef_parsed_device);
239 // LINT.ThenChange(../../compiler/mlir/tensorflow/ir/tf_ops.cc)
240 }
241
242 private:
243 string caller_device_;
244 bool has_parsed_caller_device_;
245 DeviceNameUtils::ParsedName caller_parsed_device_;
246 std::vector<string> input_devices_;
247};
248
249} // namespace
250
251std::unique_ptr<InlinedFunctionBodyPlacer>
252InlinedFunctionBodyPlacer::DefaultPlacer(const Graph& graph,
253 const Node& caller) {
254 VLOG(3) << "Create default placer for inlined function body.";
255 return std::make_unique<DefaultFunctionBodyPlacer>(caller);
256}
257
258std::unique_ptr<InlinedFunctionBodyPlacer>
259InlinedFunctionBodyPlacer::SingleDevicePlacer(const Graph& graph,
260 const Node& caller) {
261 VLOG(3) << "Create single device placer for inlined function body.";
262 return std::make_unique<SingleDeviceFunctionBodyPlacer>(caller);
263}
264
265std::unique_ptr<InlinedFunctionBodyPlacer>
266InlinedFunctionBodyPlacer::MultiDevicePlacer(const Graph& graph,
267 const Node& caller) {
268 VLOG(3) << "Create multi device placer for inlined function body.";
269 return std::make_unique<MultiDeviceFunctionBodyPlacer>(caller);
270}
271
272namespace {
273
274Status ValidateNoInline(const FunctionBody* fbody) {
275 const auto attr = AttrSlice(&fbody->fdef.attr());
276 bool noinline = false;
277 if (TryGetNodeAttr(attr, kNoInlineAttr, &noinline) && noinline) {
278 return errors::InvalidArgument(
279 "Can't inline function marked with '_noinline'");
280 }
281 return OkStatus();
282}
283
284using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
285
286// Propagate the debug info of `nodes` in function `func` to the `target` node.
287// If the debug info of any node is missing, its node name and function name
288// is used.
289void PropagateDebugInfoToNode(const string& func,
290 const std::vector<const Node*>& nodes,
291 NodeDef* target) {
292 if (nodes.empty() || target->has_experimental_debug_info()) {
293 return;
294 }
295 for (const Node* node : nodes) {
296 const auto& node_def = node->def();
297 if (node_def.has_experimental_debug_info()) {
298 target->mutable_experimental_debug_info()->MergeFrom(
299 node_def.experimental_debug_info());
300 } else {
301 target->mutable_experimental_debug_info()->add_original_node_names(
302 node_def.name());
303 target->mutable_experimental_debug_info()->add_original_func_names(func);
304 }
305 }
306}
307} // namespace
308
309string InlineFunctionBodyOptions::DebugString() const {
310 const auto true_false = [](bool b) { return b ? "true" : "false"; };
311
312 const auto keep_caller_node_str = [this]() -> string {
313 switch (keep_caller_node) {
314 case KeepCallerNode::kDoNotKeep:
315 return "DoNotKeep";
316 case KeepCallerNode::kFetchable:
317 return "Fetchable";
318 case KeepCallerNode::kTargetable:
319 return "Targetable";
320 }
321 };
322
323 return absl::StrCat(
324 "disable_inlining=", true_false(disable_inlining),
325 ", ignore_noinline=", true_false(ignore_noinline),
326 ", inline_impl_selection_group_functions=",
327 true_false(inline_impl_selection_group_functions),
328 ", keep_caller_node=", keep_caller_node_str(), ", output_control_src=",
329 output_control_src == OutputControlSrc::kDataOutputs ? "DataOutputs"
330 : "ControlOutputs",
331 ", inlined_function_body_placer=", inlined_function_body_placer.name,
332 ", uniquify_frame_names=", true_false(uniquify_frame_names));
333}
334
335Status ValidateInlining(const Node* node, const FunctionBody* fbody,
336 const InlineFunctionBodyOptions& options) {
337 // TODO(ezhulenev): Currently common_runtime function inlining can't guarantee
338 // that all side-effectful ops will be executed after inlining. See Grappler
339 // function_optimizer for details. Unify all function inlining mechanism.
340 // Do not inline if `!fbody->control_ret_nodes.empty()`.
341
342 const auto num_node_inputs = static_cast<size_t>(node->num_inputs());
343 const auto num_node_outputs = static_cast<size_t>(node->num_outputs());
344
345 if (num_node_inputs != fbody->arg_types.size() ||
346 num_node_inputs != fbody->arg_nodes.size()) {
347 return errors::InvalidArgument(
348 "Node inputs do not match function arguments: inputs=", num_node_inputs,
349 " arg_types=", fbody->arg_types.size(),
350 " arg_nodes=", fbody->arg_nodes.size());
351 }
352
353 if (num_node_outputs != fbody->ret_types.size() ||
354 num_node_outputs != fbody->ret_nodes.size()) {
355 return errors::InvalidArgument(
356 "Node outputs do not match function returns: outputs=",
357 num_node_outputs, " ret_types=", fbody->ret_types.size(),
358 " ret_nodes=", fbody->ret_nodes.size());
359 }
360
361 for (int i = 0; i < node->num_inputs(); ++i) {
362 if (node->input_type(i) != fbody->arg_types[i]) {
363 return errors::InvalidArgument(
364 "Node input type doesn't match function argument type: ",
365 node->input_type(i), " != ", fbody->arg_types[i], " @ index=", i);
366 }
367 }
368 for (int i = 0; i < node->num_outputs(); ++i) {
369 if (node->output_type(i) != fbody->ret_types[i]) {
370 return errors::InvalidArgument(
371 "Node output type doesn't match function return type: ",
372 node->output_type(i), " != ", fbody->ret_types[i], " @ index=", i);
373 }
374 }
375
376 if (options.disable_inlining) {
377 return errors::InvalidArgument(
378 "Function inlining explicitly disabled by 'options.disable_inlining'");
379 }
380
381 if (!options.inline_impl_selection_group_functions) {
382 bool is_impl_selection_group_function =
383 fbody->fdef.attr().find("api_implements") != fbody->fdef.attr().end();
384 if (is_impl_selection_group_function) {
385 return errors::InvalidArgument(
386 "Inlining of implementation selection group function ",
387 fbody->fdef.signature().name(),
388 " is disabled by options.inline_impl_selection_group_functions");
389 }
390 }
391
392 if (!options.ignore_noinline) {
393 TF_RETURN_IF_ERROR(ValidateNoInline(fbody));
394 }
395
396 return OkStatus();
397}
398
399// Function inlining must preserve function execution semantics with regards to
400// side-effects visibility. Tensorflow in Eager mode has an automatic control
401// dependencies tracking mechanism, which enforces well-defined execution order
402// of all side-effects. Any other frontend (e.g. Swift) must produce graphs
403// following the same rules, to ensure that function inlining works correctly.
404//
405// IMPORTANT: Currently we do not have a true notion of "side-effectful" node,
406// we assume that all stateful nodes might have side-effects, though it's not
407// true in practice, e.g. `ReadVariableOp` doesn't have an observable
408// side-effect.
409//
410// Automatic control dependency rules in Tensorflow 2.0 (python in eager mode):
411//
412// 1) When a function has a resource (DT_RESOURCE data type) input argument it
413// "captures" the mutable resource. This is implemented by automatically
414// adding a incoming control edge from the previous side-effectful op
415// touching that resource, and an outgoing control edge to the next
416// side-effectful op using the same resource. This serializes the mutations
417// of the resource to make graph execution deterministic.
418//
419// 2) All stateful ops inside a function body are guaranteed to execute in
420// program order, this is achieved by adding control edges between stateful
421// ops at graph construction time. Stateful ops (or ops that must execute)
422// should be in the function control return set. Having a data edge to the
423// regular function output might be not enough, because after function
424// inlining it might happen that data output is unused.
425//
426// 3) Furthermore, all ops accepting the same resource as an input are
427// guaranteed to run in program order. This is also done by adding control
428// edges at graph construction time. The last op touching the resource
429// must be in a control return set, which will guarantee that all side
430// effects to the resource will happen before function completion.
431//
432// Function inlining must preserve side-effect visibility:
433//
434// 1) All side-effects to the captured resources, that happened before function
435// call must be visible to the function body nodes using that resources.
436//
437// 2) All side-effects to the captured resources, that happened inside function
438// body, must be visible to every op/function using that resource after the
439// function call completed.
440//
441// To guarantee that these properties are preserved after inlining we:
442//
443// 1) Create "input_control_node" NoOp. Function call node incoming control
444// edges will be forwarded *to* this node. Function inputs (Identity nodes)
445// will have a control edge *from* this node. If function body has nodes
446// without inputs, they will have a control edge *from* this node.
447//
448// 2) Create "output_control_node" NoOp. All nodes that have incoming control
449// edge *from* the function call node, will be forwarded to this node.
450//
451// We have two options for choosing which nodes will have a control edge *to*
452// the "output control node":
453// a) control returns (`control_ret` field in FunctionDef)
454// b) data returns (`ret` field in FunctionDef)
455//
456// We do a) for multi-device function calls in Tensorflow v2 and b)
457// for the rest for compatibility with Tensorflow v1.
458//
459// Following the automatic control dependencies tracking rules, a node that
460// has an incoming control edge from the function call node is dependent on
461// the side-effects happening inside the function body. The output control
462// node will guarantee side-effects execution order.
463//
464// If function call node doesn't have an outgoing control edge, it means that
465// no one is interested in observing side-effects that might have happened.
466//
467// Function inlining might leave the graph in partially-placed state. Function
468// inlining caller must call Placer to guarantee that all nodes are placed.
469//
470// Function inlining with `options.override_device=true` will leave graph in
471// fully placed state, by overriding all inlined nodes devices with the caller
472// node device, but it will make functions always single-device. These functions
473// after inlining will not be able to handle resources on multiple devices. This
474// is currently acceptable for XLA use cases (XLA cluster is always executed on
475// a single device).
476//
477// TODO(ezhulenev): Documentation above is ahead of implementation below.
478Status InlineFunctionBody(const FunctionLibraryDefinition& flib_def, Graph* g,
479 Node* caller, const FunctionBody* fbody,
480 const InlineFunctionBodyOptions& options) {
481 VLOG(3) << "Inline function call: " << SummarizeNode(*caller) << " ["
482 << options.DebugString() << "]";
483 VLOG(4) << "Inlining function: " << fbody->fdef.DebugString();
484 VLOG(4) << "Current graphdef: " << g->ToGraphDefDebug().DebugString();
485 VLOG(4) << "Caller: " << caller->DebugString();
486
487 Status validation = ValidateInlining(caller, fbody, options);
488 if (!validation.ok()) {
489 return errors::Internal("Inlining mismatch: ", validation.error_message());
490 }
491
492 // Placer is responsible for assigning devices for all nodes that we will add
493 // to the graph.
494 const std::unique_ptr<InlinedFunctionBodyPlacer> placer =
495 options.inlined_function_body_placer.get(*g, *caller);
496
497 // We can't possibly introduce a duplicate control edge during function
498 // inlining, so we skip this check in calls to the 'g->AddControlEdge(...)'.
499 static constexpr bool kDoNotCheckDuplicates = true;
500
501 // ------------------------------------------------------------------------ //
502 // Helper functions to create `NoOp` and `Identity` nodes for auxiliary
503 // control nodes and inlined function inputs and outputs.
504
505 // Add a NoOp node for function control inputs/outputs.
506 const auto no_op = [&](StringPiece name) -> Node* {
507 Node* node = AddNoOp(absl::StrCat(caller->name(), "/", name), g);
508 const absl::optional<string> device = placer->ControlNodeDevice();
509 if (device.has_value()) node->set_requested_device(*device);
510 return node;
511 };
512
513 // Add an Identity node for function input.
514 const auto input_identity = [&](StringPiece name, Endpoint input,
515 int index) -> Node* {
516 Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
517 const absl::optional<string> device = placer->InputNodeDevice(index);
518 if (device.has_value()) node->set_requested_device(*device);
519 bool colocate_identity = placer->ColocateInputOutputIdentities();
520 if (colocate_identity) {
521 node->AddAttr(kColocationAttrName,
522 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
523 input.node->name())});
524 }
525 return node;
526 };
527
528 // Add an Identity node for function output.
529 const auto output_identity = [&](StringPiece name, Endpoint input,
530 int index) -> Node* {
531 Node* node = AddIdentity(absl::StrCat(caller->name(), "/", name), g, input);
532 const absl::optional<string> device = placer->OutputNodeDevice(index);
533 if (device.has_value()) node->set_requested_device(*device);
534 bool colocate_identity = placer->ColocateInputOutputIdentities();
535 if (colocate_identity) {
536 node->AddAttr(kColocationAttrName,
537 std::vector<string>{absl::StrCat(kColocationGroupPrefix,
538 input.node->name())});
539 }
540 return node;
541 };
542
543 // ------------------------------------------------------------------------ //
544 // Helper function to get an input/output argument name by index. For
545 // functions instantiated from SymbolicGradien corresponding FunctionDef is
546 // empty, and argument name is unknown.
547
548 auto arg_name = [&](auto& args, size_t i) -> absl::string_view {
549 if (i < args.size()) {
550 return args[i].name();
551 } else {
552 return "<unknown>";
553 }
554 };
555
556 // ------------------------------------------------------------------------ //
557 // Input edges. For data edges coming into "caller", we first compute the
558 // <src>:<src_output> for the i-th input in "inputs".
559 // If "caller" has any input control dependencies, we add a NoOp
560 // node "input_control_node", which depends on "caller"'s control inputs.
561 std::vector<Endpoint> inputs(caller->num_inputs());
562 Node* input_control_node = nullptr;
563 for (const Edge* e : caller->in_edges()) {
564 if (e->IsControlEdge()) {
565 if (input_control_node == nullptr) {
566 input_control_node = no_op("input_control_node");
567 }
568 g->AddControlEdge(e->src(), input_control_node, kDoNotCheckDuplicates);
569 } else {
570 inputs[e->dst_input()] = {e->src(), e->src_output()};
571 }
572 }
573 if (input_control_node != nullptr) {
574 VLOG(3) << "Created input control node: " << input_control_node->name();
575 }
576
577 // We create one Identity node for each input.
578 std::vector<Node*> input_nodes;
579 std::map<absl::string_view, absl::string_view> input_node_name_map;
580 for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
581 Node* n = input_identity("input", inputs[i], i);
582 input_node_name_map[arg_name(fbody->fdef.signature().input_arg(), i)] =
583 n->name();
584 input_nodes.push_back(n);
585 }
586
587 // ------------------------------------------------------------------------ //
588 // Duplicate fbody->graph into 'g'. First, we copy the nodes of
589 // fbody->graph into 'g' except the source and sink nodes. We copy
590 // edges among nodes in 'fbody->graph'.
591 //
592 // If 'x' is a node in fbody->graph and its copy in 'g' is 'y', we
593 // remember 'y' in node_map[x->id()].
594 std::unordered_set<string> fn_nodes;
595 for (Node* n : fbody->graph->op_nodes()) {
596 fn_nodes.insert(n->name());
597 }
598 std::vector<Node*> node_map(fbody->graph->num_node_ids());
599 for (Node* n : fbody->graph->op_nodes()) {
600 NodeDef ndef = n->def();
601
602 // Maybe override requested node device assignment.
603 const absl::optional<string> device = placer->BodyNodeDevice(ndef);
604 if (device.has_value()) ndef.set_device(*device);
605
606 // Add inlined function name to inlined node debug information.
607 PropagateDebugInfoToNode(fbody->fdef.signature().name(), {n}, &ndef);
608
609 // Add the function node name as a prefix:
610 // 1) to node name to avoid collisions
611 // 2) to frame name to avoid multiple LoopCond nodes in one frame
612 // 3) to colocation attribute
613 const string prefix = strings::StrCat(caller->name(), "/");
614 TF_RETURN_IF_ERROR(AddPrefixAndSuffixToNode(prefix, /*suffix=*/"", &ndef,
615 options.uniquify_frame_names));
616
617 // If the colocation attribute is an input arg, we need to change it to the
618 // new input (Identity) node now.
619 TF_RETURN_IF_ERROR(
620 MaybeUpdateColocationConstraintsWithMap(input_node_name_map, &ndef));
621
622 TF_RETURN_IF_ERROR(
623 MaybeAddPrefixToColocationConstraints(fn_nodes, prefix, &ndef));
624
625 Status added_node;
626 Node* clone = g->AddNode(std::move(ndef), &added_node);
627 TF_CHECK_OK(added_node);
628 node_map[n->id()] = clone;
629 clone->SetStackTrace(n->GetStackTrace());
630
631 // If there is an input control node, and one of:
632 // a) the node has no data or control inputs, or
633 // b) the node is a function call (including SymbolicGradient),
634 // then add a control edge from the input control node to the clone (only
635 // if it does not already have a control input).
636 //
637 // We must not execute any nodes if the original function call would not
638 // have executed. This is especially critical when the function call is
639 // inside a control-flow construct like tf.cond(). Case (a) ensures that
640 // such nodes do not run.
641 //
642 // The purpose of case (b) is to ensure that instances of case (a) created
643 // by further inlining steps also receive the control dependency.
644 //
645 // This edge is required to transfer execution frame down to all function
646 // body nodes of inlined nested function calls.
647 if (input_control_node) {
648 const auto is_input_edge = [](const Edge* e) -> bool {
649 return !e->src()->IsSource();
650 };
651 const auto is_control_edge = [](const Edge* e) -> bool {
652 return !e->src()->IsSource() && e->IsControlEdge();
653 };
654
655 // Forward execution frame if:
656 //
657 // a) The node has no data or control inputs.
658 // b) OR the node is a function call without control inputs (control edge
659 // will be used in nested function inlining to forward execution frame
660 // to constants inside the function body).
661 //
662 // c) Do not forward control frame to function argument nodes, they will
663 // be connected to the corresponding function input later.
664 const bool forward_execution_frame =
665 (absl::c_none_of(n->in_edges(), is_input_edge) || // (a)
666 (n->IsFunctionCall() && // (b)
667 absl::c_none_of(n->in_edges(), is_control_edge))) && //
668 !n->IsArg(); // (c)
669
670 if (forward_execution_frame) {
671 VLOG(4) << "Add control edge from input control node to: "
672 << clone->name();
673 g->AddControlEdge(input_control_node, clone, kDoNotCheckDuplicates);
674 }
675 }
676 }
677 for (const Edge* e : fbody->graph->edges()) {
678 if (e->src()->IsSource() || e->src()->IsSink() || e->dst()->IsSource() ||
679 e->dst()->IsSink()) {
680 continue;
681 }
682 Node* src_copy = node_map[e->src()->id()];
683 Node* dst_copy = node_map[e->dst()->id()];
684 g->AddEdge(src_copy, e->src_output(), dst_copy, e->dst_input());
685 }
686
687 // ------------------------------------------------------------------------ //
688 // Connect input edges.
689 //
690 // Then, we connect inputs[i] to the i-th identity node added. The nodes that
691 // previously connected to the j-th output of i-th arg node are reconnected
692 // to the i-th identity node.
693 //
694 // The added identity nodes depend on "input_control_node".
695 VLOG(4) << "Add input Identity nodes for each function argument:";
696
697 for (std::size_t i = 0; i < fbody->arg_nodes.size(); ++i) {
698 Node* arg = node_map[fbody->arg_nodes[i]->id()];
699 Node* n = input_nodes[i];
700 VLOG(4) << " [index " << i << "] "
701 << arg_name(fbody->fdef.signature().input_arg(), i) << " as "
702 << n->name() << " (input: " << inputs[i].name()
703 << ", requested_device: " << n->requested_device() << ")";
704
705 if (input_control_node) {
706 g->AddControlEdge(input_control_node, n, kDoNotCheckDuplicates);
707 }
708 for (const Edge* e : arg->out_edges()) {
709 if (e->IsControlEdge()) {
710 g->AddControlEdge(n, e->dst(), kDoNotCheckDuplicates);
711 } else {
712 g->AddEdge(n, 0, e->dst(), e->dst_input());
713 }
714 }
715 node_map[fbody->arg_nodes[i]->id()] = n;
716 g->RemoveNode(arg); // 'arg' is disconnected.
717 }
718
719 // ------------------------------------------------------------------------ //
720 // Connect output edges.
721 //
722 // For i-th return node in fbody->graph, we add in "g" an identity node
723 // (outputs[i-th]). We then reconnect every incoming edge into the i-th return
724 // node to the added identity node.
725 //
726 // For every data edge coming out of "callee"s i-th output, we reconnect it to
727 // the i-th identity added above.
728 //
729 // If "callee" is control-depended upon by any other nodes, we add a NoOp node
730 // "output_control_node". "output_control_node" depends on all identity nodes
731 // added above or on all control return nodes (controlled by
732 // `options.output_control_src` value). And nodes previously depend on
733 // "callee" is changed to depend on "output_control_node".
734 //
735 // If `keep_node_fetchable` is `true` we always add an output control node, to
736 // guarantee that executing a fetchable node will execute all side-effects.
737 VLOG(4) << "Add output Identity nodes for each function output argument:";
738
739 std::vector<Node*> outputs(caller->num_outputs());
740 for (std::size_t i = 0; i < fbody->ret_nodes.size(); ++i) {
741 Node* ret = node_map[fbody->ret_nodes[i]->id()];
742 Endpoint data; // Data input for the ret node.
743 for (const Edge* e : ret->in_edges()) {
744 if (!e->IsControlEdge()) {
745 data = {e->src(), e->src_output()};
746 break;
747 }
748 }
749 CHECK(data.node != nullptr);
750 Node* n = output_identity("output", data, i);
751 outputs[i] = n;
752 VLOG(4) << " [index " << i << "] "
753 << arg_name(fbody->fdef.signature().output_arg(), i) << " as "
754 << n->name() << " (ret: " << data.node->name() << ":" << data.index
755 << ", requested_device: " << n->requested_device() << ")";
756 for (const Edge* e : ret->in_edges()) {
757 if (e->IsControlEdge()) {
758 g->AddControlEdge(e->src(), n, kDoNotCheckDuplicates);
759 }
760 }
761 g->RemoveNode(ret); // 'ret' is disconnected.
762 }
763
764 Node* output_control_node = nullptr;
765 const bool has_control_outputs = absl::c_any_of(
766 caller->out_edges(), [](const Edge* e) { return e->IsControlEdge(); });
767
768 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
769 const bool keep_caller_node =
770 options.keep_caller_node == KeepCallerNode::kFetchable ||
771 options.keep_caller_node == KeepCallerNode::kTargetable;
772
773 if (has_control_outputs || keep_caller_node) {
774 output_control_node = no_op("output_control_node");
775 VLOG(4) << "Add output control node: " << output_control_node->name();
776 if (options.output_control_src == OutputControlSrc::kDataOutputs) {
777 for (Node* n : outputs) {
778 VLOG(4) << " [data output] add control edge from: " << n->name();
779 g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
780 }
781 } else {
782 for (Node* fbody_node : fbody->control_ret_nodes) {
783 Node* n = node_map[fbody_node->id()];
784 VLOG(4) << " [control output] add control edge from: " << n->name();
785 g->AddControlEdge(n, output_control_node, kDoNotCheckDuplicates);
786 }
787 }
788 }
789
790 // We can't leave output control node without incoming control edges, because
791 // in this case outgoing control edge will loose execution frame information.
792 // We connect input_control_node and output_control_node with a control edge
793 // to forward execution frame to the controlled nodes. Above we add a control
794 // edge to all function calls inside function body, to guarantee that we will
795 // always have input_control_node when we need it.
796 if (output_control_node && output_control_node->in_edges().empty()) {
797 if (input_control_node) {
798 VLOG(4) << "Add a control edge between input and output control nodes: "
799 << input_control_node->name() << " to "
800 << output_control_node->name();
801 g->AddControlEdge(input_control_node, output_control_node,
802 kDoNotCheckDuplicates);
803 } else {
804 VLOG(4) << "Function inlining potentially dropped execution frame "
805 "information from outgoing control edges.";
806 }
807 }
808
809 for (const Edge* e : caller->out_edges()) {
810 if (e->IsControlEdge()) {
811 g->AddControlEdge(output_control_node, e->dst(), kDoNotCheckDuplicates);
812 } else {
813 g->AddEdge(outputs[e->src_output()], 0, e->dst(), e->dst_input());
814 }
815 }
816
817 // ------------------------------------------------------------------------ //
818 // Add an IdentityN or NoOp node in-place of caller node to keep `caller`
819 // fetchable or targetable.
820
821 if (keep_caller_node) {
822 std::vector<NodeBuilder::NodeOut> output_tensors;
823 absl::c_transform(outputs, std::back_inserter(output_tensors),
824 [](Node* n) { return NodeBuilder::NodeOut(n, 0); });
825
826 Node* caller_substitute_node;
827 if (options.keep_caller_node == KeepCallerNode::kTargetable ||
828 output_tensors.empty()) {
829 // IdentityN node must have at least one data input. If function has no
830 // data outputs, we can't keep it fetchable.
831 TF_CHECK_OK(NodeBuilder(caller->name(), "NoOp")
832 .Device(caller->requested_device())
833 .ControlInput(output_control_node)
834 .Finalize(g, &caller_substitute_node));
835
836 } else if (options.keep_caller_node == KeepCallerNode::kFetchable) {
837 TF_CHECK_OK(NodeBuilder(caller->name(), "IdentityN")
838 .Device(caller->requested_device())
839 .Input(output_tensors)
840 .ControlInput(output_control_node)
841 .Finalize(g, &caller_substitute_node));
842 }
843 }
844
845 // ------------------------------------------------------------------------ //
846 // 'caller' is replaced with inlined function body nodes and maybe IdentityN
847 // to keep it fetchable.
848 VLOG(3) << "Successfully inlined function call node: " << caller->name();
849 g->RemoveNode(caller);
850
851 VLOG(4) << "Final graph: " << g->ToGraphDefDebug().DebugString();
852
853 return OkStatus();
854}
855
856bool ExpandInlineFunctions(FunctionLibraryRuntime* lib, Graph* graph,
857 const ExpandInlineFunctionsOptions& options) {
858 std::vector<std::pair<Node*, const FunctionBody*>> candidates;
859
860 const FunctionLibraryDefinition* fld = lib->GetFunctionLibraryDefinition();
861
862 for (Node* node : graph->nodes()) {
863 // Skip nodes that are not function calls or SymbolicGradient calls.
864 if (!IsFunctionCall(*lib->GetFunctionLibraryDefinition(), *node)) {
865 continue;
866 }
867 // Skip function calls that marked noinline.
868 bool noinline;
869 if (fld->GetAttr(*node, kNoInlineAttr, &noinline).ok() && noinline) {
870 VLOG(3) << "noinline: " << SummarizeNode(*node);
871 continue;
872 }
873 FunctionLibraryRuntime::Handle handle;
874 Status s = InstantiateFunctionCall(node->def(), lib, &handle);
875 if (!s.ok()) {
876 LOG(ERROR) << "Failed to instantiate a function: " << s.error_message();
877 continue;
878 }
879 const FunctionBody* fbody = lib->GetFunctionBody(handle);
880 CHECK_NOTNULL(fbody);
881 candidates.emplace_back(node, fbody);
882 }
883
884 bool inlined_any = false;
885 for (const auto& p : candidates) {
886 Status inlined = InlineFunctionBody(*fld, graph, p.first, p.second,
887 p.first->IsPartitionedCall()
888 ? options.multi_device_options
889 : options.native_options);
890 if (inlined.ok()) {
891 inlined_any = true;
892 } else {
893 VLOG(1) << "Failed to inline function call: node=" << p.first->name()
894 << " error=" << inlined.error_message();
895 }
896 }
897
898 // TODO(ezhulenev): Release handles for inlined function calls.
899
900 return inlined_any;
901}
902
903} // end namespace tensorflow
904