1 | #pragma once |
2 | #include <torch/csrc/jit/api/module.h> |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
5 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
6 | #include <torch/csrc/jit/passes/quantization/quantization_type.h> |
7 | |
8 | #include <functional> |
9 | #include <regex> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | using graph_rewrite_helper::getFuncName; |
15 | |
16 | // Vector of a module and the name of its method |
17 | using ModuleMethodVector = std::vector<std::pair<Module, std::string>>; |
18 | // Map of quantization parameter name and value |
19 | // for example _scale, _zero_point, |
20 | // _scalar_type and _axis(for per channel quantization) |
21 | using QParamVector = std::vector<std::pair<std::string, IValue>>; |
22 | |
23 | // =========== helper functions for Value ========= |
24 | // Check if a value is weight, since we need to use weight observer |
25 | // for weight |
26 | TORCH_API bool isWeight(Value* v); |
27 | |
28 | // Check if a value is bias for conv and linear, which we do not |
29 | // quantize |
30 | TORCH_API bool isBiasOfConvOrLinear(Value* v); |
31 | |
32 | TORCH_API bool isEmbeddingBagNonInput(Value* v); |
33 | |
34 | // Get the use as scalar input of clamp ops for the input value |
35 | c10::optional<Use> getClampScalarInputUse(Value* v); |
36 | |
37 | // For a given value `v`, get the list of values that we need to check |
38 | // if they are observed/quantized or not, if so, we can say the |
39 | // `v` is also observed/quantized, since we can derive |
40 | // the quantization parameters for `v` given the list of values |
41 | TORCH_API std::vector<Value*> getPassThroughInputs(Value* v); |
42 | |
43 | // Clones the method by the name of orig_method_name into new_method_name method |
44 | TORCH_API void cloneMethod( |
45 | Module& module, |
46 | const std::string& orig_method_name, |
47 | const std::string& new_method_name); |
48 | |
49 | // Check if a value in the graph is a Scalar value |
50 | TORCH_API bool isScalar(Value* v); |
51 | |
52 | // Check if value is the input of the graph |
53 | TORCH_API bool hitGraphInput(Value* value); |
54 | |
55 | // Converts a mangled name, such as |
56 | // __torch__.torch.ao.nn.quantized.modules.conv.___torch_mangle_7.Conv2d |
57 | // into an unmangled name, such as |
58 | // __torch__.torch.ao.nn.quantized.modules.conv.Conv2d |
59 | TORCH_API std::string removeTorchMangle(const std::string& orig_name); |
60 | |
61 | // Return the module name that corresponds to the value. |
62 | TORCH_API c10::optional<std::string> getModuleName(Value* value); |
63 | |
64 | // =========== helper functions for Node ========= |
65 | TORCH_API bool isSingleInputGeneralShapeAtenFunction(Node* n); |
66 | |
67 | TORCH_API bool isSingleInputGeneralValueAtenFunction(Node* n); |
68 | |
69 | TORCH_API bool isSingleInputGeneralCallFunction(Node* n); |
70 | |
71 | TORCH_API bool isSingleInputGeneralAtenFunction(Node* n); |
72 | |
73 | TORCH_API bool isClamp(Node* n); |
74 | |
75 | // Check if the node will produce the same result regardless of whether |
76 | // the input tensor is quantized or not, example: aten::size |
77 | TORCH_API bool isTensorInfoNode(Node* n); |
78 | |
79 | // Check if this the the propaagate op that has single input, e.g. aten::cat |
80 | TORCH_API bool isPropagateQuantSingleInputOp(Node* n); |
81 | |
82 | // Check if this is the propagate op that has two inputs, e.g. aten::add |
83 | TORCH_API bool isPropagateQuantBinaryOp(Node* n); |
84 | |
85 | // Check if this is the node that we'll quantize or not quantize depending on |
86 | // whether the input of the node is quantized, example: aten::cat |
87 | TORCH_API bool isPropagateQuantOp(Node* n); |
88 | |
89 | // Check if the node is a binary op like aten::add and aten::mul and |
90 | // if the input 1 is a scalar, these ops will be quantized to |
91 | // quantized::{op}_scalar |
92 | TORCH_API bool isBinaryOpWithScalarInput(Node* n); |
93 | |
94 | TORCH_API c10::optional<std::tuple<c10::QScheme, QParamVector>> getFixedQParams( |
95 | Node* n); |
96 | |
97 | // We don't want to analyze the graph for some `builtin` CallFunctions |
98 | // like `linear` because we want to preserve the op boundary |
99 | TORCH_API bool userDefinedCallFunction(Node* n); |
100 | |
101 | // Check if the node has scalar input |
102 | TORCH_API bool hasScalarInput(Node* n); |
103 | |
104 | // Check if a node is quantizable |
105 | TORCH_API bool nodeQuantizable( |
106 | Node* n, |
107 | QuantType quant_type = QuantType::STATIC); |
108 | |
109 | // Nodes which only require quantization of weight value, eg. embedding_bag |
110 | bool isWeightOnlyStaticQuantOp(Node* n); |
111 | |
112 | // Check if a use of the value is quantizable, this depends on |
113 | // both the use node and the offset |
114 | TORCH_API bool useQuantizable(const Use& use, QuantType quant_type); |
115 | |
116 | // Given a CallFunction node, extract the graph of the called function |
117 | TORCH_API std::shared_ptr<Graph> getCallFunctionGraph(Node* n); |
118 | |
119 | // Check if `use` is a CallFunction of name `func_name` and if value |
120 | // `v` is the nth argument (if provided) of the function |
121 | bool matchCallFuncToUse( |
122 | const Use& use, |
123 | const std::string& func_name, |
124 | c10::optional<int> nth_arg); |
125 | |
126 | // Check if `use` is a AtenFunction of name `func_name` and if value |
127 | // `v` is the nth argument (if provided) of the function |
128 | bool matchAtenFuncToUse( |
129 | const Use& use, |
130 | const std::string& func_name, |
131 | c10::optional<int> nth_arg); |
132 | |
133 | // =========== helper functions for Block ========= |
134 | // checks if a block will always raise an Exception |
135 | TORCH_API bool alwaysRaisesException(Block* block); |
136 | |
137 | // =========== helper functions for Module ========== |
138 | // TODO: remove |
139 | TORCH_API std::vector<std::string> getModuleAccessPath( |
140 | Value* instance, |
141 | Value* self); |
142 | // TODO: remove |
143 | TORCH_API Module |
144 | findChildModule(const Module& module, const std::vector<std::string>& path); |
145 | |
146 | // Given an CallMethod node, get the module instance corresponding |
147 | // to the instance Value |
148 | // TODO: refactor all current uses of this function to the Opt one |
149 | TORCH_API Module getInvokedModule(Module& module, Node* n, Value* self); |
150 | |
151 | // Given an CallMethod node, get the module instance corresponding |
152 | // to the instance Value if the instance is a module, otherwise return |
153 | // c10::nullopt |
154 | c10::optional<Module> getInvokedModuleOpt( |
155 | const Module& module, |
156 | Node* n, |
157 | Value* self); |
158 | |
159 | // ==================== filter functions for matches ============== |
160 | // filter to check Value `vname` is a constant of int value `value` |
161 | bool is_int_constant( |
162 | const Match& match, |
163 | const std::unordered_map<std::string, Value*>& vmap, |
164 | const std::string& vname, |
165 | int value); |
166 | |
167 | // filter to check if the %alpha argument of aten::add is constant 1 |
168 | bool aten_add_alpha_is_one( |
169 | const Match& match, |
170 | const std::unordered_map<std::string, Value*>& vmap); |
171 | |
172 | // filter to check if the functional in CallFunction is relu |
173 | bool is_functional_relu( |
174 | const Match& match, |
175 | const std::unordered_map<std::string, Value*>& vmap); |
176 | |
177 | // filter to check if the module is torch.nn.ReLU |
178 | bool is_relu_module( |
179 | const Match& match, |
180 | const std::unordered_map<std::string, Value*>& vmap); |
181 | |
182 | bool is_linear_module( |
183 | const Match& match, |
184 | const std::unordered_map<std::string, Value*>& vmap); |
185 | |
186 | // TODO: add a macro to declare the filters |
187 | bool is_conv1d_module( |
188 | const Match& match, |
189 | const std::unordered_map<std::string, Value*>& vmap); |
190 | |
191 | bool is_conv2d_module( |
192 | const Match& match, |
193 | const std::unordered_map<std::string, Value*>& vmap); |
194 | |
195 | bool is_conv3d_module( |
196 | const Match& match, |
197 | const std::unordered_map<std::string, Value*>& vmap); |
198 | |
199 | bool is_conv_transpose1d_module( |
200 | const Match& match, |
201 | const std::unordered_map<std::string, Value*>& vmap); |
202 | |
203 | bool is_conv_transpose2d_module( |
204 | const Match& match, |
205 | const std::unordered_map<std::string, Value*>& vmap); |
206 | |
207 | bool is_batchnorm2d_module( |
208 | const Match& match, |
209 | const std::unordered_map<std::string, Value*>& vmap); |
210 | |
211 | bool is_batchnorm3d_module( |
212 | const Match& match, |
213 | const std::unordered_map<std::string, Value*>& vmap); |
214 | |
215 | } // namespace jit |
216 | } // namespace torch |
217 | |