1#include <onnx/onnx_pb.h>
2#include <torch/csrc/onnx/init.h>
3#include <torch/csrc/onnx/onnx.h>
4#include <torch/version.h>
5
6#include <torch/csrc/Exceptions.h>
7#include <torch/csrc/jit/passes/onnx.h>
8#include <torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.h>
9#include <torch/csrc/jit/passes/onnx/constant_fold.h>
10#include <torch/csrc/jit/passes/onnx/deduplicate_initializers.h>
11#include <torch/csrc/jit/passes/onnx/eliminate_unused_items.h>
12#include <torch/csrc/jit/passes/onnx/eval_peephole.h>
13#include <torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.h>
14#include <torch/csrc/jit/passes/onnx/function_extraction.h>
15#include <torch/csrc/jit/passes/onnx/function_substitution.h>
16#include <torch/csrc/jit/passes/onnx/list_model_parameters.h>
17#include <torch/csrc/jit/passes/onnx/naming.h>
18#include <torch/csrc/jit/passes/onnx/onnx_log.h>
19#include <torch/csrc/jit/passes/onnx/pattern_conversion/autograd_function_process.h>
20#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_conversion.h>
21#include <torch/csrc/jit/passes/onnx/pattern_conversion/pattern_encapsulation.h>
22#include <torch/csrc/jit/passes/onnx/peephole.h>
23#include <torch/csrc/jit/passes/onnx/prepare_division_for_onnx.h>
24#include <torch/csrc/jit/passes/onnx/preprocess_for_onnx.h>
25#include <torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.h>
26#include <torch/csrc/jit/passes/onnx/scalar_type_analysis.h>
27#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
28#include <torch/csrc/jit/passes/onnx/unpack_quantized_weights.h>
29#include <torch/csrc/jit/serialization/export.h>
30
31namespace torch {
32namespace onnx {
33
34using namespace torch::jit;
35
36void initONNXBindings(PyObject* module) {
37 auto m = py::handle(module).cast<py::module>();
38
39 // ONNX specific passes
40 m.def("_jit_pass_onnx_remove_print", RemovePrintOps)
41 .def("_jit_pass_onnx_preprocess_caffe2", PreprocessCaffe2Ops)
42 .def("_jit_pass_onnx", ToONNX)
43 .def(
44 "_jit_pass_onnx_assign_output_shape",
45 ::torch::wrap_pybind_function(
46 [](std::shared_ptr<Graph>& graph,
47 const std::vector<at::Tensor>& tensors,
48 const python::IODescriptor& desc,
49 bool onnx_shape_inference,
50 bool is_script,
51 int opset_version) {
52 ONNXAssignOutputShape(
53 graph,
54 tensors,
55 desc,
56 onnx_shape_inference,
57 is_script,
58 opset_version);
59 }))
60 .def(
61 "_jit_pass_onnx_function_substitution",
62 wrap_pybind_function(ONNXFunctionCallSubstitution))
63 .def(
64 "_jit_pass_onnx_autograd_function_process",
65 wrap_pybind_function(ONNXAutogradFunctionProcess))
66 .def(
67 "_jit_pass_onnx_peephole",
68 ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
69 int opset_version,
70 bool fixed_batch_size) {
71 return PeepholeOptimizeONNX(graph, opset_version, fixed_batch_size);
72 }))
73 .def(
74 "_jit_pass_onnx_preprocess",
75 ::torch::wrap_pybind_function(PreprocessForONNX))
76 .def(
77 "_jit_pass_onnx_eval_peephole",
78 ::torch::wrap_pybind_function(
79 [](std::shared_ptr<Graph>& graph,
80 std::map<std::string, IValue>& paramsDict) {
81 EvalPeepholeONNX(graph, paramsDict);
82 return paramsDict;
83 }),
84 pybind11::return_value_policy::move)
85 .def(
86 "_jit_pass_onnx_cast_all_constant_to_floating",
87 ::torch::wrap_pybind_function(CastAllConstantToFloating))
88 .def(
89 "_jit_pass_onnx_constant_fold",
90 ::torch::wrap_pybind_function(
91 [](std::shared_ptr<Graph>& graph,
92 std::map<std::string, IValue>& paramsDict,
93 int opset_version) {
94 ConstantFoldONNX(
95 graph,
96 paramsDict,
97 opset_version); // overload resolution
98 return paramsDict;
99 }),
100 pybind11::return_value_policy::move)
101 .def(
102 "_jit_pass_onnx_eliminate_unused_items",
103 ::torch::wrap_pybind_function(
104 [](std::shared_ptr<Graph>& graph,
105 std::map<std::string, IValue>& paramsDict) {
106 EliminateUnusedItemsONNX(
107 graph->block(),
108 paramsDict); // overload resolution
109 return paramsDict;
110 }),
111 pybind11::return_value_policy::move)
112 .def(
113 "_jit_pass_onnx_scalar_type_analysis",
114 ::torch::wrap_pybind_function([](std::shared_ptr<Graph>& graph,
115 bool lowprecision_cast,
116 int opset_version) {
117 return ScalarTypeAnalysisForONNX(
118 graph, lowprecision_cast, opset_version);
119 }),
120 py::arg("graph"),
121 py::arg("lowprecision_cast") = true,
122 py::arg("opset_version"))
123 .def(
124 "_jit_pass_onnx_remove_inplace_ops_for_onnx",
125 ::torch::wrap_pybind_function(RemoveInplaceOpsForONNX))
126 .def(
127 "_jit_pass_onnx_node_shape_type_inference",
128 ::torch::wrap_pybind_function(
129 [](Node* n,
130 std::map<std::string, IValue>& params_dict,
131 int opset_version) {
132 ONNXShapeTypeInference(n, params_dict, opset_version);
133 }))
134 .def(
135 "_jit_pass_onnx_graph_shape_type_inference",
136 ::torch::wrap_pybind_function(
137 [](std::shared_ptr<Graph>& graph,
138 std::map<std::string, IValue>& params_dict,
139 int opset_version) {
140 ONNXShapeTypeInference(graph, params_dict, opset_version);
141 }),
142 py::arg("graph"),
143 py::arg("params_dict"),
144 py::arg("opset_version"))
145 .def(
146 "_jit_pass_onnx_set_dynamic_input_shape",
147 ::torch::wrap_pybind_function(ONNXSetDynamicInputShape))
148 .def("_jit_pass_onnx_lint", torch::wrap_pybind_function(ONNXLintGraph))
149 .def(
150 "_jit_pass_onnx_function_extraction",
151 ::torch::wrap_pybind_function(
152 torch::jit::onnx::ONNXFunctionExtraction))
153 .def("_jit_pass_onnx_block", torch::wrap_pybind_function(BlockToONNX))
154 .def(
155 "_jit_pass_onnx_unpack_quantized_weights",
156 ::torch::wrap_pybind_function(
157 [](std::shared_ptr<Graph>& graph,
158 std::map<std::string, IValue>& paramsDict,
159 bool caffe2) {
160 UnpackQuantizedWeights(graph, paramsDict, caffe2);
161 return paramsDict;
162 }),
163 pybind11::return_value_policy::move)
164 .def(
165 "_jit_pass_onnx_quantization_insert_permutes",
166 ::torch::wrap_pybind_function(
167 [](std::shared_ptr<Graph>& graph,
168 std::map<std::string, IValue>& paramsDict) {
169 insertPermutes(graph, paramsDict);
170 return paramsDict;
171 }),
172 pybind11::return_value_policy::move)
173 .def(
174 "_jit_onnx_list_model_parameters",
175 ::torch::wrap_pybind_function(
176 [](Module& module) { return list_module_parameters(module); }))
177 .def(
178 "_jit_pass_prepare_division_for_onnx",
179 ::torch::wrap_pybind_function(PrepareDivisionForONNX))
180 .def(
181 "_jit_onnx_convert_pattern_from_subblock",
182 ::torch::wrap_pybind_function(ConvertPatternFromSubblock))
183 .def(
184 "_jit_pass_fixup_onnx_controlflow_node",
185 ::torch::wrap_pybind_function(FixupONNXControlflowNode))
186 .def(
187 "_jit_pass_onnx_deduplicate_initializers",
188 ::torch::wrap_pybind_function(
189 [](std::shared_ptr<Graph>& graph,
190 std::map<std::string, IValue> params_dict,
191 bool is_train) {
192 DeduplicateInitializers(graph, params_dict, is_train);
193 return params_dict;
194 }),
195 pybind11::return_value_policy::move)
196 .def(
197 "_jit_pass_onnx_clear_scope_records",
198 &torch::jit::onnx::ONNXClearScopeRecords)
199 .def(
200 "_jit_pass_onnx_track_scope_attributes",
201 &torch::jit::onnx::ONNXTrackScopeAttributes)
202 .def(
203 "_jit_is_onnx_log_enabled",
204 ::torch::jit::onnx::is_log_enabled,
205 "Returns whether ONNX logging is enabled or disabled.")
206 .def(
207 "_jit_set_onnx_log_enabled",
208 ::torch::jit::onnx::set_log_enabled,
209 "Enables or disables ONNX logging.")
210 .def(
211 "_jit_set_onnx_log_output_stream",
212 [](std::string stream_name = "stdout") -> void {
213 std::shared_ptr<std::ostream> out;
214 if (stream_name == "stdout") {
215 out = std::shared_ptr<std::ostream>(
216 &std::cout, [](std::ostream*) {});
217 } else if (stream_name == "stderr") {
218 out = std::shared_ptr<std::ostream>(
219 &std::cerr, [](std::ostream*) {});
220 } else {
221 std::cerr << "ERROR: only `stdout` and `stderr`"
222 << "are supported as `stream_name`" << std::endl;
223 }
224 ::torch::jit::onnx::set_log_output_stream(out);
225 },
226 "Set specific file stream for ONNX logging.")
227 .def(
228 "_jit_onnx_log",
229 [](py::args args) -> void {
230 if (::torch::jit::onnx::is_log_enabled()) {
231 auto& out = ::torch::jit::onnx::_get_log_output_stream();
232 for (auto arg : args) {
233 out << ::c10::str(arg);
234 }
235 out << std::endl;
236 }
237 },
238 "Write `args` to the previously specified ONNX log stream.")
239 .def(
240 "_jit_pass_onnx_assign_scoped_names_for_node_and_value",
241 ::torch::wrap_pybind_function(
242 ::torch::jit::onnx::AssignScopedNamesForNodeAndValue),
243 "Assign informative scoped names for nodes and values.")
244 .def(
245 "_jit_onnx_create_full_scope_name",
246 ::torch::wrap_pybind_function(
247 ::torch::jit::onnx::ONNXScopeName::createFullScopeName),
248 "Create a full scope name from class name and variable name.");
249
250 m.def(
251 "_check_onnx_proto",
252 ::torch::wrap_pybind_function([](const std::string& proto_string) {
253 check_onnx_proto(proto_string);
254 }),
255 py::arg("proto_string"));
256
257 auto onnx = m.def_submodule("_onnx");
258 py::enum_<::ONNX_NAMESPACE::TensorProto_DataType>(onnx, "TensorProtoDataType")
259 .value("UNDEFINED", ::ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED)
260 .value("FLOAT", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT)
261 .value("UINT8", ::ONNX_NAMESPACE::TensorProto_DataType_UINT8)
262 .value("INT8", ::ONNX_NAMESPACE::TensorProto_DataType_INT8)
263 .value("UINT16", ::ONNX_NAMESPACE::TensorProto_DataType_UINT16)
264 .value("INT16", ::ONNX_NAMESPACE::TensorProto_DataType_INT16)
265 .value("INT32", ::ONNX_NAMESPACE::TensorProto_DataType_INT32)
266 .value("INT64", ::ONNX_NAMESPACE::TensorProto_DataType_INT64)
267 .value("STRING", ::ONNX_NAMESPACE::TensorProto_DataType_STRING)
268 .value("BOOL", ::ONNX_NAMESPACE::TensorProto_DataType_BOOL)
269 .value("FLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_FLOAT16)
270 .value("DOUBLE", ::ONNX_NAMESPACE::TensorProto_DataType_DOUBLE)
271 .value("UINT32", ::ONNX_NAMESPACE::TensorProto_DataType_UINT32)
272 .value("UINT64", ::ONNX_NAMESPACE::TensorProto_DataType_UINT64)
273 .value("COMPLEX64", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX64)
274 .value("COMPLEX128", ::ONNX_NAMESPACE::TensorProto_DataType_COMPLEX128)
275 .value("BFLOAT16", ::ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16);
276
277 py::enum_<OperatorExportTypes>(onnx, "OperatorExportTypes")
278 .value("ONNX", OperatorExportTypes::ONNX)
279 .value("ONNX_ATEN", OperatorExportTypes::ONNX_ATEN)
280 .value("ONNX_ATEN_FALLBACK", OperatorExportTypes::ONNX_ATEN_FALLBACK)
281 .value("ONNX_FALLTHROUGH", OperatorExportTypes::ONNX_FALLTHROUGH);
282
283 py::enum_<TrainingMode>(onnx, "TrainingMode")
284 .value("EVAL", TrainingMode::EVAL)
285 .value("PRESERVE", TrainingMode::PRESERVE)
286 .value("TRAINING", TrainingMode::TRAINING);
287
288 onnx.attr("PRODUCER_VERSION") = py::str(TORCH_VERSION);
289
290#ifdef BUILD_CAFFE2
291 onnx.attr("_CAFFE2_ATEN_FALLBACK") = true;
292#else
293 onnx.attr("_CAFFE2_ATEN_FALLBACK") = false;
294#endif
295}
296} // namespace onnx
297} // namespace torch
298