1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "pybind11/pybind11.h" |
17 | #include "pybind11/pytypes.h" |
18 | #include "tensorflow/c/tf_status.h" |
19 | #include "tensorflow/compiler/mlir/python/mlir.h" |
20 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
21 | #include "tensorflow/python/lib/core/pybind11_status.h" |
22 | #include "tensorflow/python/lib/core/safe_ptr.h" |
23 | |
24 | PYBIND11_MODULE(_pywrap_mlir, m) { |
25 | m.def("ImportGraphDef" , |
26 | [](const std::string &graphdef, const std::string &pass_pipeline, |
27 | bool show_debug_info) { |
28 | tensorflow::Safe_TF_StatusPtr status = |
29 | tensorflow::make_safe(TF_NewStatus()); |
30 | std::string output = tensorflow::ImportGraphDef( |
31 | graphdef, pass_pipeline, show_debug_info, status.get()); |
32 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
33 | return output; |
34 | }); |
35 | |
36 | m.def("ImportFunction" , |
37 | [](const py::handle &context, const std::string &functiondef, |
38 | const std::string &pass_pipeline, bool show_debug_info) { |
39 | tensorflow::Safe_TF_StatusPtr status = |
40 | tensorflow::make_safe(TF_NewStatus()); |
41 | auto *ctxt = static_cast<TFE_Context *>( |
42 | PyCapsule_GetPointer(context.ptr(), nullptr)); |
43 | if (!ctxt) throw py::error_already_set(); |
44 | std::string output = tensorflow::ImportFunction( |
45 | functiondef, pass_pipeline, show_debug_info, ctxt, status.get()); |
46 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
47 | return output; |
48 | }); |
49 | |
50 | m.def("ImportGraphDef" , |
51 | [](const std::string &graphdef, const std::string &pass_pipeline, |
52 | bool show_debug_info, const std::string &input_names, |
53 | const std::string &input_data_types, |
54 | const std::string &input_data_shapes, |
55 | const std::string &output_names) { |
56 | tensorflow::Safe_TF_StatusPtr status = |
57 | tensorflow::make_safe(TF_NewStatus()); |
58 | std::string output = tensorflow::ImportGraphDef( |
59 | graphdef, pass_pipeline, show_debug_info, input_names, |
60 | input_data_types, input_data_shapes, output_names, status.get()); |
61 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
62 | return output; |
63 | }); |
64 | |
65 | m.def("ExperimentalConvertSavedModelToMlir" , |
66 | [](const std::string &saved_model_path, |
67 | const std::string &exported_names, bool show_debug_info) { |
68 | tensorflow::Safe_TF_StatusPtr status = |
69 | tensorflow::make_safe(TF_NewStatus()); |
70 | std::string output = tensorflow::ExperimentalConvertSavedModelToMlir( |
71 | saved_model_path, exported_names, show_debug_info, status.get()); |
72 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
73 | return output; |
74 | }); |
75 | |
76 | m.def("ExperimentalConvertSavedModelV1ToMlirLite" , |
77 | [](const std::string &saved_model_path, |
78 | const std::string &exported_names_str, const std::string &tags, |
79 | bool upgrade_legacy, bool show_debug_info) { |
80 | tensorflow::Safe_TF_StatusPtr status = |
81 | tensorflow::make_safe(TF_NewStatus()); |
82 | std::string output = |
83 | tensorflow::ExperimentalConvertSavedModelV1ToMlirLite( |
84 | saved_model_path, exported_names_str, tags, upgrade_legacy, |
85 | show_debug_info, status.get()); |
86 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
87 | return output; |
88 | }); |
89 | |
90 | m.def("ExperimentalConvertSavedModelV1ToMlir" , |
91 | [](const std::string &saved_model_path, |
92 | const std::string &exported_names_str, const std::string &tags, |
93 | bool lift_variables, bool upgrade_legacy, bool show_debug_info) { |
94 | tensorflow::Safe_TF_StatusPtr status = |
95 | tensorflow::make_safe(TF_NewStatus()); |
96 | std::string output = |
97 | tensorflow::ExperimentalConvertSavedModelV1ToMlir( |
98 | saved_model_path, exported_names_str, tags, lift_variables, |
99 | upgrade_legacy, show_debug_info, status.get()); |
100 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
101 | return output; |
102 | }); |
103 | |
104 | m.def("ExperimentalRunPassPipeline" , |
105 | [](const std::string &mlir_txt, const std::string &pass_pipeline, |
106 | bool show_debug_info) { |
107 | tensorflow::Safe_TF_StatusPtr status = |
108 | tensorflow::make_safe(TF_NewStatus()); |
109 | std::string output = tensorflow::ExperimentalRunPassPipeline( |
110 | mlir_txt, pass_pipeline, show_debug_info, status.get()); |
111 | tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get()); |
112 | return output; |
113 | }); |
114 | }; |
115 | |