1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/lower_function_call_op.h"
17
18#include "absl/algorithm/container.h"
19#include "tensorflow/core/common_runtime/function_def_utils.h"
20#include "tensorflow/core/common_runtime/inline_function_utils.h"
21#include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h"
22#include "tensorflow/core/framework/node_def_util.h"
23#include "tensorflow/core/graph/graph.h"
24#include "tensorflow/core/graph/graph_node_util.h"
25#include "tensorflow/core/platform/errors.h"
26
27namespace tensorflow {
28
29using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
30using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource;
31
32Status RewriteFunctionCallNode(Node* n, Graph* g,
33 const FunctionLibraryDefinition& flib_def,
34 bool keep_caller_fetchable) {
35 VLOG(2) << "Lower function call node: " << SummarizeNode(*n);
36
37 // We support lowering of two types of functions that could be invoked by the
38 // node `n`: 1) native functions and 2) multi-device functions.
39 // NOTE(ezhulenev): We explicitly choose not to deal with SymbolicGradient,
40 // because it has been deprecated for a long time.
41 InlineFunctionBodyOptions inline_options;
42 inline_options.keep_caller_node = keep_caller_fetchable
43 ? KeepCallerNode::kFetchable
44 : KeepCallerNode::kTargetable;
45
46 FunctionCallInlinePolicy policy = GetFunctionCallInlinePolicy(n);
47 if (policy == FunctionCallInlinePolicy::kMultiDevicePlacer) {
48 // Multi-device function calls (PartitionedCall or StatefulPartitionedCall
49 // ops) can execute on multiple devices and accept DT_RESOURCE inputs that
50 // belong to different devices. This type of functions was added in
51 // Tensorflow 2.0 Eager mode, and it has control outputs to represent
52 // side-effects that must always execute (see `control_ret` in FunctionDef).
53 inline_options.output_control_src = OutputControlSrc::kControlOutputs;
54 inline_options.inlined_function_body_placer =
55 InlinedFunctionBodyPlacer::MultiDevice();
56 } else if (policy == FunctionCallInlinePolicy::kSingleDevicePlacer) {
57 // Native function call (node.type_string() is the function name). These
58 // functions are always executed on a single-device, which is the device of
59 // the function call node.
60 inline_options.output_control_src = OutputControlSrc::kDataOutputs;
61 inline_options.inlined_function_body_placer =
62 InlinedFunctionBodyPlacer::SingleDevice();
63 } else {
64 return errors::InvalidArgument("Unsupported function inlining policy");
65 }
66
67 const FunctionDef* fdef;
68 if (n->IsPartitionedCall()) {
69 NameAttrList func;
70 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "f", &func));
71 fdef = flib_def.Find(func.name());
72 } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) {
73 VLOG(2) << "Skip SymbolicGradient lowering";
74 return OkStatus();
75 } else {
76 fdef = flib_def.Find(n->type_string());
77 }
78
79 if (fdef == nullptr) {
80 return errors::Internal("Can't find a function: node=", SummarizeNode(*n));
81 }
82
83 std::unique_ptr<FunctionBody> fbody;
84 TF_RETURN_IF_ERROR(
85 FunctionDefToBodyHelper(*fdef, n->attrs(), &flib_def, &fbody));
86
87 Status can_inline_function_call =
88 ValidateInlining(n, fbody.get(), inline_options);
89 if (can_inline_function_call.ok()) {
90 TF_RETURN_IF_ERROR(
91 InlineFunctionBody(flib_def, g, n, fbody.get(), inline_options));
92 } else {
93 VLOG(2) << "Failed to inline function call node: "
94 << can_inline_function_call.error_message();
95 }
96
97 return OkStatus();
98}
99
100} // namespace tensorflow
101