1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/function/runtime_client.h" |
17 | |
18 | #include <memory> |
19 | #include <string> |
20 | #include <utility> |
21 | #include <vector> |
22 | |
23 | #include "absl/strings/str_cat.h" |
24 | #include "absl/types/span.h" |
25 | #include "mlir/IR/MLIRContext.h" // from @llvm-project |
26 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
27 | #include "mlir/Pass/PassRegistry.h" // from @llvm-project |
28 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
29 | #include "tensorflow/c/eager/abstract_tensor_handle.h" |
30 | #include "tensorflow/c/eager/immediate_execution_context.h" |
31 | #include "tensorflow/c/eager/immediate_execution_operation.h" |
32 | #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" |
33 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
34 | #include "tensorflow/core/common_runtime/device_mgr.h" |
35 | #include "tensorflow/core/common_runtime/eager/context.h" |
36 | #include "tensorflow/core/framework/device.h" |
37 | #include "tensorflow/core/framework/device_factory.h" |
38 | #include "tensorflow/core/framework/function.pb.h" |
39 | #include "tensorflow/core/framework/graph.pb.h" |
40 | #include "tensorflow/core/framework/op_def.pb.h" |
41 | #include "tensorflow/core/ir/importexport/graphdef_export.h" |
42 | #include "tensorflow/core/ir/importexport/graphdef_import.h" |
43 | #include "tensorflow/core/ir/ops.h" |
44 | #include "tensorflow/core/platform/errors.h" |
45 | #include "tensorflow/core/platform/status.h" |
46 | #include "tensorflow/core/platform/statusor.h" |
47 | #include "tensorflow/core/platform/stringpiece.h" |
48 | #include "tensorflow/core/platform/types.h" |
49 | #include "tensorflow/core/protobuf/error_codes.pb.h" |
50 | #include "tensorflow/core/public/session_options.h" |
51 | |
52 | namespace tensorflow { |
53 | namespace core { |
54 | namespace function { |
55 | |
56 | EagerContext& GlobalEagerContext() { |
57 | static EagerContext* global_ctx = []() { |
58 | SessionOptions opts; |
59 | std::vector<std::unique_ptr<Device>> devices; |
60 | Status&& device_init_status = DeviceFactory::AddDevices( |
61 | opts, "/job:localhost/replica:0/task:0" , &devices); |
62 | CHECK(device_init_status.ok()); // Crash OK |
63 | |
64 | return new EagerContext( |
65 | opts, ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT, |
66 | /*async=*/false, |
67 | /*device_mgr=*/new DynamicDeviceMgr(std::move(devices)), |
68 | /*device_mgr_owned=*/true, |
69 | /*rendezvous=*/nullptr, |
70 | /*cluster_flr=*/nullptr, |
71 | /*collective_executor_mgr=*/nullptr, |
72 | /*run_eager_op_as_function=*/true); |
73 | }(); |
74 | return *global_ctx; |
75 | } |
76 | |
77 | EagerContext& GlobalPythonEagerContext() { |
78 | EagerContext* ctx = reinterpret_cast<EagerContext*>(GetCEagerContext()); |
79 | DCHECK(ctx) << "The Python eager context must be initialized first." ; |
80 | return *ctx; |
81 | } |
82 | |
83 | StatusOr<FunctionDef> Runtime::GetFunctionProto(StringPiece name) { |
84 | EagerContext& ctx = this->eager_ctx_; |
85 | |
86 | const FunctionDef* f = ctx.FindFunctionDef(std::string(name)); |
87 | if (f == nullptr) { |
88 | return Status(error::INVALID_ARGUMENT, |
89 | absl::StrCat("Could not find an attribute for key " , name)); |
90 | } |
91 | |
92 | return *f; |
93 | } |
94 | |
95 | Status Runtime::CreateFunction(const FunctionDef& fdef) { |
96 | const auto& fname = fdef.signature().name(); |
97 | if (this->eager_ctx_.FindFunctionByName(fname)) { |
98 | TF_RETURN_WITH_CONTEXT_IF_ERROR(this->eager_ctx_.RemoveFunction(fname), |
99 | "removing function " , fname); |
100 | } |
101 | return this->eager_ctx_.AddFunctionDef(fdef); |
102 | } |
103 | |
104 | Status Runtime::CreateFunction(OpaqueTfgGraphFuncOp* fop) { |
105 | mlir::tfg::GraphFuncOp fop_proper = |
106 | *reinterpret_cast<mlir::tfg::GraphFuncOp*>(fop); |
107 | return mlir::tfg::ConvertToFunctionDef(fop_proper, |
108 | *this->eager_ctx_.FuncLibDef()); |
109 | } |
110 | |
111 | Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name) { |
112 | // TODO(mdan): Use a longer-lived context. |
113 | mlir::MLIRContext ctx; |
114 | mlir::PassManager pm(&ctx); |
115 | |
116 | std::string error; |
117 | llvm::raw_string_ostream error_stream(error); |
118 | // StringPiece doesn't seem to always be compatible with StringRef. |
119 | if (mlir::failed(mlir::parsePassPipeline(std::string(pipeline_name), pm, |
120 | error_stream))) { |
121 | return Status(error::INVALID_ARGUMENT, |
122 | absl::StrCat("locating pass pipeline " , pipeline_name, ": " , |
123 | error_stream.str())); |
124 | } |
125 | |
126 | // For now, we roundtrip from proto. Once we have a permanent MLIR |
127 | // representation, we should be able to use it directly. |
128 | auto fn = GetFunctionProto(name); |
129 | TF_RETURN_WITH_CONTEXT_IF_ERROR(fn.status(), "loading function " , name); |
130 | |
131 | GraphDef graph; |
132 | *graph.mutable_library()->add_function() = *fn; |
133 | tensorflow::GraphDebugInfo debug_info; |
134 | auto mlir_fn = mlir::tfg::ImportGraphDef(&ctx, debug_info, graph); |
135 | TF_RETURN_WITH_CONTEXT_IF_ERROR(mlir_fn.status(), "importing function " , |
136 | name); |
137 | |
138 | mlir::StatusScopedDiagnosticHandler diagnostics_handler(&ctx); |
139 | if (failed(pm.run(mlir_fn->get()))) { |
140 | return diagnostics_handler.Combine( |
141 | Status(error::INVALID_ARGUMENT, |
142 | absl::StrCat("running pass pipeline " , pipeline_name, ": " ))); |
143 | } |
144 | |
145 | for (auto fn : mlir_fn->get().getBody()->getOps<mlir::tfg::GraphFuncOp>()) { |
146 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
147 | CreateFunction(reinterpret_cast<OpaqueTfgGraphFuncOp*>(&fn)), |
148 | absl::StrCat("updating function " , fn.getName().str())); |
149 | } |
150 | |
151 | return OkStatus(); |
152 | } |
153 | |
154 | StatusOr<ReturnValues> Runtime::CallFunction( |
155 | StringPiece name, absl::Span<AbstractTensorHandle* const> args) { |
156 | EagerContext& ctx = this->eager_ctx_; |
157 | |
158 | ImmediateOpPtr op(ctx.CreateOperation()); |
159 | TF_RETURN_WITH_CONTEXT_IF_ERROR(op->Reset(name.data(), nullptr), |
160 | "initializing call op for " , name); |
161 | |
162 | TF_RETURN_WITH_CONTEXT_IF_ERROR(op->AddInputList(args), |
163 | "preparing call args for " , name); |
164 | |
165 | const FunctionDef* fn_def = ctx.GetFunctionDef(string(name)); |
166 | int num_retvals = fn_def->signature().output_arg_size(); |
167 | int actual_retvals = num_retvals; |
168 | std::vector<ImmediateExecutionTensorHandle*> retvals(num_retvals); |
169 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
170 | op->Execute(absl::MakeSpan( |
171 | reinterpret_cast<AbstractTensorHandle**>(retvals.data()), |
172 | num_retvals), |
173 | &actual_retvals), |
174 | "executing call op for " , name); |
175 | DCHECK(num_retvals == actual_retvals); |
176 | |
177 | ReturnValues final_returns; |
178 | for (const auto& r : retvals) { |
179 | final_returns.emplace_back(ImmediateTensorHandlePtr(r)); |
180 | } |
181 | |
182 | return final_returns; |
183 | } |
184 | |
185 | } // namespace function |
186 | } // namespace core |
187 | } // namespace tensorflow |
188 | |