1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | #include <vector> |
18 | |
19 | #include "pybind11/pybind11.h" |
20 | #include "tensorflow/lite/toco/python/toco_python_api.h" |
21 | #include "tensorflow/python/lib/core/pybind11_lib.h" |
22 | |
23 | namespace py = pybind11; |
24 | |
25 | PYBIND11_MODULE(_pywrap_toco_api, m) { |
26 | m.def( |
27 | "TocoConvert" , |
28 | [](py::object model_flags_proto_txt_raw, |
29 | py::object toco_flags_proto_txt_raw, py::object input_contents_txt_raw, |
30 | bool extended_return, py::object debug_info_txt_raw, |
31 | bool enable_mlir_converter) { |
32 | return tensorflow::PyoOrThrow(toco::TocoConvert( |
33 | model_flags_proto_txt_raw.ptr(), toco_flags_proto_txt_raw.ptr(), |
34 | input_contents_txt_raw.ptr(), extended_return, |
35 | debug_info_txt_raw.ptr(), enable_mlir_converter)); |
36 | }, |
37 | py::arg("model_flags_proto_txt_raw" ), py::arg("toco_flags_proto_txt_raw" ), |
38 | py::arg("input_contents_txt_raw" ), py::arg("extended_return" ) = false, |
39 | py::arg("debug_info_txt_raw" ) = py::none(), |
40 | py::arg("enable_mlir_converter" ) = false, |
41 | R"pbdoc( |
42 | Convert a model represented in `input_contents`. `model_flags_proto` |
43 | describes model parameters. `toco_flags_proto` describes conversion |
44 | parameters (see relevant .protos for more information). Returns a string |
45 | representing the contents of the converted model. When extended_return |
46 | flag is set to true returns a dictionary that contains string representation |
47 | of the converted model and some statistics like arithmetic ops count. |
48 | `debug_info_str` contains the `GraphDebugInfo` proto. When |
49 | `enable_mlir_converter` is True, tuse MLIR-based conversion instead of |
50 | TOCO conversion. |
51 | )pbdoc" ); |
52 | m.def( |
53 | "ExperimentalMlirQuantizeModel" , |
54 | [](py::object input_contents_txt_raw, bool disable_per_channel, |
55 | bool fully_quantize, int inference_type, int input_data_type, |
56 | int output_data_type, bool enable_numeric_verify, |
57 | bool enable_whole_model_verify, py::object op_blocklist, |
58 | py::object node_blocklist) { |
59 | return tensorflow::PyoOrThrow(toco::MlirQuantizeModel( |
60 | input_contents_txt_raw.ptr(), disable_per_channel, fully_quantize, |
61 | inference_type, input_data_type, output_data_type, |
62 | enable_numeric_verify, enable_whole_model_verify, |
63 | op_blocklist.ptr(), node_blocklist.ptr())); |
64 | }, |
65 | py::arg("input_contents_txt_raw" ), py::arg("disable_per_channel" ) = false, |
66 | py::arg("fully_quantize" ) = true, py::arg("inference_type" ) = 9, |
67 | py::arg("input_data_type" ) = 0, py::arg("output_data_type" ) = 0, |
68 | py::arg("enable_numeric_verify" ) = false, |
69 | py::arg("enable_whole_model_verify" ) = false, |
70 | py::arg("op_blocklist" ) = py::none(), |
71 | py::arg("node_blocklist" ) = py::none(), |
72 | R"pbdoc( |
73 | Returns a quantized model. |
74 | )pbdoc" ); |
75 | m.def( |
76 | "ExperimentalMlirSparsifyModel" , |
77 | [](py::object input_contents_txt_raw) { |
78 | return tensorflow::PyoOrThrow( |
79 | toco::MlirSparsifyModel(input_contents_txt_raw.ptr())); |
80 | }, |
81 | py::arg("input_contents_txt_raw" ), |
82 | R"pbdoc( |
83 | Returns a sparsified model. |
84 | )pbdoc" ); |
85 | m.def( |
86 | "RegisterCustomOpdefs" , |
87 | [](py::object custom_opdefs_txt_raw) { |
88 | return tensorflow::PyoOrThrow( |
89 | toco::RegisterCustomOpdefs(custom_opdefs_txt_raw.ptr())); |
90 | }, |
91 | py::arg("custom_opdefs_txt_raw" ), |
92 | R"pbdoc( |
93 | Registers the given custom opdefs to the TensorFlow global op registry. |
94 | )pbdoc" ); |
95 | m.def( |
96 | "RetrieveCollectedErrors" , |
97 | []() { |
98 | std::vector<std::string> collected_errors = |
99 | toco::RetrieveCollectedErrors(); |
100 | pybind11::list serialized_message_list(collected_errors.size()); |
101 | int i = 0; |
102 | for (const auto& error_data : collected_errors) { |
103 | serialized_message_list[i++] = pybind11::bytes(error_data); |
104 | } |
105 | return serialized_message_list; |
106 | }, |
107 | R"pbdoc( |
108 | Returns and clears the list of collected errors in ErrorCollector. |
109 | )pbdoc" ); |
110 | m.def( |
111 | "FlatBufferToMlir" , |
112 | [](const std::string& model, bool input_is_filepath) { |
113 | return toco::FlatBufferFileToMlir(model, input_is_filepath); |
114 | }, |
115 | R"pbdoc( |
116 | Returns MLIR dump of the given TFLite model. |
117 | )pbdoc" ); |
118 | } |
119 | |