1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/function_def_utils.h" |
17 | |
18 | #include <vector> |
19 | |
20 | #include "tensorflow/core/common_runtime/function_body.h" |
21 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
22 | #include "tensorflow/core/framework/function.h" |
23 | #include "tensorflow/core/framework/node_def_util.h" |
24 | #include "tensorflow/core/graph/control_flow.h" |
25 | #include "tensorflow/core/graph/graph.h" |
26 | |
27 | namespace tensorflow { |
28 | |
29 | Status FunctionDefToBodyHelper( |
30 | const FunctionDef& fdef, const AttrSlice& attrs, |
31 | const FunctionLibraryDefinition* const lib_def, |
32 | const std::function<Status(const string&, const OpDef**)>& get_func_sig, |
33 | std::unique_ptr<FunctionBody>* fbody) { |
34 | // Instantiates the function template into a graph def. |
35 | InstantiationResult result; |
36 | TF_RETURN_IF_ERROR(InstantiateFunction(fdef, attrs, get_func_sig, &result)); |
37 | |
38 | auto graph = std::make_unique<Graph>(lib_def); |
39 | |
40 | auto construction_context_iter = fdef.attr().find("_construction_context" ); |
41 | if (construction_context_iter != fdef.attr().end()) { |
42 | if (construction_context_iter->second.s() == "kEagerRuntime" ) { |
43 | graph->SetConstructionContext(ConstructionContext::kEagerRuntime); |
44 | } else { |
45 | DCHECK(false) << "Unknown _construction_context attribute: " |
46 | << construction_context_iter->second.s(); |
47 | } |
48 | } |
49 | |
50 | GraphConstructorOptions opts; |
51 | opts.allow_internal_ops = true; |
52 | opts.expect_device_spec = false; |
53 | TF_RETURN_IF_ERROR(ConvertNodeDefsToGraph(opts, result.nodes, graph.get())); |
54 | |
55 | const StackTracesMap& stack_traces = |
56 | lib_def->GetStackTraces(fdef.signature().name()); |
57 | for (Node* n : graph->nodes()) { |
58 | if (n) { |
59 | auto it = stack_traces.find(n->name()); |
60 | if (it != stack_traces.end()) { |
61 | n->SetStackTrace(it->second); |
62 | } |
63 | } |
64 | } |
65 | |
66 | // Call BuildControlFlowInfo to validate that this function body has |
67 | // well-formed control flow. |
68 | std::vector<ControlFlowInfo> dummy; |
69 | TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &dummy)); |
70 | |
71 | *fbody = std::make_unique<FunctionBody>(fdef, result.arg_types, |
72 | result.ret_types, graph.release()); |
73 | return OkStatus(); |
74 | } |
75 | |
76 | Status FunctionDefToBodyHelper(const FunctionDef& fdef, const AttrSlice& attrs, |
77 | const FunctionLibraryDefinition* lib_def, |
78 | std::unique_ptr<FunctionBody>* fbody) { |
79 | const auto get_func_sig = [&lib_def](const string& op, const OpDef** sig) { |
80 | return lib_def->LookUpOpDef(op, sig); |
81 | }; |
82 | return FunctionDefToBodyHelper(fdef, attrs, lib_def, get_func_sig, fbody); |
83 | } |
84 | |
85 | } // end namespace tensorflow |
86 | |