1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/lower_function_call_op.h" |
17 | |
18 | #include "absl/algorithm/container.h" |
19 | #include "tensorflow/core/common_runtime/function_def_utils.h" |
20 | #include "tensorflow/core/common_runtime/inline_function_utils.h" |
21 | #include "tensorflow/core/common_runtime/lower_function_call_inline_policy.h" |
22 | #include "tensorflow/core/framework/node_def_util.h" |
23 | #include "tensorflow/core/graph/graph.h" |
24 | #include "tensorflow/core/graph/graph_node_util.h" |
25 | #include "tensorflow/core/platform/errors.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode; |
30 | using OutputControlSrc = InlineFunctionBodyOptions::OutputControlSource; |
31 | |
32 | Status RewriteFunctionCallNode(Node* n, Graph* g, |
33 | const FunctionLibraryDefinition& flib_def, |
34 | bool keep_caller_fetchable) { |
35 | VLOG(2) << "Lower function call node: " << SummarizeNode(*n); |
36 | |
37 | // We support lowering of two types of functions that could be invoked by the |
38 | // node `n`: 1) native functions and 2) multi-device functions. |
39 | // NOTE(ezhulenev): We explicitly choose not to deal with SymbolicGradient, |
40 | // because it has been deprecated for a long time. |
41 | InlineFunctionBodyOptions inline_options; |
42 | inline_options.keep_caller_node = keep_caller_fetchable |
43 | ? KeepCallerNode::kFetchable |
44 | : KeepCallerNode::kTargetable; |
45 | |
46 | FunctionCallInlinePolicy policy = GetFunctionCallInlinePolicy(n); |
47 | if (policy == FunctionCallInlinePolicy::kMultiDevicePlacer) { |
48 | // Multi-device function calls (PartitionedCall or StatefulPartitionedCall |
49 | // ops) can execute on multiple devices and accept DT_RESOURCE inputs that |
50 | // belong to different devices. This type of functions was added in |
51 | // Tensorflow 2.0 Eager mode, and it has control outputs to represent |
52 | // side-effects that must always execute (see `control_ret` in FunctionDef). |
53 | inline_options.output_control_src = OutputControlSrc::kControlOutputs; |
54 | inline_options.inlined_function_body_placer = |
55 | InlinedFunctionBodyPlacer::MultiDevice(); |
56 | } else if (policy == FunctionCallInlinePolicy::kSingleDevicePlacer) { |
57 | // Native function call (node.type_string() is the function name). These |
58 | // functions are always executed on a single-device, which is the device of |
59 | // the function call node. |
60 | inline_options.output_control_src = OutputControlSrc::kDataOutputs; |
61 | inline_options.inlined_function_body_placer = |
62 | InlinedFunctionBodyPlacer::SingleDevice(); |
63 | } else { |
64 | return errors::InvalidArgument("Unsupported function inlining policy" ); |
65 | } |
66 | |
67 | const FunctionDef* fdef; |
68 | if (n->IsPartitionedCall()) { |
69 | NameAttrList func; |
70 | TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "f" , &func)); |
71 | fdef = flib_def.Find(func.name()); |
72 | } else if (n->type_string() == FunctionLibraryDefinition::kGradientOp) { |
73 | VLOG(2) << "Skip SymbolicGradient lowering" ; |
74 | return OkStatus(); |
75 | } else { |
76 | fdef = flib_def.Find(n->type_string()); |
77 | } |
78 | |
79 | if (fdef == nullptr) { |
80 | return errors::Internal("Can't find a function: node=" , SummarizeNode(*n)); |
81 | } |
82 | |
83 | std::unique_ptr<FunctionBody> fbody; |
84 | TF_RETURN_IF_ERROR( |
85 | FunctionDefToBodyHelper(*fdef, n->attrs(), &flib_def, &fbody)); |
86 | |
87 | Status can_inline_function_call = |
88 | ValidateInlining(n, fbody.get(), inline_options); |
89 | if (can_inline_function_call.ok()) { |
90 | TF_RETURN_IF_ERROR( |
91 | InlineFunctionBody(flib_def, g, n, fbody.get(), inline_options)); |
92 | } else { |
93 | VLOG(2) << "Failed to inline function call node: " |
94 | << can_inline_function_call.error_message(); |
95 | } |
96 | |
97 | return OkStatus(); |
98 | } |
99 | |
100 | } // namespace tensorflow |
101 | |