1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5namespace torch {
6namespace jit {
7
8// Erase NumberType information. This is necessary for and only used in
9// exporting to ONNX. This pass ensures that no remaining Values have
10// NumberType types, replacing them with tensors.
11// The following things are done to erase NumberType info:
12// - NumberType outputs are changed to DynamicType.
13// - prim::Constant nodes which are numbers get changed into 0-dim tensors of
14// the corresponding type
15// - prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes
16// are erased.
17//
18// The pass assumes that DCE will be called sometime after.
19TORCH_API void EraseNumberTypes(const std::shared_ptr<Graph>& graph);
20TORCH_API void EraseNumberTypesOnBlock(Block* block);
21
22} // namespace jit
23} // namespace torch
24