1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/mlir/python/mlir.h"
17
18#include <string>
19#include <type_traits>
20#include <utility>
21
22#include "absl/algorithm/container.h"
23#include "absl/container/flat_hash_set.h"
24#include "absl/container/inlined_vector.h"
25#include "absl/strings/str_cat.h"
26#include "absl/strings/str_join.h"
27#include "absl/strings/str_split.h"
28#include "llvm/Support/raw_ostream.h"
29#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30#include "mlir/IR/BuiltinOps.h" // from @llvm-project
31#include "mlir/InitAllPasses.h" // from @llvm-project
32#include "mlir/Parser/Parser.h" // from @llvm-project
33#include "mlir/Pass/PassManager.h" // from @llvm-project
34#include "mlir/Pass/PassRegistry.h" // from @llvm-project
35#include "tensorflow/c/eager/c_api.h"
36#include "tensorflow/c/eager/tfe_context_internal.h"
37#include "tensorflow/c/tf_status.h"
38#include "tensorflow/c/tf_status_helper.h"
39#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
40#include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
41#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h"
43#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
44#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
45#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
46#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
47#include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h"
48#include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
49#include "tensorflow/compiler/mlir/tosa/tf_passes.h"
50#include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h"
51#include "tensorflow/compiler/mlir/tosa/tfl_passes.h"
52#include "tensorflow/compiler/mlir/tosa/transforms/passes.h"
53#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
54#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h"
55#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/register_passes.h"
56#include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h"
57#include "tensorflow/compiler/xla/status_macros.h"
58#include "tensorflow/core/common_runtime/eager/context.h"
59#include "tensorflow/core/common_runtime/function_body.h"
60#include "tensorflow/core/common_runtime/function_def_utils.h"
61#include "tensorflow/core/framework/function.h"
62#include "tensorflow/core/framework/function.pb.h"
63#include "tensorflow/core/framework/op.h"
64#include "tensorflow/core/framework/tensor_shape.pb.h"
65#include "tensorflow/core/framework/types.h"
66#include "tensorflow/core/framework/types.pb.h"
67#include "tensorflow/core/lib/core/errors.h"
68#include "tensorflow/core/platform/types.h"
69
70namespace tensorflow {
71
72namespace {
73// All the passes we will make available to Python by default.
74// TODO(tf): this should be sharded instead of being monolithic like that.
75static void RegisterPasses() {
76 static bool unique_registration = [] {
77 mlir::registerAllPasses();
78 mlir::registerTensorFlowPasses();
79 mlir::TFDevice::registerTensorFlowDevicePasses();
80 mlir::mhlo::registerAllMhloPasses();
81 mlir::lmhlo::registerAllLmhloPasses();
82 // These are in compiler/mlir/xla and not part of the above MHLO
83 // passes.
84 mlir::mhlo::registerXlaPasses();
85 mlir::mhlo::registerTfXlaPasses();
86 mlir::mhlo::registerLegalizeTFPass();
87 mlir::mhlo::registerLegalizeTFControlFlowPass();
88 mlir::mhlo::registerLegalizeTfTypesPassPass();
89 mlir::tosa::registerLegalizeTosaPasses();
90 mlir::tosa::registerTFtoTOSALegalizationPipeline();
91 mlir::tosa::registerTFLtoTOSALegalizationPipeline();
92 mlir::tosa::registerTFTFLtoTOSALegalizationPipeline();
93 mlir::tf_saved_model::registerTensorFlowSavedModelPasses();
94 return true;
95 }();
96 (void)unique_registration;
97}
98
99// Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not
100// empty.
101std::string RunPassPipelineOnModule(mlir::ModuleOp module,
102 const std::string& pass_pipeline,
103 bool show_debug_info, TF_Status* status) {
104 RegisterPasses();
105 if (!pass_pipeline.empty()) {
106 mlir::PassManager pm(module.getContext());
107 std::string error;
108 llvm::raw_string_ostream error_stream(error);
109 if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
110 TF_SetStatus(status, TF_INVALID_ARGUMENT,
111 ("Invalid pass_pipeline: " + error_stream.str()).c_str());
112 return "// error";
113 }
114
115 mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext());
116 if (failed(pm.run(module))) {
117 Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus());
118 return "// error";
119 }
120 }
121 return MlirModuleToString(module, show_debug_info);
122}
123
124} // anonymous namespace
125
126static std::string ImportGraphDefImpl(const std::string& proto,
127 const std::string& pass_pipeline,
128 bool show_debug_info,
129 GraphDebugInfo& debug_info,
130 GraphImportConfig& specs,
131 TF_Status* status) {
132 GraphDef graphdef;
133 auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef);
134 if (!s.ok()) {
135 Set_TF_Status_from_Status(status, s);
136 return "// error";
137 }
138 mlir::MLIRContext context;
139 auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context);
140 if (!module.ok()) {
141 Set_TF_Status_from_Status(status, module.status());
142 return "// error";
143 }
144
145 return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
146 status);
147}
148
149std::string ImportFunction(const std::string& functiondef_proto,
150 const std::string& pass_pipeline,
151 bool show_debug_info, TFE_Context* tfe_context,
152 TF_Status* status) {
153 FunctionDef functiondef;
154 auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef);
155 if (!s.ok()) {
156 Set_TF_Status_from_Status(status, s);
157 return "// error";
158 }
159
160 const std::string& function_name = functiondef.signature().name();
161 EagerContext* cpp_context = ContextFromInterface(unwrap(tfe_context));
162 FunctionLibraryDefinition& flib_def = *cpp_context->FuncLibDef();
163 const tensorflow::FunctionDef* fdef = flib_def.Find(function_name);
164 if (fdef == nullptr) {
165 s = tensorflow::errors::NotFound("Cannot find function ", function_name);
166 Set_TF_Status_from_Status(status, s);
167 return "// error";
168 }
169
170 std::unique_ptr<tensorflow::FunctionBody> fbody;
171 s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def,
172 &fbody);
173 if (!s.ok()) {
174 Set_TF_Status_from_Status(status, s);
175 return "// error";
176 }
177
178 mlir::MLIRContext context;
179 auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context);
180 if (!module.ok()) {
181 Set_TF_Status_from_Status(status, module.status());
182 return "// error";
183 }
184
185 return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info,
186 status);
187}
188
189std::string ImportGraphDef(const std::string& proto,
190 const std::string& pass_pipeline,
191 bool show_debug_info, TF_Status* status) {
192 GraphDebugInfo debug_info;
193 GraphImportConfig specs;
194 return ImportGraphDefImpl(proto, pass_pipeline, show_debug_info, debug_info,
195 specs, status);
196}
197
198std::string ImportGraphDef(const std::string& proto,
199 const std::string& pass_pipeline,
200 bool show_debug_info, absl::string_view input_names,
201 absl::string_view input_data_types,
202 absl::string_view input_data_shapes,
203 absl::string_view output_names, TF_Status* status) {
204 GraphDebugInfo debug_info;
205 GraphImportConfig specs;
206 auto s = ParseInputArrayInfo(input_names, input_data_types, input_data_shapes,
207 &specs.inputs);
208 if (!s.ok()) {
209 Set_TF_Status_from_Status(status, s);
210 return "// error";
211 }
212 if (!output_names.empty()) {
213 specs.outputs = absl::StrSplit(output_names, ',');
214 }
215 return ImportGraphDefImpl(proto, pass_pipeline, show_debug_info, debug_info,
216 specs, status);
217}
218
219std::string ExperimentalConvertSavedModelToMlir(
220 const std::string& saved_model_path, const std::string& exported_names_str,
221 bool show_debug_info, TF_Status* status) {
222 // Load the saved model into a SavedModelV2Bundle.
223
224 tensorflow::SavedModelV2Bundle bundle;
225 auto load_status =
226 tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle);
227 if (!load_status.ok()) {
228 Set_TF_Status_from_Status(status, load_status);
229 return "// error";
230 }
231
232 // Convert the SavedModelV2Bundle to an MLIR module.
233
234 std::vector<string> exported_names =
235 absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
236 mlir::MLIRContext context;
237 auto module_or = ConvertSavedModelToMlir(
238 &bundle, &context, absl::Span<std::string>(exported_names));
239 if (!module_or.status().ok()) {
240 Set_TF_Status_from_Status(status, module_or.status());
241 return "// error";
242 }
243
244 return MlirModuleToString(*std::move(module_or).value(), show_debug_info);
245}
246
247std::string ExperimentalConvertSavedModelV1ToMlirLite(
248 const std::string& saved_model_path, const std::string& exported_names_str,
249 const std::string& tags, bool upgrade_legacy, bool show_debug_info,
250 TF_Status* status) {
251 std::unordered_set<string> tag_set =
252 absl::StrSplit(tags, ',', absl::SkipEmpty());
253
254 std::vector<string> exported_names =
255 absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
256 mlir::MLIRContext context;
257
258 tensorflow::MLIRImportOptions import_options;
259 import_options.upgrade_legacy = upgrade_legacy;
260 auto module_or = SavedModelSignatureDefsToMlirImportLite(
261 saved_model_path, tag_set, absl::Span<std::string>(exported_names),
262 &context, import_options);
263 if (!module_or.status().ok()) {
264 Set_TF_Status_from_Status(status, module_or.status());
265 return "// error";
266 }
267
268 return MlirModuleToString(*module_or.value(), show_debug_info);
269}
270
271std::string ExperimentalConvertSavedModelV1ToMlir(
272 const std::string& saved_model_path, const std::string& exported_names_str,
273 const std::string& tags, bool lift_variables, bool upgrade_legacy,
274 bool show_debug_info, TF_Status* status) {
275 // Load the saved model into a SavedModelBundle.
276
277 std::unordered_set<string> tag_set =
278 absl::StrSplit(tags, ',', absl::SkipEmpty());
279
280 tensorflow::SavedModelBundle bundle;
281 auto load_status =
282 tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle);
283 if (!load_status.ok()) {
284 Set_TF_Status_from_Status(status, load_status);
285 return "// error";
286 }
287
288 // Convert the SavedModelBundle to an MLIR module.
289 std::vector<string> exported_names =
290 absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
291 mlir::MLIRContext context;
292 tensorflow::MLIRImportOptions import_options;
293 import_options.upgrade_legacy = upgrade_legacy;
294 auto module_or =
295 ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
296 &context, import_options, lift_variables);
297 if (!module_or.status().ok()) {
298 Set_TF_Status_from_Status(status, module_or.status());
299 return "// error";
300 }
301
302 // Run the tf standard pipeline by default and then, run passes that lift
303 // variables if the flag is set on the module.
304 mlir::OwningOpRef<mlir::ModuleOp> module = std::move(module_or).value();
305 mlir::PassManager pm(&context);
306 std::string error;
307 llvm::raw_string_ostream error_stream(error);
308
309 mlir::TF::StandardPipelineOptions tf_options;
310 mlir::TF::CreateTFStandardPipeline(pm, tf_options);
311
312 mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
313 if (failed(pm.run(*module))) {
314 Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
315 return "// error";
316 }
317 return MlirModuleToString(*module, show_debug_info);
318}
319
320std::string ExperimentalRunPassPipeline(const std::string& mlir_txt,
321 const std::string& pass_pipeline,
322 bool show_debug_info,
323 TF_Status* status) {
324 RegisterPasses();
325 mlir::DialectRegistry registry;
326 mlir::RegisterAllTensorFlowDialects(registry);
327 mlir::MLIRContext context(registry);
328 mlir::OwningOpRef<mlir::ModuleOp> module;
329 {
330 mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
331 module = mlir::parseSourceString<mlir::ModuleOp>(mlir_txt, &context);
332 if (!module) {
333 Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
334 return "// error";
335 }
336 }
337
338 // Run the pass_pipeline on the module.
339 mlir::PassManager pm(&context);
340 std::string error;
341 llvm::raw_string_ostream error_stream(error);
342 if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) {
343 TF_SetStatus(status, TF_INVALID_ARGUMENT,
344 ("Invalid pass_pipeline: " + error_stream.str()).c_str());
345 return "// error";
346 }
347
348 mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context);
349 if (failed(pm.run(*module))) {
350 Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus());
351 return "// error";
352 }
353 return MlirModuleToString(*module, show_debug_info);
354}
355
356} // namespace tensorflow
357