1#include <torch/csrc/jit/passes/onnx.h>
2
3#include <ATen/core/functional.h>
4#include <c10/util/Exception.h>
5#include <c10/util/irange.h>
6#include <torch/csrc/autograd/function.h>
7#include <torch/csrc/autograd/symbolic.h>
8#include <torch/csrc/jit/ir/constants.h>
9#include <torch/csrc/jit/jit_log.h>
10#include <torch/csrc/jit/passes/dead_code_elimination.h>
11#include <torch/csrc/jit/passes/onnx/constant_map.h>
12#include <torch/csrc/jit/passes/onnx/helper.h>
13#include <torch/csrc/jit/passes/onnx/onnx_log.h>
14#include <torch/csrc/jit/passes/onnx/shape_type_inference.h>
15#include <torch/csrc/jit/python/python_ir.h>
16#include <torch/csrc/utils/pybind.h>
17#include <sstream>
18#include <unordered_map>
19namespace torch {
20namespace jit {
21
22void removePrintOps(Block* block) {
23 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
24 ++it) {
25 for (auto b : it->blocks()) {
26 removePrintOps(b);
27 }
28 if (it->kind() == prim::Print || it->kind() == aten::warn) {
29 for (size_t i = 0; i < it->inputs().size();) {
30 auto input = it->inputs().at(i);
31 // only handling constants bc of potential side effects
32 if (input->uses().size() == 1 &&
33 input->node()->kind() == prim::Constant) {
34 it->removeInput(i);
35 input->node()->destroy();
36 } else {
37 ++i;
38 }
39 }
40 it.destroyCurrent();
41 }
42 }
43}
44
45void RemovePrintOps(std::shared_ptr<Graph>& graph) {
46 removePrintOps(graph->block());
47 GRAPH_DUMP("After RemovePrintOps: ", graph);
48}
49
50void checkONNXCompatibility(const c10::FunctionSchema& schema) {
51 // in ONNX, all inputs are tensors, no support for tensor list
52 // so at most one input tensor list is supported
53 bool has_tensor_list = false;
54 const auto& args = schema.arguments();
55 for (const auto& arg : args) {
56 if (arg.name() == "_caffe2_preallocated_outputs") {
57 continue;
58 }
59 auto type = arg.type();
60 if (type->kind() == TypeKind::OptionalType) {
61 type = reinterpret_cast<OptionalType*>(type.get())->getElementType();
62 // recursive optional type is not supported
63 TORCH_INTERNAL_ASSERT(type->kind() != TypeKind::OptionalType);
64 }
65 if (type->kind() == TypeKind::ListType) {
66 const auto& elem_type =
67 reinterpret_cast<ListType*>(type.get())->getElementType();
68 if (elem_type->isSubtypeOf(*TensorType::get())) {
69 TORCH_INTERNAL_ASSERT(
70 !has_tensor_list,
71 "ONNX export supports at most one TensorList as input.");
72 has_tensor_list = true;
73 }
74 }
75 }
76}
77
78void preprocessCaffe2Ops(Block* block) {
79 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
80 ++it) {
81 for (auto b : it->blocks()) {
82 preprocessCaffe2Ops(b);
83 }
84 if (it->kind().is_caffe2()) {
85 const auto& schema = it->schema();
86 checkONNXCompatibility(schema);
87 std::vector<Value*> origin_inputs;
88 for (Value* v : it->inputs()) {
89 origin_inputs.push_back(v);
90 }
91 it->removeAllInputs();
92 const auto& args = schema.arguments();
93 size_t origin_inputs_index = 0;
94 for (const auto& arg : args) {
95 const auto& type = arg.type();
96 TORCH_INTERNAL_ASSERT(origin_inputs_index < origin_inputs.size());
97 const auto& origin_input = origin_inputs[origin_inputs_index++];
98 if (type->kind() == TypeKind::OptionalType &&
99 origin_input->mustBeNone()) {
100 continue;
101 }
102 if (type->isSubtypeOf(*TensorType::get())) {
103 it->addInput(origin_input);
104 } else if (
105 type->kind() == TypeKind::BoolType ||
106 type->kind() == TypeKind::IntType) {
107 const auto* constant_node = origin_input->node();
108 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
109 it->i_(Symbol::attr(arg.name()), constant_node->i(attr::value));
110 } else if (type->kind() == TypeKind::FloatType) {
111 const auto* constant_node = origin_input->node();
112 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
113 it->f_(Symbol::attr(arg.name()), constant_node->f(attr::value));
114 } else if (type->kind() == TypeKind::StringType) {
115 const auto* constant_node = origin_input->node();
116 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
117 it->s_(Symbol::attr(arg.name()), constant_node->s(attr::value));
118 } else if (type->kind() == TypeKind::ListType) {
119 const auto& list_node = origin_input->node();
120 const auto& elem_type = type->castRaw<ListType>()->getElementType();
121 TORCH_INTERNAL_ASSERT(
122 list_node->kind() == prim::ListConstruct ||
123 list_node->kind() == prim::Constant);
124 if (elem_type->isSubtypeOf(*TensorType::get())) {
125 TORCH_INTERNAL_ASSERT(list_node->kind(), prim::ListConstruct);
126 const auto& tensor_list = origin_input->node()->inputs();
127 for (const auto& t : tensor_list) {
128 it->addInput(t);
129 }
130 } else if (elem_type->kind() == TypeKind::FloatType) {
131 std::vector<double> values;
132 if (list_node->kind() == prim::ListConstruct) {
133 for (const auto* elem_input : list_node->inputs()) {
134 const auto* constant_node = elem_input->node();
135 TORCH_INTERNAL_ASSERT(constant_node->kind() == prim::Constant);
136 values.push_back(constant_node->f(attr::value));
137 }
138 } else { // is a constant list
139 values = list_node->fs(attr::value);
140 }
141 it->fs_(Symbol::attr(arg.name()), values);
142 } else {
143 throw std::runtime_error(
144 "Unhandled scalar arg: " + arg.name() +
145 ", type: " + c10::typeKindToString(elem_type->kind()));
146 }
147 } else {
148 throw std::runtime_error(
149 "Unsupported input type of arg " + arg.name() +
150 " in Caffe2 operator: " + c10::typeKindToString(type->kind()));
151 }
152 }
153 }
154 }
155 EliminateDeadCode(
156 block, true, DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
157}
158
159void PreprocessCaffe2Ops(std::shared_ptr<Graph>& graph) {
160 preprocessCaffe2Ops(graph->block());
161 GRAPH_DUMP("After PreprocessCaffe2Ops: ", graph);
162}
163
164// Transform PythonOps into Nodes that match ONNX semantics.
165std::shared_ptr<Graph> ToONNX(
166 std::shared_ptr<Graph>& graph,
167 ::torch::onnx::OperatorExportTypes operator_export_type) {
168 auto constant_value_map = ConstantValueMap::getInstance();
169 ConstantValueMap::ClearMaps();
170 auto new_graph = std::make_shared<Graph>(graph->current_scope());
171 std::unordered_map<Value*, Value*> env;
172 try {
173 BlockToONNX(graph->block(), new_graph->block(), operator_export_type, env);
174 } catch (std::runtime_error& ex) {
175 ONNX_LOG(
176 "ONNX graph being constructed during exception:\n",
177 new_graph->toString());
178 throw;
179 }
180 GRAPH_DUMP("after ToONNX: ", new_graph);
181 ConstantValueMap::ClearMaps();
182 return new_graph;
183}
184
185// BlockToONNX.
186// is_sub_block = true means the old_block (aten graph) is in the sub block
187// (e.g., if sub block), and we want to convert it into its parent block in onnx
188// graph. In this case, we don't register the input/output or eliminate the dead
189// code.
190std::unordered_map<Value*, Value*> BlockToONNX(
191 Block* old_block,
192 Block* new_block,
193 ::torch::onnx::OperatorExportTypes operator_export_type,
194 std::unordered_map<Value*, Value*>& env,
195 bool is_sub_block) {
196 torch::autograd::SymbolicContext ctx{};
197 ctx.block = new_block;
198
199 GRAPH_DEBUG(
200 "BlockToONNX: graph of old block: ",
201 old_block->owningGraph()->toString());
202
203 // Initialize context and environment
204 if (!is_sub_block) {
205 for (auto input : old_block->inputs()) {
206 auto n = ctx.block->addInput()->copyMetadata(input);
207 env[input] = n;
208 }
209 }
210
211 // Finally, visit all nodes in the graph
212 for (auto node : old_block->nodes()) {
213 NodeToONNX(node, ctx.block, operator_export_type, env);
214 }
215
216 if (is_sub_block) {
217 return env;
218 }
219
220 for (auto output : old_block->outputs()) {
221 ctx.block->registerOutput(env.at(output));
222 }
223 // Run dce to clean-up unused functional and inplace ops.
224 EliminateDeadCode(
225 ctx.block,
226 true,
227 DCESideEffectPolicy::ALLOW_DELETING_NODES_WITH_SIDE_EFFECTS);
228
229 return {};
230}
231
232bool ConstantFoldCondition(torch::jit::Value* output) {
233 auto fold_condition = output->node()->kind() != c10::onnx::Constant &&
234 ConstantValueMap::HasValue(output->debugName());
235 auto reliable_value =
236 ConstantValueMap::GetTypeReliable(output->debugName()).value_or(false);
237 return fold_condition && reliable_value;
238}
239
240void NodeToONNX(
241 Node* old_node,
242 Block* new_block,
243 ::torch::onnx::OperatorExportTypes operator_export_type,
244 std::unordered_map<Value*, Value*>& env) {
245 py::object onnx = py::module::import("torch.onnx");
246 py::object onnx_globals = py::module::import("torch.onnx._globals");
247 py::object onnx_registration =
248 py::module::import("torch.onnx._internal.registration");
249
250 // Setup all the lambda helper functions.
251
252 // Returns a node that n maps to in the new graph
253 auto envFn = [&env](Value* n) -> Value* {
254 auto it = env.find(n);
255 TORCH_CHECK(it != env.end(), "Dangling node reference");
256 TORCH_CHECK(it->second, "Unused node was subsequently used");
257 return it->second;
258 };
259
260 // Put the new outputs in our environment map, and copy the type from the
261 // input graph if they were not set by the symbolic. This is called only
262 // with results of symbolic call (not for nodes that are just cloned).
263 auto setOutputs = [&](const std::string& op_name,
264 Node* node,
265 const value_list& outputs) {
266 auto old_outputs = node->outputs();
267 // Count all outputs, excluding Handles
268 auto num_old_outputs = old_outputs.size();
269 if (outputs.size() != num_old_outputs) {
270 std::ostringstream ss;
271 ss << "symbolic for " << op_name
272 << " produced an incorrect number of outputs (expected ";
273 ss << num_old_outputs << ", but got " << outputs.size() << ")";
274 throw std::runtime_error(ss.str());
275 }
276 // For const node, it does not need params_dict info, so set it to {}.
277 const ParamMap empty_params_dict = {};
278 auto opset_version = py::cast<int>(
279 onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version"));
280 for (const auto i : c10::irange(num_old_outputs)) {
281 auto old = old_outputs[i];
282 if (outputs[i]) {
283 bool exist_in_env =
284 (env.end() !=
285 std::find_if(
286 env.begin(), env.end(), [&outputs, i](const auto& vt) {
287 return vt.second == outputs[i];
288 }));
289 // Update ONNX value debug name with ATen value debug name if existed.
290 // Skip if ONNX value already exist in environment.
291 // This implies the op is a noop, and the value is owned by
292 // other node created elsewhere.
293 if (old->hasDebugName() && !exist_in_env) {
294 auto old_name = outputs[i]->debugName();
295 auto new_name = old->debugNameBase();
296 auto debug_names = new_block->owningGraph()->debugNames();
297 auto exist_name = debug_names.find(new_name);
298 outputs[i]->setDebugName(new_name);
299 if (exist_name != debug_names.end()) {
300 // setDebugName changes name of existing value with same name.
301 // Set again to revert the changes, but update name for new value
302 // with suffix.
303 exist_name->second->setDebugName(new_name);
304 }
305 ConstantValueMap::UpdateValueName(old_name, outputs[i]->debugName());
306 }
307 // Allow symbolic() to skip specifying the type of the return node.
308 // Unfortunately, they are on the hook for all internal nodes
309 // (though in practice, the types are not computed.)
310 //
311 // If onnx shape inference is turned on, the new outputs will have
312 // types inferred, and they will be merged with the old types.
313 if (ConstantFoldCondition(outputs[i])) {
314 // Create a const node if the node output value is in
315 // ConstantValueMap.
316 auto value =
317 ConstantValueMap::GetValue(outputs[i]->debugName()).value();
318 Node* const_node =
319 new_block->owningGraph()->create(c10::onnx::Constant);
320 const_node->t_(attr::value, value);
321 const_node->output()->setType(TensorType::create(value));
322
323 // Copy over source location and scope information to all nodes
324 // created by the symbolic
325 const_node->copyMetadata(node);
326 new_block->appendNode(const_node);
327 ONNXShapeTypeInference(const_node, empty_params_dict, opset_version);
328 env[old] = const_node->output();
329 } else {
330 // An update in ConstantValueMap is also needed here, since
331 // the user setType can be only accessed in this step, and it
332 // should be reliable.
333 MergeInferredTypeAndSetMap(
334 outputs[i], old->type(), outputs[i]->type());
335 // non ONNX node with no type given will throw out the warnings here.
336 UpdateReliable(
337 outputs[i],
338 AreInputsReliableOrStatic(outputs[i]->node()),
339 /*no_type_warning=*/true);
340 // For the node type that does not have ComputeConstant logic, it may
341 // have reliable shape but its shape is not in ConstantValueMap. So we
342 // need to update ConstantValueMap.
343 UpdateShapeConstantIfReliable(outputs[i]);
344
345 // Copy over source location and scope information to all nodes
346 // created by the symbolic
347 // Do not set metadata if outputs[i] is already in env.
348 if (!exist_in_env) {
349 outputs[i]->node()->copyMetadata(node);
350 }
351 env[old] = outputs[i];
352 }
353 } else {
354 // Null output means that the ONNX op doesn't have outputs corresponding
355 // to certain PyTorch outputs
356 env[old] = nullptr;
357 if (!old->uses().empty()) {
358 std::ostringstream ss;
359 ss << "symbolic for " << op_name << " returned None for the output "
360 << i;
361 ss << " (indicating conversion for that particular output is not supported), ";
362 ss << "but the network uses this output later";
363 // TODO: Say what actually used it
364 throw std::runtime_error(ss.str());
365 }
366 }
367 }
368 };
369
370 // Clone the node and add it to the new graph
371 auto cloneNode = [&](Node* node) {
372 auto n_ = new_block->appendNode(
373 new_block->owningGraph()->createClone(node, envFn));
374 for (const auto i : c10::irange(node->outputs().size())) {
375 // n_->outputs()[i]->setType(node->outputs()[i]->type());
376 env[node->output(i)] = n_->output(i);
377 }
378 };
379
380 // Inline the prim::PythonOp sub-block nodes and append them to the onnx graph
381 auto inlineAutograd = [&](Node* PythonOpNode) {
382 for (auto subblock : PythonOpNode->blocks()) {
383 for (const auto i : c10::irange(PythonOpNode->inputs().size())) {
384 env[subblock->inputs()[i]] = env[PythonOpNode->inputs()[i]];
385 }
386 for (auto* node : subblock->nodes()) {
387 NodeToONNX(node, new_block, operator_export_type, env);
388 }
389 for (const auto i : c10::irange(PythonOpNode->outputs().size())) {
390 env[PythonOpNode->outputs()[i]] = env[subblock->outputs()[i]];
391 }
392 }
393 };
394
395 // Cast output of symbolic() python implementation
396 auto processSymbolicOutput = [&](const std::string& op_name,
397 Node* n,
398 const py::object& raw_output) {
399 if (raw_output.ptr() == Py_None) {
400 cloneNode(n);
401 return;
402 }
403 // Cast the outputs back to C++ and put them in the new graph
404 std::vector<Value*> outputs;
405 try {
406 if (py::isinstance<Value>(raw_output)) {
407 outputs = value_list{py::cast<Value*>(raw_output)};
408 } else {
409 outputs = py::cast<std::vector<Value*>>(raw_output);
410 }
411 } catch (const std::exception& ex) {
412 std::ostringstream ss;
413 ss << "Error casting results of symbolic for " << op_name
414 << ": expected to return list of op nodes, instead received type ''"
415 << py::str(raw_output.get_type()) << "': " << py::str(raw_output);
416 throw std::runtime_error(ss.str());
417 }
418
419 setOutputs(op_name, n, outputs);
420 };
421
422 auto callPySymbolicFunction = [&](Node* n) {
423 // The idea is delegate as much of the actual argument massaging to
424 // Python as possible
425
426 py::tuple py_inputs(n->inputs().size());
427 Py_ssize_t input_nr = 0;
428 for (auto* input : n->inputs()) {
429 py_inputs[input_nr++] = py::cast(envFn(input));
430 }
431
432 Graph* g = new_block->owningGraph();
433 std::unordered_set<Node*> nodes_before;
434 for (auto node : g->nodes()) {
435 nodes_before.emplace(node);
436 }
437
438 WithInsertPoint insert_point_guard(new_block);
439 WithCurrentScope scope_guard(*g, n->scope());
440
441 // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
442 // Python. Check #87343 for details.
443 py::object raw_output = onnx.attr("_run_symbolic_function")(
444 g->shared_from_this(),
445 new_block,
446 n,
447 py_inputs,
448 env,
449 operator_export_type);
450
451 // Find new nodes that have been created by _run_symbolic_function and
452 // propagate metadata
453 for (auto node : g->nodes()) {
454 if (nodes_before.find(node) == nodes_before.end()) {
455 node->copyMetadata(n);
456 }
457 }
458
459 // TODO: Assert it's an ATen identifier???
460 // (Sometimes it's not...)
461 processSymbolicOutput(n->kind().toUnqualString(), n, raw_output);
462 GRAPH_DUMP("after processSymbolicOutput: ", g);
463 };
464
465 auto callPySymbolicMethod = [&](ConcretePythonOp* op) {
466 // Test if there is a symbolic function; bail if there is not
467 auto pyobj = py::handle(op->pyobj.get());
468 auto func = op->autogradFunction();
469 if (func) {
470 pyobj = func->get();
471 }
472
473 py::object opset_version =
474 onnx_globals.attr("GLOBALS").attr("export_onnx_opset_version");
475 // NOTE(justinchuby): Call the internal registry to register the symbolic
476 // method defined in the module.
477 bool is_registered_op =
478 onnx_registration.attr("registry")
479 .attr("is_registered_op")("prim::PythonOp", opset_version)
480 .cast<bool>();
481 if (!py::hasattr(pyobj, "symbolic") && !is_registered_op) {
482 // Inline the subgraph within the prim::PythonOp unless
483 // either of these conditions are satisfied
484 // 1. The torch.autograd.Function class of this node object has `symbolic`
485 // method defined.
486 // 2. Custom export symbolic is registered for prim::PythonOp.
487 if (operator_export_type == ::torch::onnx::OperatorExportTypes::ONNX ||
488 operator_export_type ==
489 ::torch::onnx::OperatorExportTypes::ONNX_ATEN_FALLBACK) {
490 try {
491 inlineAutograd(op);
492 } catch (const std::exception& ex) {
493 TORCH_WARN(
494 "Unable to inline PythonOp: ",
495 op->name(),
496 " due to the following exception\n",
497 ex.what(),
498 "prim::PythonOp will be exported as is and without being inlined\n",
499 "Try exporting with the following alternatives: \n",
500 "1) Set operator_export_type to ONNX_FALLTHROUGH mode\n",
501 "2) Register a symbolic method for the prim::PythonOp ",
502 op->name());
503 cloneNode(op);
504 }
505 } else {
506 cloneNode(op);
507 }
508 return;
509 }
510
511 // Prepare args for Python. First one is the graph, and is followed
512 // by regular args, with Variables replaced by corresponding nodes.
513 Py_ssize_t input_nr = 0;
514 py::tuple py_symbolic_args(op->cconv.size());
515 auto inputs = op->inputs();
516 auto node_it = inputs.begin();
517 auto scalar_it = op->scalar_args.begin();
518 for (auto arg_type : op->cconv) {
519 py::object obj;
520 if (arg_type == 'c') {
521 TORCH_CHECK(
522 scalar_it != op->scalar_args.end(),
523 "expected too many scalar args");
524 obj = py::reinterpret_borrow<py::object>(
525 py::handle((scalar_it++)->get()));
526 } else if (arg_type == 'd') {
527 TORCH_CHECK(node_it != inputs.end(), "expected too many inputs");
528 obj = py::cast(envFn(*node_it++));
529 } else {
530 throw std::runtime_error("unexpected calling convention");
531 }
532 py_symbolic_args[input_nr++] = obj;
533 }
534
535 WithInsertPoint insert_point_guard(new_block);
536 WithCurrentScope scope_guard(*new_block->owningGraph(), op->scope());
537
538 if (py::hasattr(pyobj, "symbolic")) {
539 // Call the symbolic function
540 // Use a little trampoline function so we can give good error messages
541 // upon argument mismatch
542 // Register as a custom operator
543 // TODO: Find a more elegant way to do this without having to touch
544 // internal Python modules.
545 // TODO(justinchuby): Define a namespace for these Python Ops.
546 onnx_registration.attr("registry")
547 .attr("register")(
548 "::" + op->name(),
549 opset_version,
550 pyobj.attr("symbolic"),
551 /* custom */ true);
552
553 // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
554 // Python. Check #87343 for details.
555 py::object raw_output = onnx.attr("_run_symbolic_method")(
556 new_block->owningGraph()->shared_from_this(),
557 op->name(),
558 pyobj.attr("symbolic"),
559 py_symbolic_args);
560
561 processSymbolicOutput(op->name(), op, raw_output);
562 } else {
563 TORCH_INTERNAL_ASSERT(is_registered_op);
564 Node* n = static_cast<Node*>(op);
565 n->s_(attr::name, op->name());
566 // Call symbolic function
567 // IMPORTANT: NEVER pass raw pointer of smart pointer managed objects to
568 // Python. Check #87343 for details.
569 py::object raw_output = onnx.attr("_run_symbolic_function")(
570 new_block->owningGraph()->shared_from_this(),
571 new_block,
572 n,
573 py_symbolic_args,
574 env,
575 operator_export_type);
576
577 processSymbolicOutput(op->kind().toUnqualString(), n, raw_output);
578 }
579 };
580
581 auto k = old_node->kind();
582 if (k.is_caffe2()) {
583 // Pass on Caffe2 operator, since we already preprocess it
584 cloneNode(old_node);
585 } else if (k == prim::PythonOp) {
586 callPySymbolicMethod(static_cast<ConcretePythonOp*>(old_node));
587 } else {
588 callPySymbolicFunction(old_node);
589 }
590}
591
592} // namespace jit
593} // namespace torch
594