1#include <torch/csrc/jit/passes/erase_number_types.h>
2
3#include <torch/csrc/jit/ir/constants.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/dead_code_elimination.h>
6
7#include <ATen/ScalarOps.h>
8
9namespace torch {
10namespace jit {
11
12void SetNumTypeToTensorType(Value* v) {
13 if (v->type()->isSubtypeOf(*NumberType::get())) {
14 v->setType(TensorType::fromNumberType(*v->type()));
15 } else if (v->type()->isSubtypeOf(*BoolType::get())) {
16 v->setType(TensorType::fromBoolType());
17 }
18}
19
20void EraseNumberTypesOnBlock(Block* block) {
21 for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end;
22 ++it) {
23 for (auto inp : it->inputs()) {
24 SetNumTypeToTensorType(inp);
25 }
26 for (auto sub : it->blocks()) {
27 EraseNumberTypesOnBlock(sub);
28 }
29 switch (it->kind()) {
30 case prim::Constant: {
31 // remove primitive constants, replacing with tensor equivalent
32 // ONNX does not support non-tensor constants
33 if (it->output()->type()->isSubtypeOf(*NumberType::get()) ||
34 it->output()->type()->isSubtypeOf(*BoolType::get())) {
35 at::Scalar s;
36 if (it->output()->type()->isSubtypeOf(*BoolType::get())) {
37 s = *constant_as<bool>(it->output());
38 } else {
39 s = *constant_as<at::Scalar>(it->output());
40 }
41
42 WithInsertPoint guard(*it);
43 Value* r = block->owningGraph()->insertConstant(
44 scalar_to_tensor(s), c10::nullopt, it->scope());
45 r->copyMetadata(it->output());
46 it->output()->replaceAllUsesWith(r);
47 it.destroyCurrent();
48 }
49 } break;
50 case aten::Bool:
51 case aten::Float:
52 case aten::Int:
53 case aten::FloatImplicit:
54 case aten::IntImplicit:
55 case aten::ScalarImplicit:
56 case prim::NumToTensor: {
57 it->output()->replaceAllUsesWith(it->inputs()[0]);
58 it.destroyCurrent();
59 } break;
60 default: {
61 for (auto o : it->outputs()) {
62 SetNumTypeToTensorType(o);
63 }
64 } break;
65 }
66 }
67}
68
69void EraseNumberTypes(const std::shared_ptr<Graph>& graph) {
70 for (auto inp : graph->inputs()) {
71 SetNumTypeToTensorType(inp);
72 }
73 EraseNumberTypesOnBlock(graph->block());
74 GRAPH_DUMP("After EraseNumberTypes: ", graph);
75}
76} // namespace jit
77} // namespace torch
78