1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/function/runtime_client.h"
17
18#include <memory>
19#include <string>
20#include <utility>
21#include <vector>
22
23#include "absl/strings/str_cat.h"
24#include "absl/types/span.h"
25#include "mlir/IR/MLIRContext.h" // from @llvm-project
26#include "mlir/Pass/PassManager.h" // from @llvm-project
27#include "mlir/Pass/PassRegistry.h" // from @llvm-project
28#include "mlir/Support/LogicalResult.h" // from @llvm-project
29#include "tensorflow/c/eager/abstract_tensor_handle.h"
30#include "tensorflow/c/eager/immediate_execution_context.h"
31#include "tensorflow/c/eager/immediate_execution_operation.h"
32#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
33#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
34#include "tensorflow/core/common_runtime/device_mgr.h"
35#include "tensorflow/core/common_runtime/eager/context.h"
36#include "tensorflow/core/framework/device.h"
37#include "tensorflow/core/framework/device_factory.h"
38#include "tensorflow/core/framework/function.pb.h"
39#include "tensorflow/core/framework/graph.pb.h"
40#include "tensorflow/core/framework/op_def.pb.h"
41#include "tensorflow/core/ir/importexport/graphdef_export.h"
42#include "tensorflow/core/ir/importexport/graphdef_import.h"
43#include "tensorflow/core/ir/ops.h"
44#include "tensorflow/core/platform/errors.h"
45#include "tensorflow/core/platform/status.h"
46#include "tensorflow/core/platform/statusor.h"
47#include "tensorflow/core/platform/stringpiece.h"
48#include "tensorflow/core/platform/types.h"
49#include "tensorflow/core/protobuf/error_codes.pb.h"
50#include "tensorflow/core/public/session_options.h"
51
52namespace tensorflow {
53namespace core {
54namespace function {
55
56EagerContext& GlobalEagerContext() {
57 static EagerContext* global_ctx = []() {
58 SessionOptions opts;
59 std::vector<std::unique_ptr<Device>> devices;
60 Status&& device_init_status = DeviceFactory::AddDevices(
61 opts, "/job:localhost/replica:0/task:0", &devices);
62 CHECK(device_init_status.ok()); // Crash OK
63
64 return new EagerContext(
65 opts, ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
66 /*async=*/false,
67 /*device_mgr=*/new DynamicDeviceMgr(std::move(devices)),
68 /*device_mgr_owned=*/true,
69 /*rendezvous=*/nullptr,
70 /*cluster_flr=*/nullptr,
71 /*collective_executor_mgr=*/nullptr,
72 /*run_eager_op_as_function=*/true);
73 }();
74 return *global_ctx;
75}
76
77EagerContext& GlobalPythonEagerContext() {
78 EagerContext* ctx = reinterpret_cast<EagerContext*>(GetCEagerContext());
79 DCHECK(ctx) << "The Python eager context must be initialized first.";
80 return *ctx;
81}
82
83StatusOr<FunctionDef> Runtime::GetFunctionProto(StringPiece name) {
84 EagerContext& ctx = this->eager_ctx_;
85
86 const FunctionDef* f = ctx.FindFunctionDef(std::string(name));
87 if (f == nullptr) {
88 return Status(error::INVALID_ARGUMENT,
89 absl::StrCat("Could not find an attribute for key ", name));
90 }
91
92 return *f;
93}
94
95Status Runtime::CreateFunction(const FunctionDef& fdef) {
96 const auto& fname = fdef.signature().name();
97 if (this->eager_ctx_.FindFunctionByName(fname)) {
98 TF_RETURN_WITH_CONTEXT_IF_ERROR(this->eager_ctx_.RemoveFunction(fname),
99 "removing function ", fname);
100 }
101 return this->eager_ctx_.AddFunctionDef(fdef);
102}
103
104Status Runtime::CreateFunction(OpaqueTfgGraphFuncOp* fop) {
105 mlir::tfg::GraphFuncOp fop_proper =
106 *reinterpret_cast<mlir::tfg::GraphFuncOp*>(fop);
107 return mlir::tfg::ConvertToFunctionDef(fop_proper,
108 *this->eager_ctx_.FuncLibDef());
109}
110
111Status Runtime::TransformFunction(StringPiece name, StringPiece pipeline_name) {
112 // TODO(mdan): Use a longer-lived context.
113 mlir::MLIRContext ctx;
114 mlir::PassManager pm(&ctx);
115
116 std::string error;
117 llvm::raw_string_ostream error_stream(error);
118 // StringPiece doesn't seem to always be compatible with StringRef.
119 if (mlir::failed(mlir::parsePassPipeline(std::string(pipeline_name), pm,
120 error_stream))) {
121 return Status(error::INVALID_ARGUMENT,
122 absl::StrCat("locating pass pipeline ", pipeline_name, ": ",
123 error_stream.str()));
124 }
125
126 // For now, we roundtrip from proto. Once we have a permanent MLIR
127 // representation, we should be able to use it directly.
128 auto fn = GetFunctionProto(name);
129 TF_RETURN_WITH_CONTEXT_IF_ERROR(fn.status(), "loading function ", name);
130
131 GraphDef graph;
132 *graph.mutable_library()->add_function() = *fn;
133 tensorflow::GraphDebugInfo debug_info;
134 auto mlir_fn = mlir::tfg::ImportGraphDef(&ctx, debug_info, graph);
135 TF_RETURN_WITH_CONTEXT_IF_ERROR(mlir_fn.status(), "importing function ",
136 name);
137
138 mlir::StatusScopedDiagnosticHandler diagnostics_handler(&ctx);
139 if (failed(pm.run(mlir_fn->get()))) {
140 return diagnostics_handler.Combine(
141 Status(error::INVALID_ARGUMENT,
142 absl::StrCat("running pass pipeline ", pipeline_name, ": ")));
143 }
144
145 for (auto fn : mlir_fn->get().getBody()->getOps<mlir::tfg::GraphFuncOp>()) {
146 TF_RETURN_WITH_CONTEXT_IF_ERROR(
147 CreateFunction(reinterpret_cast<OpaqueTfgGraphFuncOp*>(&fn)),
148 absl::StrCat("updating function ", fn.getName().str()));
149 }
150
151 return OkStatus();
152}
153
154StatusOr<ReturnValues> Runtime::CallFunction(
155 StringPiece name, absl::Span<AbstractTensorHandle* const> args) {
156 EagerContext& ctx = this->eager_ctx_;
157
158 ImmediateOpPtr op(ctx.CreateOperation());
159 TF_RETURN_WITH_CONTEXT_IF_ERROR(op->Reset(name.data(), nullptr),
160 "initializing call op for ", name);
161
162 TF_RETURN_WITH_CONTEXT_IF_ERROR(op->AddInputList(args),
163 "preparing call args for ", name);
164
165 const FunctionDef* fn_def = ctx.GetFunctionDef(string(name));
166 int num_retvals = fn_def->signature().output_arg_size();
167 int actual_retvals = num_retvals;
168 std::vector<ImmediateExecutionTensorHandle*> retvals(num_retvals);
169 TF_RETURN_WITH_CONTEXT_IF_ERROR(
170 op->Execute(absl::MakeSpan(
171 reinterpret_cast<AbstractTensorHandle**>(retvals.data()),
172 num_retvals),
173 &actual_retvals),
174 "executing call op for ", name);
175 DCHECK(num_retvals == actual_retvals);
176
177 ReturnValues final_returns;
178 for (const auto& r : retvals) {
179 final_returns.emplace_back(ImmediateTensorHandlePtr(r));
180 }
181
182 return final_returns;
183}
184
185} // namespace function
186} // namespace core
187} // namespace tensorflow
188