1 | #include <onnx/onnx_pb.h> |
2 | #include <torch/csrc/onnx/init.h> |
3 | #include <torch/csrc/onnx/onnx.h> |
4 | #include <torch/version.h> |
5 | |
6 | #include <torch/csrc/Exceptions.h> |
7 | #include <torch/csrc/jit/passes/onnx.h> |
8 | #include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h> |
9 | #include <torch/csrc/jit/passes/onnx/constant_fold.h> |
10 | #include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h> |
11 | #include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h> |
12 | #include <torch/csrc/jit/passes/onnx/eval_peephole.h> |
13 | #include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h> |
14 | #include <torch/csrc/jit/passes/onnx/function_extraction.h> |
15 | #include <torch/csrc/jit/passes/onnx/function_substitution.h> |
16 | #include <torch/csrc/jit/passes/onnx/list_model_parameters.h> |
17 | #include <torch/csrc/jit/passes/onnx/naming.h> |
18 | #include <torch/csrc/jit/passes/onnx/onnx_log.h> |
19 | #include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h> |
20 | #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h> |
21 | #include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h> |
22 | #include <torch/csrc/jit/passes/onnx/peephole.h> |
23 | #include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h> |
24 | #include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h> |
25 | #include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h> |
26 | #include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h> |
27 | #include <torch/csrc/jit/passes/onnx/shape_type_inference.h> |
28 | #include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h> |
29 | #include <torch/csrc/jit/serialization/export.h> |
30 | |
31 | namespace torch { |
32 | namespace onnx { |
33 | |
34 | using namespace torch::jit; |
35 | |
36 | void initONNXBindings(PyObject* module) { |
37 | auto m = py::handle(module).cast<py::module>(); |
38 | |
39 | // ONNX specific passes |
40 | m.def("_jit_pass_onnx_remove_print" , RemovePrintOps) |
41 | .def("_jit_pass_onnx_preprocess_caffe2" , PreprocessCaffe2Ops) |
42 | .def("_jit_pass_onnx" , ToONNX) |
43 | .def( |
44 | "_jit_pass_onnx_assign_output_shape" , |
45 | ::torch::wrap_pybind_function( |
46 | [](std::shared_ptr<Graph>& graph, |
47 | const std::vector<at::Tensor>& tensors, |
48 | const python::IODescriptor& desc, |
49 | bool onnx_shape_inference, |
50 | bool is_script, |
51 | int opset_version) { |
52 | ONNXAssignOutputShape( |
53 | graph, |
54 | tensors, |
55 | desc, |
56 | onnx_shape_inference, |
57 | is_script, |
58 | opset_version); |
59 | })) |
60 | .def( |
61 | "_jit_pass_onnx_function_substitution" , |
62 | wrap_pybind_function(ONNXFunctionCallSubstitution)) |
63 | .def( |
64 | "_jit_pass_onnx_autograd_function_process" , |
65 | wrap_pybind_function(ONNXAutogradFunctionProcess)) |
66 | .def( |
67 | "_jit_pass_onnx_peephole" , |
68 | ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph, |
69 | int opset_version, |
70 | bool fixed_batch_size) { |
71 | return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size); |
72 | })) |
73 | .def( |
74 | "_jit_pass_onnx_preprocess" , |
75 | ::torch::wrap_pybind_function(PreprocessForONNX)) |
76 | .def( |
77 | "_jit_pass_onnx_eval_peephole" , |
78 | ::torch::wrap_pybind_function( |
79 | [](std::shared_ptr<Graph>& graph, |
80 | std::map<std::string, IValue>& paramsDict) { |
81 | EvalPeepholeONNX(graph, paramsDict); |
82 | return paramsDict; |
83 | }), |
84 | pybind11::return_value_policy::move) |
85 | .def( |
86 | "_jit_pass_onnx_cast_all_constant_to_floating" , |
87 | ::torch::wrap_pybind_function(CastAllConstantToFloating)) |
88 | .def( |
89 | "_jit_pass_onnx_constant_fold" , |
90 | ::torch::wrap_pybind_function( |
91 | [](std::shared_ptr<Graph>& graph, |
92 | std::map<std::string, IValue>& paramsDict, |
93 | int opset_version) { |
94 | ConstantFoldONNX( |
95 | graph, |
96 | paramsDict, |
97 | opset_version); // overload resolution |
98 | return paramsDict; |
99 | }), |
100 | pybind11::return_value_policy::move) |
101 | .def( |
102 | "_jit_pass_onnx_eliminate_unused_items" , |
103 | ::torch::wrap_pybind_function( |
104 | [](std::shared_ptr<Graph>& graph, |
105 | std::map<std::string, IValue>& paramsDict) { |
106 | EliminateUnusedItemsONNX( |
107 | graph->block(), |
108 | paramsDict); // overload resolution |
109 | return paramsDict; |
110 | }), |
111 | pybind11::return_value_policy::move) |
112 | .def( |
113 | "_jit_pass_onnx_scalar_type_analysis" , |
114 | ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph, |
115 | bool lowprecision_cast, |
116 | int opset_version) { |
117 | return ScalarTypeAnalysisForONNX( |
118 | graph, lowprecision_cast, opset_version); |
119 | }), |
120 | py::arg("graph" ), |
121 | py::arg("lowprecision_cast" ) = true, |
122 | py::arg("opset_version" )) |
123 | .def( |
124 | "_jit_pass_onnx_remove_inplace_ops_for_onnx" , |
125 | ::torch::wrap_pybind_function(RemoveInplaceOpsForONNX)) |
126 | .def( |
127 | "_jit_pass_onnx_node_shape_type_inference" , |
128 | ::torch::wrap_pybind_function( |
129 | [](Node* n, |
130 | std::map<std::string, IValue>& params_dict, |
131 | int opset_version) { |
132 | ONNXShapeTypeInference(n, params_dict, opset_version); |
133 | })) |
134 | .def( |
135 | "_jit_pass_onnx_graph_shape_type_inference" , |
136 | ::torch::wrap_pybind_function( |
137 | [](std::shared_ptr<Graph>& graph, |
138 | std::map<std::string, IValue>& params_dict, |
139 | int opset_version) { |
140 | ONNXShapeTypeInference(graph, params_dict, opset_version); |
141 | }), |
142 | py::arg("graph" ), |
143 | py::arg("params_dict" ), |
144 | py::arg("opset_version" )) |
145 | .def( |
146 | "_jit_pass_onnx_set_dynamic_input_shape" , |
147 | ::torch::wrap_pybind_function(ONNXSetDynamicInputShape)) |
148 | .def("_jit_pass_onnx_lint" , torch::wrap_pybind_function(ONNXLintGraph)) |
149 | .def( |
150 | "_jit_pass_onnx_function_extraction" , |
151 | ::torch::wrap_pybind_function( |
152 | torch::jit::onnx::ONNXFunctionExtraction)) |
153 | .def("_jit_pass_onnx_block" , torch::wrap_pybind_function(BlockToONNX)) |
154 | .def( |
155 | "_jit_pass_onnx_unpack_quantized_weights" , |
156 | ::torch::wrap_pybind_function( |
157 | [](std::shared_ptr<Graph>& graph, |
158 | std::map<std::string, IValue>& paramsDict, |
159 | bool caffe2) { |
160 | UnpackQuantizedWeights(graph, paramsDict, caffe2); |
161 | return paramsDict; |
162 | }), |
163 | pybind11::return_value_policy::move) |
164 | .def( |
165 | "_jit_pass_onnx_quantization_insert_permutes" , |
166 | ::torch::wrap_pybind_function( |
167 | [](std::shared_ptr<Graph>& graph, |
168 | std::map<std::string, IValue>& paramsDict) { |
169 | insertPermutes(graph, paramsDict); |
170 | return paramsDict; |
171 | }), |
172 | pybind11::return_value_policy::move) |
173 | .def( |
174 | "_jit_onnx_list_model_parameters" , |
175 | ::torch::wrap_pybind_function( |
176 | [](Module& module) { return list_module_parameters(module); })) |
177 | .def( |
178 | "_jit_pass_prepare_division_for_onnx" , |
179 | ::torch::wrap_pybind_function(PrepareDivisionForONNX)) |
180 | .def( |
181 | "_jit_onnx_convert_pattern_from_subblock" , |
182 | ::torch::wrap_pybind_function(ConvertPatternFromSubblock)) |
183 | .def( |
184 | "_jit_pass_fixup_onnx_controlflow_node" , |
185 | ::torch::wrap_pybind_function(FixupONNXControlflowNode)) |
186 | .def( |
187 | "_jit_pass_onnx_deduplicate_initializers" , |
188 | ::torch::wrap_pybind_function( |
189 | [](std::shared_ptr<Graph>& graph, |
190 | std::map<std::string, IValue> params_dict, |
191 | bool is_train) { |
192 | DeduplicateInitializers(graph, params_dict, is_train); |
193 | return params_dict; |
194 | }), |
195 | pybind11::return_value_policy::move) |
196 | .def( |
197 | "_jit_pass_onnx_clear_scope_records" , |
198 | &torch::jit::onnx::ONNXClearScopeRecords) |
199 | .def( |
200 | "_jit_pass_onnx_track_scope_attributes" , |
201 | &torch::jit::onnx::ONNXTrackScopeAttributes) |
202 | .def( |
203 | "_jit_is_onnx_log_enabled" , |
204 | ::torch::jit::onnx::is_log_enabled, |
205 | "Returns whether ONNX logging is enabled or disabled." ) |
206 | .def( |
207 | "_jit_set_onnx_log_enabled" , |
208 | ::torch::jit::onnx::set_log_enabled, |
209 | "Enables or disables ONNX logging." ) |
210 | .def( |
211 | "_jit_set_onnx_log_output_stream" , |
212 | [](std::string stream_name = "stdout" ) -> void { |
213 | std::shared_ptr<std::ostream> out; |
214 | if (stream_name == "stdout" ) { |
215 | out = std::shared_ptr<std::ostream>( |
216 | &std::cout, [](std::ostream*) {}); |
217 | } else if (stream_name == "stderr" ) { |
218 | out = std::shared_ptr<std::ostream>( |
219 | &std::cerr, [](std::ostream*) {}); |
220 | } else { |
221 | std::cerr << "ERROR: only `stdout` and `stderr`" |
222 | << "are supported as `stream_name`" << std::endl; |
223 | } |
224 | ::torch::jit::onnx::set_log_output_stream(out); |
225 | }, |
226 | "Set specific file stream for ONNX logging." ) |
227 | .def( |
228 | "_jit_onnx_log" , |
229 | [](py::args args) -> void { |
230 | if (::torch::jit::onnx::is_log_enabled()) { |
231 | auto& out = ::torch::jit::onnx::_get_log_output_stream(); |
232 | for (auto arg : args) { |
233 | out << ::c10::str(arg); |
234 | } |
235 | out << std::endl; |
236 | } |
237 | }, |
238 | "Write `args` to the previously specified ONNX log stream." ) |
239 | .def( |
240 | "_jit_pass_onnx_assign_scoped_names_for_node_and_value" , |
241 | ::torch::wrap_pybind_function( |
242 | ::torch::jit::onnx::AssignScopedNamesForNodeAndValue), |
243 | "Assign informative scoped names for nodes and values." ) |
244 | .def( |
245 | "_jit_onnx_create_full_scope_name" , |
246 | ::torch::wrap_pybind_function( |
247 | ::torch::jit::onnx::ONNXScopeName::createFullScopeName), |
248 | "Create a full scope name from class name and variable name." ); |
249 | |
250 | m.def( |
251 | "_check_onnx_proto" , |
252 | ::torch::wrap_pybind_function([](const std::string& proto_string) { |
253 | check_onnx_proto(proto_string); |
254 | }), |
255 | py::arg("proto_string" )); |
256 | |
257 | auto onnx = m.def_submodule("_onnx" ); |
258 | py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType" ) |
259 | .value("UNDEFINED" , ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED) |
260 | .value("FLOAT" , ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT) |
261 | .value("UINT8" , ::ONNX_NAMESPACE::TensorProto_DataType_UINT8) |
262 | .value("INT8" , ::ONNX_NAMESPACE::TensorProto_DataType_INT8) |
263 | .value("UINT16" , ::ONNX_NAMESPACE::TensorProto_DataType_UINT16) |
264 | .value("INT16" , ::ONNX_NAMESPACE::TensorProto_DataType_INT16) |
265 | .value("INT32" , ::ONNX_NAMESPACE::TensorProto_DataType_INT32) |
266 | .value("INT64" , ::ONNX_NAMESPACE::TensorProto_DataType_INT64) |
267 | .value("STRING" , ::ONNX_NAMESPACE::TensorProto_DataType_STRING) |
268 | .value("BOOL" , ::ONNX_NAMESPACE::TensorProto_DataType_BOOL) |
269 | .value("FLOAT16" , ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) |
270 | .value("DOUBLE" , ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE) |
271 | .value("UINT32" , ::ONNX_NAMESPACE::TensorProto_DataType_UINT32) |
272 | .value("UINT64" , ::ONNX_NAMESPACE::TensorProto_DataType_UINT64) |
273 | .value("COMPLEX64" , ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64) |
274 | .value("COMPLEX128" , ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128) |
275 | .value("BFLOAT16" , ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16); |
276 | |
277 | py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes" ) |
278 | .value("ONNX" , OperatorExportTypes::ONNX) |
279 | .value("ONNX_ATEN" , OperatorExportTypes::ONNX_ATEN) |
280 | .value("ONNX_ATEN_FALLBACK" , OperatorExportTypes::ONNX_ATEN_FALLBACK) |
281 | .value("ONNX_FALLTHROUGH" , OperatorExportTypes::ONNX_FALLTHROUGH); |
282 | |
283 | py::enum_<TrainingMode>(onnx, "TrainingMode" ) |
284 | .value("EVAL" , TrainingMode::EVAL) |
285 | .value("PRESERVE" , TrainingMode::PRESERVE) |
286 | .value("TRAINING" , TrainingMode::TRAINING); |
287 | |
288 | onnx.attr("PRODUCER_VERSION" ) = py::str(TORCH_VERSION); |
289 | |
290 | #ifdef BUILD_CAFFE2 |
291 | onnx.attr("_CAFFE2_ATEN_FALLBACK" ) = true; |
292 | #else |
293 | onnx.attr("_CAFFE2_ATEN_FALLBACK" ) = false; |
294 | #endif |
295 | } |
296 | } // namespace onnx |
297 | } // namespace torch |
298 | |