1 | #include <torch/csrc/jit/passes/onnx.h> |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/irange.h> |
6 | #include <torch/csrc/autograd/function.h> |
7 | #include <torch/csrc/autograd/symbolic.h> |
8 | #include <torch/csrc/jit/ir/constants.h> |
9 | #include <torch/csrc/jit/jit_log.h> |
10 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
11 | #include <torch/csrc/jit/passes/onnx/constant_map.h> |
12 | #include <torch/csrc/jit/passes/onnx/helper.h> |
13 | #include <torch/csrc/jit/passes/onnx/onnx_log.h> |
14 | #include <torch/csrc/jit/passes/onnx/shape_type_inference.h> |
15 | #include <torch/csrc/jit/python/python_ir.h> |
16 | #include <torch/csrc/utils/pybind.h> |
17 | #include <sstream> |
18 | #include <unordered_map> |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | void removePrintOps(Block* block) { |
23 | for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; |
24 | ++it) { |
25 | for (auto b : it->blocks()) { |
26 | removePrintOps(b); |
27 | } |
28 | if (it->kind() == prim::Print || it->kind() == aten::warn) { |
29 | for (size_t i = 0; i < it->inputs().size();) { |
30 | auto input = it->inputs().at(i); |
31 | // only handling constants bc of potential side effects |
32 | if (input->uses().size() == 1 && |
33 | input->node()->kind() == prim::Constant) { |
34 | it->removeInput(i); |
35 | input->node()->destroy(); |
36 | } else { |
37 | ++i; |
38 | } |
39 | } |
40 | it.destroyCurrent(); |
41 | } |
42 | } |
43 | } |
44 | |
45 | void RemovePrintOps(std::shared_ptr<Graph>& graph) { |
46 | removePrintOps(graph->block()); |
47 | GRAPH_DUMP("After RemovePrintOps: " , graph); |
48 | } |
49 | |
50 | void checkONNXCompatibility(const c10::FunctionSchema& schema) { |
51 | // in ONNX, all inputs are tensors, no support for tensor list |
52 | // so at most one input tensor list is supported |
53 | bool has_tensor_list = false; |
54 | const auto& args = schema.arguments(); |
55 | for (const auto& arg : args) { |
56 | if (arg.name() == "_caffe2_preallocated_outputs" ) { |
57 | continue; |
58 | } |
59 | auto type = arg.type(); |
60 | if (type->kind() == TypeKind::OptionalType) { |
61 | type = reinterpret_cast<OptionalType*>(type.get())->getElementType(); |
62 | // recursive optional type is not supported |
63 | TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType); |
64 | } |
65 | if (type->kind() == TypeKind::ListType) { |
66 | const auto& elem_type = |
67 | reinterpret_cast<ListType*>(type.get())->getElementType(); |
68 | if (elem_type->isSubtypeOf(*TensorType::get())) { |
69 | TORCH_INTERNAL_ASSERT( |
70 | !has_tensor_list, |
71 | "ONNX export supports at most one TensorList as input." ); |
72 | has_tensor_list = true; |
73 | } |
74 | } |
75 | } |
76 | } |
77 | |
78 | void preprocessCaffe2Ops(Block* block) { |
79 | for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; |
80 | ++it) { |
81 | for (auto b : it->blocks()) { |
82 | preprocessCaffe2Ops(b); |
83 | } |
84 | if (it->kind().is_caffe2()) { |
85 | const auto& schema = it->schema(); |
86 | checkONNXCompatibility(schema); |
87 | std::vector<Value*> origin_inputs; |
88 | for (Value* v : it->inputs()) { |
89 | origin_inputs.push_back(v); |
90 | } |
91 | it->removeAllInputs(); |
92 | const auto& args = schema.arguments(); |
93 | size_t origin_inputs_index = 0; |
94 | for (const auto& arg : args) { |
95 | const auto& type = arg.type(); |
96 | TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size()); |
97 | const auto& origin_input = origin_inputs[origin_inputs_index++]; |
98 | if (type->kind() == TypeKind::OptionalType && |
99 | origin_input->mustBeNone()) { |
100 | continue; |
101 | } |
102 | if (type->isSubtypeOf(*TensorType::get())) { |
103 | it->addInput(origin_input); |
104 | } else if ( |
105 | type->kind() == TypeKind::BoolType || |
106 | type->kind() == TypeKind::IntType) { |
107 | const auto* constant_node = origin_input->node(); |
108 | TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); |
109 | it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value)); |
110 | } else if (type->kind() == TypeKind::FloatType) { |
111 | const auto* constant_node = origin_input->node(); |
112 | TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); |
113 | it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value)); |
114 | } else if (type->kind() == TypeKind::StringType) { |
115 | const auto* constant_node = origin_input->node(); |
116 | TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); |
117 | it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value)); |
118 | } else if (type->kind() == TypeKind::ListType) { |
119 | const auto& list_node = origin_input->node(); |
120 | const auto& elem_type = type->castRaw<ListType>()->getElementType(); |
121 | TORCH_INTERNAL_ASSERT( |
122 | list_node->kind() == prim::ListConstruct || |
123 | list_node->kind() == prim::Constant); |
124 | if (elem_type->isSubtypeOf(*TensorType::get())) { |
125 | TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct); |
126 | const auto& tensor_list = origin_input->node()->inputs(); |
127 | for (const auto& t : tensor_list) { |
128 | it->addInput(t); |
129 | } |
130 | } else if (elem_type->kind() == TypeKind::FloatType) { |
131 | std::vector<double> values; |
132 | if (list_node->kind() == prim::ListConstruct) { |
133 | for (const auto* elem_input : list_node->inputs()) { |
134 | const auto* constant_node = elem_input->node(); |
135 | TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant); |
136 | values.push_back(constant_node->f(attr::value)); |
137 | } |
138 | } else { // is a constant list |
139 | values = list_node->fs(attr::value); |
140 | } |
141 | it->fs_(Symbol::attr(arg.name()), values); |
142 | } else { |
143 | throw std::runtime_error( |
144 | "Unhandled scalar arg: " + arg.name() + |
145 | ", type: " + c10::typeKindToString(elem_type->kind())); |
146 | } |
147 | } else { |
148 | throw std::runtime_error( |
149 | "Unsupported input type of arg " + arg.name() + |
150 | " in Caffe2 operator: " + c10::typeKindToString(type->kind())); |
151 | } |
152 | } |
153 | } |
154 | } |
155 | EliminateDeadCode( |
156 | block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); |
157 | } |
158 | |
159 | void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) { |
160 | preprocessCaffe2Ops(graph->block()); |
161 | GRAPH_DUMP("After PreprocessCaffe2Ops: " , graph); |
162 | } |
163 | |
164 | // Transform PythonOps into Nodes that match ONNX semantics. |
165 | std::shared_ptr<Graph> ToONNX( |
166 | std::shared_ptr<Graph>& graph, |
167 | ::torch::onnx::OperatorExportTypes operator_export_type) { |
168 | auto constant_value_map = ConstantValueMap::getInstance(); |
169 | ConstantValueMap::ClearMaps(); |
170 | auto new_graph = std::make_shared<Graph>(graph->current_scope()); |
171 | std::unordered_map<Value*, Value*> env; |
172 | try { |
173 | BlockToONNX(graph->block(), new_graph->block(), operator_export_type, env); |
174 | } catch (std::runtime_error& ex) { |
175 | ONNX_LOG( |
176 | "ONNX graph being constructed during exception:\n" , |
177 | new_graph->toString()); |
178 | throw; |
179 | } |
180 | GRAPH_DUMP("after ToONNX: " , new_graph); |
181 | ConstantValueMap::ClearMaps(); |
182 | return new_graph; |
183 | } |
184 | |
185 | // BlockToONNX. |
186 | // is_sub_block = true means the old_block (aten graph) is in the sub block |
187 | // (e.g., if sub block), and we want to convert it into its parent block in onnx |
188 | // graph. In this case, we don't register the input/output or eliminate the dead |
189 | // code. |
190 | std::unordered_map<Value*, Value*> BlockToONNX( |
191 | Block* old_block, |
192 | Block* new_block, |
193 | ::torch::onnx::OperatorExportTypes operator_export_type, |
194 | std::unordered_map<Value*, Value*>& env, |
195 | bool is_sub_block) { |
196 | torch::autograd::SymbolicContext ctx{}; |
197 | ctx.block = new_block; |
198 | |
199 | GRAPH_DEBUG( |
200 | "BlockToONNX: graph of old block: " , |
201 | old_block->owningGraph()->toString()); |
202 | |
203 | // Initialize context and environment |
204 | if (!is_sub_block) { |
205 | for (auto input : old_block->inputs()) { |
206 | auto n = ctx.block->addInput()->copyMetadata(input); |
207 | env[input] = n; |
208 | } |
209 | } |
210 | |
211 | // Finally, visit all nodes in the graph |
212 | for (auto node : old_block->nodes()) { |
213 | NodeToONNX(node, ctx.block, operator_export_type, env); |
214 | } |
215 | |
216 | if (is_sub_block) { |
217 | return env; |
218 | } |
219 | |
220 | for (auto output : old_block->outputs()) { |
221 | ctx.block->registerOutput(env.at(output)); |
222 | } |
223 | // Run dce to clean-up unused functional and inplace ops. |
224 | EliminateDeadCode( |
225 | ctx.block, |
226 | true, |
227 | DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS); |
228 | |
229 | return {}; |
230 | } |
231 | |
232 | bool ConstantFoldCondition(torch::jit::Value* output) { |
233 | auto fold_condition = output->node()->kind() != c10::onnx::Constant && |
234 | ConstantValueMap::HasValue(output->debugName()); |
235 | auto reliable_value = |
236 | ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false); |
237 | return fold_condition && reliable_value; |
238 | } |
239 | |
240 | void NodeToONNX( |
241 | Node* old_node, |
242 | Block* new_block, |
243 | ::torch::onnx::OperatorExportTypes operator_export_type, |
244 | std::unordered_map<Value*, Value*>& env) { |
245 | py::object onnx = py::module::import("torch.onnx" ); |
246 | py::object onnx_globals = py::module::import("torch.onnx._globals" ); |
247 | py::object onnx_registration = |
248 | py::module::import("torch.onnx._internal.registration" ); |
249 | |
250 | // Setup all the lambda helper functions. |
251 | |
252 | // Returns a node that n maps to in the new graph |
253 | auto envFn = [&env](Value* n) -> Value* { |
254 | auto it = env.find(n); |
255 | TORCH_CHECK(it != env.end(), "Dangling node reference" ); |
256 | TORCH_CHECK(it->second, "Unused node was subsequently used" ); |
257 | return it->second; |
258 | }; |
259 | |
260 | // Put the new outputs in our environment map, and copy the type from the |
261 | // input graph if they were not set by the symbolic. This is called only |
262 | // with results of symbolic call (not for nodes that are just cloned). |
263 | auto setOutputs = [&](const std::string& op_name, |
264 | Node* node, |
265 | const value_list& outputs) { |
266 | auto old_outputs = node->outputs(); |
267 | // Count all outputs, excluding Handles |
268 | auto num_old_outputs = old_outputs.size(); |
269 | if (outputs.size() != num_old_outputs) { |
270 | std::ostringstream ss; |
271 | ss << "symbolic for " << op_name |
272 | << " produced an incorrect number of outputs (expected " ; |
273 | ss << num_old_outputs << ", but got " << outputs.size() << ")" ; |
274 | throw std::runtime_error(ss.str()); |
275 | } |
276 | // For const node, it does not need params_dict info, so set it to {}. |
277 | const ParamMap empty_params_dict = {}; |
278 | auto opset_version = py::cast<int>( |
279 | onnx_globals.attr("GLOBALS" ).attr("export_onnx_opset_version" )); |
280 | for (const auto i : c10::irange(num_old_outputs)) { |
281 | auto old = old_outputs[i]; |
282 | if (outputs[i]) { |
283 | bool exist_in_env = |
284 | (env.end() != |
285 | std::find_if( |
286 | env.begin(), env.end(), [&outputs, i](const auto& vt) { |
287 | return vt.second == outputs[i]; |
288 | })); |
289 | // Update ONNX value debug name with ATen value debug name if existed. |
290 | // Skip if ONNX value already exist in environment. |
291 | // This implies the op is a noop, and the value is owned by |
292 | // other node created elsewhere. |
293 | if (old->hasDebugName() && !exist_in_env) { |
294 | auto old_name = outputs[i]->debugName(); |
295 | auto new_name = old->debugNameBase(); |
296 | auto debug_names = new_block->owningGraph()->debugNames(); |
297 | auto exist_name = debug_names.find(new_name); |
298 | outputs[i]->setDebugName(new_name); |
299 | if (exist_name != debug_names.end()) { |
300 | // setDebugName changes name of existing value with same name. |
301 | // Set again to revert the changes, but update name for new value |
302 | // with suffix. |
303 | exist_name->second->setDebugName(new_name); |
304 | } |
305 | ConstantValueMap::UpdateValueName(old_name, outputs[i]->debugName()); |
306 | } |
307 | // Allow symbolic() to skip specifying the type of the return node. |
308 | // Unfortunately, they are on the hook for all internal nodes |
309 | // (though in practice, the types are not computed.) |
310 | // |
311 | // If onnx shape inference is turned on, the new outputs will have |
312 | // types inferred, and they will be merged with the old types. |
313 | if (ConstantFoldCondition(outputs[i])) { |
314 | // Create a const node if the node output value is in |
315 | // ConstantValueMap. |
316 | auto value = |
317 | ConstantValueMap::GetValue(outputs[i]->debugName()).value(); |
318 | Node* const_node = |
319 | new_block->owningGraph()->create(c10::onnx::Constant); |
320 | const_node->t_(attr::value, value); |
321 | const_node->output()->setType(TensorType::create(value)); |
322 | |
323 | // Copy over source location and scope information to all nodes |
324 | // created by the symbolic |
325 | const_node->copyMetadata(node); |
326 | new_block->appendNode(const_node); |
327 | ONNXShapeTypeInference(const_node, empty_params_dict, opset_version); |
328 | env[old] = const_node->output(); |
329 | } else { |
330 | // An update in ConstantValueMap is also needed here, since |
331 | // the user setType can be only accessed in this step, and it |
332 | // should be reliable. |
333 | MergeInferredTypeAndSetMap( |
334 | outputs[i], old->type(), outputs[i]->type()); |
335 | // non ONNX node with no type given will throw out the warnings here. |
336 | UpdateReliable( |
337 | outputs[i], |
338 | AreInputsReliableOrStatic(outputs[i]->node()), |
339 | /*no_type_warning=*/true); |
340 | // For the node type that does not have ComputeConstant logic, it may |
341 | // have reliable shape but its shape is not in ConstantValueMap. So we |
342 | // need to update ConstantValueMap. |
343 | UpdateShapeConstantIfReliable(outputs[i]); |
344 | |
345 | // Copy over source location and scope information to all nodes |
346 | // created by the symbolic |
347 | // Do not set metadata if outputs[i] is already in env. |
348 | if (!exist_in_env) { |
349 | outputs[i]->node()->copyMetadata(node); |
350 | } |
351 | env[old] = outputs[i]; |
352 | } |
353 | } else { |
354 | // Null output means that the ONNX op doesn't have outputs corresponding |
355 | // to certain PyTorch outputs |
356 | env[old] = nullptr; |
357 | if (!old->uses().empty()) { |
358 | std::ostringstream ss; |
359 | ss << "symbolic for " << op_name << " returned None for the output " |
360 | << i; |
361 | ss << " (indicating conversion for that particular output is not supported), " ; |
362 | ss << "but the network uses this output later" ; |
363 | // TODO: Say what actually used it |
364 | throw std::runtime_error(ss.str()); |
365 | } |
366 | } |
367 | } |
368 | }; |
369 | |
370 | // Clone the node and add it to the new graph |
371 | auto cloneNode = [&](Node* node) { |
372 | auto n_ = new_block->appendNode( |
373 | new_block->owningGraph()->createClone(node, envFn)); |
374 | for (const auto i : c10::irange(node->outputs().size())) { |
375 | // n_->outputs()[i]->setType(node->outputs()[i]->type()); |
376 | env[node->output(i)] = n_->output(i); |
377 | } |
378 | }; |
379 | |
380 | // Inline the prim::PythonOp sub-block nodes and append them to the onnx graph |
381 | auto inlineAutograd = [&](Node* PythonOpNode) { |
382 | for (auto subblock : PythonOpNode->blocks()) { |
383 | for (const auto i : c10::irange(PythonOpNode->inputs().size())) { |
384 | env[subblock->inputs()[i]] = env[PythonOpNode->inputs()[i]]; |
385 | } |
386 | for (auto* node : subblock->nodes()) { |
387 | NodeToONNX(node, new_block, operator_export_type, env); |
388 | } |
389 | for (const auto i : c10::irange(PythonOpNode->outputs().size())) { |
390 | env[PythonOpNode->outputs()[i]] = env[subblock->outputs()[i]]; |
391 | } |
392 | } |
393 | }; |
394 | |
395 | // Cast output of symbolic() python implementation |
396 | auto processSymbolicOutput = [&](const std::string& op_name, |
397 | Node* n, |
398 | const py::object& raw_output) { |
399 | if (raw_output.ptr() == Py_None) { |
400 | cloneNode(n); |
401 | return; |
402 | } |
403 | // Cast the outputs back to C++ and put them in the new graph |
404 | std::vector<Value*> outputs; |
405 | try { |
406 | if (py::isinstance<Value>(raw_output)) { |
407 | outputs = value_list{py::cast<Value*>(raw_output)}; |
408 | } else { |
409 | outputs = py::cast<std::vector<Value*>>(raw_output); |
410 | } |
411 | } catch (const std::exception& ex) { |
412 | std::ostringstream ss; |
413 | ss << "Error casting results of symbolic for " << op_name |
414 | << ": expected to return list of op nodes, instead received type ''" |
415 | << py::str(raw_output.get_type()) << "': " << py::str(raw_output); |
416 | throw std::runtime_error(ss.str()); |
417 | } |
418 | |
419 | setOutputs(op_name, n, outputs); |
420 | }; |
421 | |
422 | auto callPySymbolicFunction = [&](Node* n) { |
423 | // The idea is delegate as much of the actual argument massaging to |
424 | // Python as possible |
425 | |
426 | py::tuple py_inputs(n->inputs().size()); |
427 | Py_ssize_t input_nr = 0; |
428 | for (auto* input : n->inputs()) { |
429 | py_inputs[input_nr++] = py::cast(envFn(input)); |
430 | } |
431 | |
432 | Graph* g = new_block->owningGraph(); |
433 | std::unordered_set<Node*> nodes_before; |
434 | for (auto node : g->nodes()) { |
435 | nodes_before.emplace(node); |
436 | } |
437 | |
438 | WithInsertPoint insert_point_guard(new_block); |
439 | WithCurrentScope scope_guard(*g, n->scope()); |
440 | |
441 | // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to |
442 | // Python. Check #87343 for details. |
443 | py::object raw_output = onnx.attr("_run_symbolic_function" )( |
444 | g->shared_from_this(), |
445 | new_block, |
446 | n, |
447 | py_inputs, |
448 | env, |
449 | operator_export_type); |
450 | |
451 | // Find new nodes that have been created by _run_symbolic_function and |
452 | // propagate metadata |
453 | for (auto node : g->nodes()) { |
454 | if (nodes_before.find(node) == nodes_before.end()) { |
455 | node->copyMetadata(n); |
456 | } |
457 | } |
458 | |
459 | // TODO: Assert it's an ATen identifier??? |
460 | // (Sometimes it's not...) |
461 | processSymbolicOutput(n->kind().toUnqualString(), n, raw_output); |
462 | GRAPH_DUMP("after processSymbolicOutput: " , g); |
463 | }; |
464 | |
465 | auto callPySymbolicMethod = [&](ConcretePythonOp* op) { |
466 | // Test if there is a symbolic function; bail if there is not |
467 | auto pyobj = py::handle(op->pyobj.get()); |
468 | auto func = op->autogradFunction(); |
469 | if (func) { |
470 | pyobj = func->get(); |
471 | } |
472 | |
473 | py::object opset_version = |
474 | onnx_globals.attr("GLOBALS" ).attr("export_onnx_opset_version" ); |
475 | // NOTE(justinchuby): Call the internal registry to register the symbolic |
476 | // method defined in the module. |
477 | bool is_registered_op = |
478 | onnx_registration.attr("registry" ) |
479 | .attr("is_registered_op" )("prim::PythonOp" , opset_version) |
480 | .cast<bool>(); |
481 | if (!py::hasattr(pyobj, "symbolic" ) && !is_registered_op) { |
482 | // Inline the subgraph within the prim::PythonOp unless |
483 | // either of these conditions are satisfied |
484 | // 1. The torch.autograd.Function class of this node object has `symbolic` |
485 | // method defined. |
486 | // 2. Custom export symbolic is registered for prim::PythonOp. |
487 | if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX || |
488 | operator_export_type == |
489 | ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) { |
490 | try { |
491 | inlineAutograd(op); |
492 | } catch (const std::exception& ex) { |
493 | TORCH_WARN( |
494 | "Unable to inline PythonOp: " , |
495 | op->name(), |
496 | " due to the following exception\n" , |
497 | ex.what(), |
498 | "prim::PythonOp will be exported as is and without being inlined\n" , |
499 | "Try exporting with the following alternatives: \n" , |
500 | "1) Set operator_export_type to ONNX_FALLTHROUGH mode\n" , |
501 | "2) Register a symbolic method for the prim::PythonOp " , |
502 | op->name()); |
503 | cloneNode(op); |
504 | } |
505 | } else { |
506 | cloneNode(op); |
507 | } |
508 | return; |
509 | } |
510 | |
511 | // Prepare args for Python. First one is the graph, and is followed |
512 | // by regular args, with Variables replaced by corresponding nodes. |
513 | Py_ssize_t input_nr = 0; |
514 | py::tuple py_symbolic_args(op->cconv.size()); |
515 | auto inputs = op->inputs(); |
516 | auto node_it = inputs.begin(); |
517 | auto scalar_it = op->scalar_args.begin(); |
518 | for (auto arg_type : op->cconv) { |
519 | py::object obj; |
520 | if (arg_type == 'c') { |
521 | TORCH_CHECK( |
522 | scalar_it != op->scalar_args.end(), |
523 | "expected too many scalar args" ); |
524 | obj = py::reinterpret_borrow<py::object>( |
525 | py::handle((scalar_it++)->get())); |
526 | } else if (arg_type == 'd') { |
527 | TORCH_CHECK(node_it != inputs.end(), "expected too many inputs" ); |
528 | obj = py::cast(envFn(*node_it++)); |
529 | } else { |
530 | throw std::runtime_error("unexpected calling convention" ); |
531 | } |
532 | py_symbolic_args[input_nr++] = obj; |
533 | } |
534 | |
535 | WithInsertPoint insert_point_guard(new_block); |
536 | WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope()); |
537 | |
538 | if (py::hasattr(pyobj, "symbolic" )) { |
539 | // Call the symbolic function |
540 | // Use a little trampoline function so we can give good error messages |
541 | // upon argument mismatch |
542 | // Register as a custom operator |
543 | // TODO: Find a more elegant way to do this without having to touch |
544 | // internal Python modules. |
545 | // TODO(justinchuby): Define a namespace for these Python Ops. |
546 | onnx_registration.attr("registry" ) |
547 | .attr("register" )( |
548 | "::" + op->name(), |
549 | opset_version, |
550 | pyobj.attr("symbolic" ), |
551 | /* custom */ true); |
552 | |
553 | // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to |
554 | // Python. Check #87343 for details. |
555 | py::object raw_output = onnx.attr("_run_symbolic_method" )( |
556 | new_block->owningGraph()->shared_from_this(), |
557 | op->name(), |
558 | pyobj.attr("symbolic" ), |
559 | py_symbolic_args); |
560 | |
561 | processSymbolicOutput(op->name(), op, raw_output); |
562 | } else { |
563 | TORCH_INTERNAL_ASSERT(is_registered_op); |
564 | Node* n = static_cast<Node*>(op); |
565 | n->s_(attr::name, op->name()); |
566 | // Call symbolic function |
567 | // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to |
568 | // Python. Check #87343 for details. |
569 | py::object raw_output = onnx.attr("_run_symbolic_function" )( |
570 | new_block->owningGraph()->shared_from_this(), |
571 | new_block, |
572 | n, |
573 | py_symbolic_args, |
574 | env, |
575 | operator_export_type); |
576 | |
577 | processSymbolicOutput(op->kind().toUnqualString(), n, raw_output); |
578 | } |
579 | }; |
580 | |
581 | auto k = old_node->kind(); |
582 | if (k.is_caffe2()) { |
583 | // Pass on Caffe2 operator, since we already preprocess it |
584 | cloneNode(old_node); |
585 | } else if (k == prim::PythonOp) { |
586 | callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node)); |
587 | } else { |
588 | callPySymbolicFunction(old_node); |
589 | } |
590 | } |
591 | |
592 | } // namespace jit |
593 | } // namespace torch |
594 | |