1 | #include <torch/csrc/jit/passes/erase_number_types.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/ir/constants.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
6 | |
7 | #include <ATen/ScalarOps.h> |
8 | |
9 | namespace torch { |
10 | namespace jit { |
11 | |
12 | void SetNumTypeToTensorType(Value* v) { |
13 | if (v->type()->isSubtypeOf(*NumberType::get())) { |
14 | v->setType(TensorType::fromNumberType(*v->type())); |
15 | } else if (v->type()->isSubtypeOf(*BoolType::get())) { |
16 | v->setType(TensorType::fromBoolType()); |
17 | } |
18 | } |
19 | |
20 | void EraseNumberTypesOnBlock(Block* block) { |
21 | for (auto it = block->nodes().begin(), end = block->nodes().end(); it != end; |
22 | ++it) { |
23 | for (auto inp : it->inputs()) { |
24 | SetNumTypeToTensorType(inp); |
25 | } |
26 | for (auto sub : it->blocks()) { |
27 | EraseNumberTypesOnBlock(sub); |
28 | } |
29 | switch (it->kind()) { |
30 | case prim::Constant: { |
31 | // remove primitive constants, replacing with tensor equivalent |
32 | // ONNX does not support non-tensor constants |
33 | if (it->output()->type()->isSubtypeOf(*NumberType::get()) || |
34 | it->output()->type()->isSubtypeOf(*BoolType::get())) { |
35 | at::Scalar s; |
36 | if (it->output()->type()->isSubtypeOf(*BoolType::get())) { |
37 | s = *constant_as<bool>(it->output()); |
38 | } else { |
39 | s = *constant_as<at::Scalar>(it->output()); |
40 | } |
41 | |
42 | WithInsertPoint guard(*it); |
43 | Value* r = block->owningGraph()->insertConstant( |
44 | scalar_to_tensor(s), c10::nullopt, it->scope()); |
45 | r->copyMetadata(it->output()); |
46 | it->output()->replaceAllUsesWith(r); |
47 | it.destroyCurrent(); |
48 | } |
49 | } break; |
50 | case aten::Bool: |
51 | case aten::Float: |
52 | case aten::Int: |
53 | case aten::FloatImplicit: |
54 | case aten::IntImplicit: |
55 | case aten::ScalarImplicit: |
56 | case prim::NumToTensor: { |
57 | it->output()->replaceAllUsesWith(it->inputs()[0]); |
58 | it.destroyCurrent(); |
59 | } break; |
60 | default: { |
61 | for (auto o : it->outputs()) { |
62 | SetNumTypeToTensorType(o); |
63 | } |
64 | } break; |
65 | } |
66 | } |
67 | } |
68 | |
69 | void EraseNumberTypes(const std::shared_ptr<Graph>& graph) { |
70 | for (auto inp : graph->inputs()) { |
71 | SetNumTypeToTensorType(inp); |
72 | } |
73 | EraseNumberTypesOnBlock(graph->block()); |
74 | GRAPH_DUMP("After EraseNumberTypes: ", graph); |
75 | } |
76 | } // namespace jit |
77 | } // namespace torch |
78 |