1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | // Erase NumberType information. This is necessary for and only used in |
9 | // exporting to ONNX. This pass ensures that no remaining Values have |
10 | // NumberType types, replacing them with tensors. |
11 | // The following things are done to erase NumberType info: |
12 | // - NumberType outputs are changed to DynamicType. |
13 | // - prim::Constant nodes which are numbers get changed into 0-dim tensors of |
14 | // the corresponding type |
15 | // - prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes |
16 | // are erased. |
17 | // |
18 | // The pass assumes that DCE will be called sometime after. |
19 | TORCH_API void EraseNumberTypes(const std::shared_ptr<Graph>& graph); |
20 | TORCH_API void EraseNumberTypesOnBlock(Block* block); |
21 | |
22 | } // namespace jit |
23 | } // namespace torch |
24 | |