1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/compiler/mlir/python/mlir.h" |
17 | |
18 | #include <string> |
19 | #include <type_traits> |
20 | #include <utility> |
21 | |
22 | #include "absl/algorithm/container.h" |
23 | #include "absl/container/flat_hash_set.h" |
24 | #include "absl/container/inlined_vector.h" |
25 | #include "absl/strings/str_cat.h" |
26 | #include "absl/strings/str_join.h" |
27 | #include "absl/strings/str_split.h" |
28 | #include "llvm/Support/raw_ostream.h" |
29 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
30 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
31 | #include "mlir/InitAllPasses.h" // from @llvm-project |
32 | #include "mlir/Parser/Parser.h" // from @llvm-project |
33 | #include "mlir/Pass/PassManager.h" // from @llvm-project |
34 | #include "mlir/Pass/PassRegistry.h" // from @llvm-project |
35 | #include "tensorflow/c/eager/c_api.h" |
36 | #include "tensorflow/c/eager/tfe_context_internal.h" |
37 | #include "tensorflow/c/tf_status.h" |
38 | #include "tensorflow/c/tf_status_helper.h" |
39 | #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" |
40 | #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" |
41 | #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" |
42 | #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" |
43 | #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" |
44 | #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" |
45 | #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" |
46 | #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" |
47 | #include "tensorflow/compiler/mlir/tensorflow/utils/import_utils.h" |
48 | #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" |
49 | #include "tensorflow/compiler/mlir/tosa/tf_passes.h" |
50 | #include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" |
51 | #include "tensorflow/compiler/mlir/tosa/tfl_passes.h" |
52 | #include "tensorflow/compiler/mlir/tosa/transforms/passes.h" |
53 | #include "tensorflow/compiler/mlir/xla/transforms/passes.h" |
54 | #include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h" |
55 | #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/lhlo/transforms/register_passes.h" |
56 | #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/transforms/register_passes.h" |
57 | #include "tensorflow/compiler/xla/status_macros.h" |
58 | #include "tensorflow/core/common_runtime/eager/context.h" |
59 | #include "tensorflow/core/common_runtime/function_body.h" |
60 | #include "tensorflow/core/common_runtime/function_def_utils.h" |
61 | #include "tensorflow/core/framework/function.h" |
62 | #include "tensorflow/core/framework/function.pb.h" |
63 | #include "tensorflow/core/framework/op.h" |
64 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
65 | #include "tensorflow/core/framework/types.h" |
66 | #include "tensorflow/core/framework/types.pb.h" |
67 | #include "tensorflow/core/lib/core/errors.h" |
68 | #include "tensorflow/core/platform/types.h" |
69 | |
70 | namespace tensorflow { |
71 | |
72 | namespace { |
73 | // All the passes we will make available to Python by default. |
74 | // TODO(tf): this should be sharded instead of being monolithic like that. |
75 | static void RegisterPasses() { |
76 | static bool unique_registration = [] { |
77 | mlir::registerAllPasses(); |
78 | mlir::registerTensorFlowPasses(); |
79 | mlir::TFDevice::registerTensorFlowDevicePasses(); |
80 | mlir::mhlo::registerAllMhloPasses(); |
81 | mlir::lmhlo::registerAllLmhloPasses(); |
82 | // These are in compiler/mlir/xla and not part of the above MHLO |
83 | // passes. |
84 | mlir::mhlo::registerXlaPasses(); |
85 | mlir::mhlo::registerTfXlaPasses(); |
86 | mlir::mhlo::registerLegalizeTFPass(); |
87 | mlir::mhlo::registerLegalizeTFControlFlowPass(); |
88 | mlir::mhlo::registerLegalizeTfTypesPassPass(); |
89 | mlir::tosa::registerLegalizeTosaPasses(); |
90 | mlir::tosa::registerTFtoTOSALegalizationPipeline(); |
91 | mlir::tosa::registerTFLtoTOSALegalizationPipeline(); |
92 | mlir::tosa::registerTFTFLtoTOSALegalizationPipeline(); |
93 | mlir::tf_saved_model::registerTensorFlowSavedModelPasses(); |
94 | return true; |
95 | }(); |
96 | (void)unique_registration; |
97 | } |
98 | |
99 | // Runs pass pipeline `pass_pipeline` on `module` if `pass_pipeline` is not |
100 | // empty. |
101 | std::string RunPassPipelineOnModule(mlir::ModuleOp module, |
102 | const std::string& pass_pipeline, |
103 | bool show_debug_info, TF_Status* status) { |
104 | RegisterPasses(); |
105 | if (!pass_pipeline.empty()) { |
106 | mlir::PassManager pm(module.getContext()); |
107 | std::string error; |
108 | llvm::raw_string_ostream error_stream(error); |
109 | if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { |
110 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
111 | ("Invalid pass_pipeline: " + error_stream.str()).c_str()); |
112 | return "// error" ; |
113 | } |
114 | |
115 | mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext()); |
116 | if (failed(pm.run(module))) { |
117 | Set_TF_Status_from_Status(status, statusHandler.ConsumeStatus()); |
118 | return "// error" ; |
119 | } |
120 | } |
121 | return MlirModuleToString(module, show_debug_info); |
122 | } |
123 | |
124 | } // anonymous namespace |
125 | |
126 | static std::string ImportGraphDefImpl(const std::string& proto, |
127 | const std::string& pass_pipeline, |
128 | bool show_debug_info, |
129 | GraphDebugInfo& debug_info, |
130 | GraphImportConfig& specs, |
131 | TF_Status* status) { |
132 | GraphDef graphdef; |
133 | auto s = tensorflow::LoadProtoFromBuffer(proto, &graphdef); |
134 | if (!s.ok()) { |
135 | Set_TF_Status_from_Status(status, s); |
136 | return "// error" ; |
137 | } |
138 | mlir::MLIRContext context; |
139 | auto module = ConvertGraphdefToMlir(graphdef, debug_info, specs, &context); |
140 | if (!module.ok()) { |
141 | Set_TF_Status_from_Status(status, module.status()); |
142 | return "// error" ; |
143 | } |
144 | |
145 | return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info, |
146 | status); |
147 | } |
148 | |
149 | std::string ImportFunction(const std::string& functiondef_proto, |
150 | const std::string& pass_pipeline, |
151 | bool show_debug_info, TFE_Context* tfe_context, |
152 | TF_Status* status) { |
153 | FunctionDef functiondef; |
154 | auto s = tensorflow::LoadProtoFromBuffer(functiondef_proto, &functiondef); |
155 | if (!s.ok()) { |
156 | Set_TF_Status_from_Status(status, s); |
157 | return "// error" ; |
158 | } |
159 | |
160 | const std::string& function_name = functiondef.signature().name(); |
161 | EagerContext* cpp_context = ContextFromInterface(unwrap(tfe_context)); |
162 | FunctionLibraryDefinition& flib_def = *cpp_context->FuncLibDef(); |
163 | const tensorflow::FunctionDef* fdef = flib_def.Find(function_name); |
164 | if (fdef == nullptr) { |
165 | s = tensorflow::errors::NotFound("Cannot find function " , function_name); |
166 | Set_TF_Status_from_Status(status, s); |
167 | return "// error" ; |
168 | } |
169 | |
170 | std::unique_ptr<tensorflow::FunctionBody> fbody; |
171 | s = FunctionDefToBodyHelper(*fdef, tensorflow::AttrSlice(), &flib_def, |
172 | &fbody); |
173 | if (!s.ok()) { |
174 | Set_TF_Status_from_Status(status, s); |
175 | return "// error" ; |
176 | } |
177 | |
178 | mlir::MLIRContext context; |
179 | auto module = ConvertFunctionToMlir(fbody.get(), flib_def, &context); |
180 | if (!module.ok()) { |
181 | Set_TF_Status_from_Status(status, module.status()); |
182 | return "// error" ; |
183 | } |
184 | |
185 | return RunPassPipelineOnModule(module->get(), pass_pipeline, show_debug_info, |
186 | status); |
187 | } |
188 | |
189 | std::string ImportGraphDef(const std::string& proto, |
190 | const std::string& pass_pipeline, |
191 | bool show_debug_info, TF_Status* status) { |
192 | GraphDebugInfo debug_info; |
193 | GraphImportConfig specs; |
194 | return ImportGraphDefImpl(proto, pass_pipeline, show_debug_info, debug_info, |
195 | specs, status); |
196 | } |
197 | |
198 | std::string ImportGraphDef(const std::string& proto, |
199 | const std::string& pass_pipeline, |
200 | bool show_debug_info, absl::string_view input_names, |
201 | absl::string_view input_data_types, |
202 | absl::string_view input_data_shapes, |
203 | absl::string_view output_names, TF_Status* status) { |
204 | GraphDebugInfo debug_info; |
205 | GraphImportConfig specs; |
206 | auto s = ParseInputArrayInfo(input_names, input_data_types, input_data_shapes, |
207 | &specs.inputs); |
208 | if (!s.ok()) { |
209 | Set_TF_Status_from_Status(status, s); |
210 | return "// error" ; |
211 | } |
212 | if (!output_names.empty()) { |
213 | specs.outputs = absl::StrSplit(output_names, ','); |
214 | } |
215 | return ImportGraphDefImpl(proto, pass_pipeline, show_debug_info, debug_info, |
216 | specs, status); |
217 | } |
218 | |
219 | std::string ExperimentalConvertSavedModelToMlir( |
220 | const std::string& saved_model_path, const std::string& exported_names_str, |
221 | bool show_debug_info, TF_Status* status) { |
222 | // Load the saved model into a SavedModelV2Bundle. |
223 | |
224 | tensorflow::SavedModelV2Bundle bundle; |
225 | auto load_status = |
226 | tensorflow::SavedModelV2Bundle::Load(saved_model_path, &bundle); |
227 | if (!load_status.ok()) { |
228 | Set_TF_Status_from_Status(status, load_status); |
229 | return "// error" ; |
230 | } |
231 | |
232 | // Convert the SavedModelV2Bundle to an MLIR module. |
233 | |
234 | std::vector<string> exported_names = |
235 | absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); |
236 | mlir::MLIRContext context; |
237 | auto module_or = ConvertSavedModelToMlir( |
238 | &bundle, &context, absl::Span<std::string>(exported_names)); |
239 | if (!module_or.status().ok()) { |
240 | Set_TF_Status_from_Status(status, module_or.status()); |
241 | return "// error" ; |
242 | } |
243 | |
244 | return MlirModuleToString(*std::move(module_or).value(), show_debug_info); |
245 | } |
246 | |
247 | std::string ExperimentalConvertSavedModelV1ToMlirLite( |
248 | const std::string& saved_model_path, const std::string& exported_names_str, |
249 | const std::string& tags, bool upgrade_legacy, bool show_debug_info, |
250 | TF_Status* status) { |
251 | std::unordered_set<string> tag_set = |
252 | absl::StrSplit(tags, ',', absl::SkipEmpty()); |
253 | |
254 | std::vector<string> exported_names = |
255 | absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); |
256 | mlir::MLIRContext context; |
257 | |
258 | tensorflow::MLIRImportOptions import_options; |
259 | import_options.upgrade_legacy = upgrade_legacy; |
260 | auto module_or = SavedModelSignatureDefsToMlirImportLite( |
261 | saved_model_path, tag_set, absl::Span<std::string>(exported_names), |
262 | &context, import_options); |
263 | if (!module_or.status().ok()) { |
264 | Set_TF_Status_from_Status(status, module_or.status()); |
265 | return "// error" ; |
266 | } |
267 | |
268 | return MlirModuleToString(*module_or.value(), show_debug_info); |
269 | } |
270 | |
271 | std::string ExperimentalConvertSavedModelV1ToMlir( |
272 | const std::string& saved_model_path, const std::string& exported_names_str, |
273 | const std::string& tags, bool lift_variables, bool upgrade_legacy, |
274 | bool show_debug_info, TF_Status* status) { |
275 | // Load the saved model into a SavedModelBundle. |
276 | |
277 | std::unordered_set<string> tag_set = |
278 | absl::StrSplit(tags, ',', absl::SkipEmpty()); |
279 | |
280 | tensorflow::SavedModelBundle bundle; |
281 | auto load_status = |
282 | tensorflow::LoadSavedModel({}, {}, saved_model_path, tag_set, &bundle); |
283 | if (!load_status.ok()) { |
284 | Set_TF_Status_from_Status(status, load_status); |
285 | return "// error" ; |
286 | } |
287 | |
288 | // Convert the SavedModelBundle to an MLIR module. |
289 | std::vector<string> exported_names = |
290 | absl::StrSplit(exported_names_str, ',', absl::SkipEmpty()); |
291 | mlir::MLIRContext context; |
292 | tensorflow::MLIRImportOptions import_options; |
293 | import_options.upgrade_legacy = upgrade_legacy; |
294 | auto module_or = |
295 | ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names), |
296 | &context, import_options, lift_variables); |
297 | if (!module_or.status().ok()) { |
298 | Set_TF_Status_from_Status(status, module_or.status()); |
299 | return "// error" ; |
300 | } |
301 | |
302 | // Run the tf standard pipeline by default and then, run passes that lift |
303 | // variables if the flag is set on the module. |
304 | mlir::OwningOpRef<mlir::ModuleOp> module = std::move(module_or).value(); |
305 | mlir::PassManager pm(&context); |
306 | std::string error; |
307 | llvm::raw_string_ostream error_stream(error); |
308 | |
309 | mlir::TF::StandardPipelineOptions tf_options; |
310 | mlir::TF::CreateTFStandardPipeline(pm, tf_options); |
311 | |
312 | mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); |
313 | if (failed(pm.run(*module))) { |
314 | Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); |
315 | return "// error" ; |
316 | } |
317 | return MlirModuleToString(*module, show_debug_info); |
318 | } |
319 | |
320 | std::string ExperimentalRunPassPipeline(const std::string& mlir_txt, |
321 | const std::string& pass_pipeline, |
322 | bool show_debug_info, |
323 | TF_Status* status) { |
324 | RegisterPasses(); |
325 | mlir::DialectRegistry registry; |
326 | mlir::RegisterAllTensorFlowDialects(registry); |
327 | mlir::MLIRContext context(registry); |
328 | mlir::OwningOpRef<mlir::ModuleOp> module; |
329 | { |
330 | mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); |
331 | module = mlir::parseSourceString<mlir::ModuleOp>(mlir_txt, &context); |
332 | if (!module) { |
333 | Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); |
334 | return "// error" ; |
335 | } |
336 | } |
337 | |
338 | // Run the pass_pipeline on the module. |
339 | mlir::PassManager pm(&context); |
340 | std::string error; |
341 | llvm::raw_string_ostream error_stream(error); |
342 | if (failed(mlir::parsePassPipeline(pass_pipeline, pm, error_stream))) { |
343 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
344 | ("Invalid pass_pipeline: " + error_stream.str()).c_str()); |
345 | return "// error" ; |
346 | } |
347 | |
348 | mlir::StatusScopedDiagnosticHandler diagnostic_handler(&context); |
349 | if (failed(pm.run(*module))) { |
350 | Set_TF_Status_from_Status(status, diagnostic_handler.ConsumeStatus()); |
351 | return "// error" ; |
352 | } |
353 | return MlirModuleToString(*module, show_debug_info); |
354 | } |
355 | |
356 | } // namespace tensorflow |
357 | |