1 | #include <torch/csrc/jit/passes/quantization/finalize.h> |
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | #include <torch/csrc/jit/passes/clear_profiling.h> |
5 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
6 | #include <torch/csrc/jit/passes/constant_pooling.h> |
7 | #include <torch/csrc/jit/passes/constant_propagation.h> |
8 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
9 | #include <torch/csrc/jit/passes/freeze_module.h> |
10 | #include <torch/csrc/jit/passes/loop_unrolling.h> |
11 | #include <torch/csrc/jit/passes/peephole.h> |
12 | #include <torch/csrc/jit/passes/prepack_folding.h> |
13 | #include <torch/csrc/jit/passes/quantization/quantization_patterns.h> |
14 | #include <torch/csrc/jit/passes/quantization/register_packed_params.h> |
15 | #include <torch/csrc/jit/runtime/graph_iterator.h> |
16 | |
17 | #include <utility> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | namespace { |
23 | |
24 | void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) { |
25 | std::vector<QuantFusionInfo> patterns_and_replacements = |
26 | linear_prepack_unpack_patterns(); |
27 | |
28 | for (const auto& entry : patterns_and_replacements) { |
29 | SubgraphRewriter rewriter; |
30 | rewriter.RegisterRewritePattern(entry.pattern, entry.replacement); |
31 | rewriter.runOnGraph(graph, entry.filters); |
32 | } |
33 | } |
34 | |
35 | void insertPrepackUnpackForConv(std::shared_ptr<Graph>& graph) { |
36 | std::vector<QuantFusionInfo> patterns_and_replacements = |
37 | conv_prepack_unpack_patterns(); |
38 | |
39 | for (const auto& entry : patterns_and_replacements) { |
40 | SubgraphRewriter rewriter; |
41 | rewriter.RegisterRewritePattern(entry.pattern, entry.replacement); |
42 | rewriter.runOnGraph(graph, entry.filters); |
43 | } |
44 | } |
45 | |
46 | void removePackedParamInsertionAndFPWeightsSetAttr( |
47 | std::shared_ptr<Graph>& g, |
48 | const std::unordered_set<std::string>& packed_param_attr_names) { |
49 | DepthFirstGraphNodeIterator it(g); |
50 | Node* n = nullptr; |
51 | std::vector<Node*> nodes_to_delete; |
52 | while ((n = it.next()) != nullptr) { |
53 | if (n->kind() == prim::SetAttr) { |
54 | const std::string& attr_name = n->s(attr::name); |
55 | if (packed_param_attr_names.count(attr_name)) { |
56 | nodes_to_delete.push_back(n); |
57 | } else { |
58 | Value* v = n->input(0); |
59 | Value* self = g->inputs()[0]; |
60 | std::vector<std::string> paths = getModuleAccessPath(v, self); |
61 | std::string path = joinPaths(paths); |
62 | if (packed_param_attr_names.count(path)) { |
63 | nodes_to_delete.push_back(n); |
64 | } |
65 | } |
66 | } |
67 | } |
68 | for (auto node : nodes_to_delete) { |
69 | node->removeAllInputs(); |
70 | } |
71 | for (auto node : nodes_to_delete) { |
72 | node->destroy(); |
73 | } |
74 | ConstantPooling(g); |
75 | EliminateDeadCode(g); |
76 | } |
77 | |
78 | void removeObserverCallMethods(std::shared_ptr<Graph>& g) { |
79 | DepthFirstGraphNodeIterator it(g); |
80 | Node* n = nullptr; |
81 | std::vector<Node*> nodes_to_delete; |
82 | while ((n = it.next()) != nullptr) { |
83 | if (n->kind() == prim::CallMethod) { |
84 | const std::string& attr_name = n->s(attr::name); |
85 | if (attr_name == "calculate_qparams" ) { |
86 | auto observer_node = n->input(0)->node(); |
87 | if (observer_node->kind() == prim::GetAttr && |
88 | observer_node->s(attr::name).find("_observer_" ) != |
89 | std::string::npos) { |
90 | nodes_to_delete.push_back(n); |
91 | } |
92 | } |
93 | } |
94 | } |
95 | for (auto node : nodes_to_delete) { |
96 | node->removeAllInputs(); |
97 | } |
98 | for (auto node : nodes_to_delete) { |
99 | node->destroy(); |
100 | } |
101 | EliminateDeadCode(g); |
102 | } |
103 | |
104 | void keepOnlyPackedParamsGeneration(Module& m, const std::string& method_name) { |
105 | auto g = m.get_method(method_name).graph(); |
106 | Function& function = m.get_method(method_name).function(); |
107 | const auto& schema = function.getSchema(); |
108 | auto new_schema = schema.cloneWithReturns({Argument("" , NoneType::get())}); |
109 | for (size_t i = 0, output_size = g->outputs().size(); i < output_size; i++) { |
110 | g->eraseOutput(i); |
111 | } |
112 | Node* none_node = g->createNone(); |
113 | g->registerOutput(none_node->output()); |
114 | none_node->insertBefore(g->return_node()); |
115 | function.setSchema(std::move(new_schema)); |
116 | EliminateDeadCode(g); |
117 | } |
118 | |
119 | } // namespace |
120 | |
121 | void QuantFusion(std::shared_ptr<Graph>& graph, QuantType quant_type) { |
122 | std::vector<QuantFusionInfo> patterns; |
123 | if (quant_type == QuantType::DYNAMIC) { |
124 | patterns = dynamic_quant_fusion_pattern_and_replacements(); |
125 | std::vector<QuantFusionInfo> patterns_wo_dynamic_activation_quant = |
126 | dynamic_quantized_linear_pattern_and_replacements(); |
127 | patterns.insert( |
128 | patterns.end(), |
129 | patterns_wo_dynamic_activation_quant.begin(), |
130 | patterns_wo_dynamic_activation_quant.end()); |
131 | } else { |
132 | patterns = quant_fusion_pattern_and_replacements(); |
133 | } |
134 | for (const auto& info : patterns) { |
135 | SubgraphRewriter rewriter; |
136 | rewriter.RegisterRewritePattern(info.pattern, info.replacement); |
137 | rewriter.runOnGraph(graph, info.filters); |
138 | } |
139 | } |
140 | |
141 | void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) { |
142 | insertPrepackUnpackForLinear(graph); |
143 | insertPrepackUnpackForConv(graph); |
144 | } |
145 | |
146 | void InsertPrepackUnpack(Module& module) { |
147 | for (auto& method : module.get_methods()) { |
148 | auto graph = method.graph(); |
149 | InsertPrepackUnpack(graph); |
150 | } |
151 | for (Module m : module.children()) { |
152 | InsertPrepackUnpack(m); |
153 | } |
154 | } |
155 | |
156 | void FoldQuantizedPrepackingOps(Module& module) { |
157 | auto filter_fn = [](const Node* n) -> bool { |
158 | return ( |
159 | n->kind() == Symbol::fromQualString("quantized::linear_prepack" ) || |
160 | n->kind() == Symbol::fromQualString("quantized::conv1d_prepack" ) || |
161 | n->kind() == Symbol::fromQualString("quantized::conv2d_prepack" ) || |
162 | n->kind() == Symbol::fromQualString("quantized::conv3d_prepack" ) || |
163 | n->kind() == |
164 | Symbol::fromQualString("quantized::conv_transpose1d_prepack" ) || |
165 | n->kind() == |
166 | Symbol::fromQualString("quantized::conv_transpose2d_prepack" )); |
167 | }; |
168 | PrePackingOpsFolder(module, filter_fn, "quantized" ); |
169 | } |
170 | |
171 | std::unordered_set<std::string> RegisterPrePackingParams( |
172 | Module& module, |
173 | const std::string& method_name) { |
174 | auto filter_fn = [](const Node* n) -> bool { |
175 | return ( |
176 | n->kind() == Symbol::fromQualString("quantized::linear_prepack" ) || |
177 | n->kind() == Symbol::fromQualString("quantized::conv1d_prepack" ) || |
178 | n->kind() == Symbol::fromQualString("quantized::conv2d_prepack" ) || |
179 | n->kind() == Symbol::fromQualString("quantized::conv3d_prepack" ) || |
180 | n->kind() == |
181 | Symbol::fromQualString("quantized::conv_transpose1d_prepack" ) || |
182 | n->kind() == |
183 | Symbol::fromQualString("quantized::conv_transpose2d_prepack" )); |
184 | }; |
185 | return RegisterPrePackParams(module, method_name, filter_fn, "" ); |
186 | } |
187 | |
188 | Module Finalize( |
189 | Module& module, |
190 | QuantType quant_type, |
191 | const std::vector<std::string>& preserved_attrs) { |
192 | // Tracing annotates the resulting graph with shape information. In many case, |
193 | // user applies different input shapes to traced graph. It is on the user to |
194 | // know it is correct to do so. The quantized module needs to be clean up and |
195 | // To prevent the JIT optimizations from leveraging the annotated shape info, |
196 | // clear shape information in the graph. |
197 | for (auto func : module.type()->methods()) { |
198 | ClearProfilingInformation(toGraphFunction(*func).graph()); |
199 | } |
200 | |
201 | auto graph = module.get_method("forward" ).graph(); |
202 | InsertPrepackUnpack(graph); |
203 | GRAPH_DUMP("Before QuantFusion:" , graph); |
204 | QuantFusion(graph, quant_type); |
205 | auto frozen = freeze_module(module, preserved_attrs); |
206 | FoldQuantizedPrepackingOps(frozen); |
207 | return frozen; |
208 | } |
209 | |
210 | Module FinalizeOnDevicePTQ( |
211 | Module& module, |
212 | QuantType quant_type, |
213 | const std::string& method_name) { |
214 | // Tracing annotates the resulting graph with shape information. In many case, |
215 | // user applies different input shapes to traced graph. It is on the user to |
216 | // know it is correct to do so. The quantized module needs to be clean up and |
217 | // To prevent the JIT optimizations from leveraging the annotated shape info, |
218 | // clear shape information in the graph. |
219 | for (auto func : module.type()->methods()) { |
220 | ClearProfilingInformation(toGraphFunction(*func).graph()); |
221 | } |
222 | |
223 | const std::string kQuantizeString = "quantize_" ; |
224 | const auto matched_pos = method_name.find(kQuantizeString); |
225 | const auto end_pos = matched_pos + kQuantizeString.length(); |
226 | const std::string orig_method_name = method_name.substr(end_pos); |
227 | TORCH_CHECK( |
228 | matched_pos == 0, |
229 | "Quantized ops can only be added to quantize_" , |
230 | orig_method_name, |
231 | ". Please make sure to run quant/dequant nodes insertion step for on-device PTQ." ); |
232 | |
233 | const std::string quantized_method_name = "quantized_" + orig_method_name; |
234 | auto graph = module.get_method(method_name).graph(); |
235 | // Doing some AOT optimizations here |
236 | // Of all CSE seeems to be required otherwise in some experiments |
237 | // serialized model is incorrect. As in it cannot be deserialized |
238 | // Rest are included as canonical optimizations that are not for inference |
239 | EliminateCommonSubexpression(graph); |
240 | EliminateDeadCode(graph); |
241 | PeepholeOptimize(graph); |
242 | ConstantPropagation(graph); |
243 | UnrollConstantLoops(graph); |
244 | ConstantPooling(graph); |
245 | |
246 | InsertPrepackUnpack(graph); |
247 | GRAPH_DUMP("Before QuantFusion:" , graph); |
248 | QuantFusion(graph, quant_type); |
249 | auto packed_param_attr_names = RegisterPrePackingParams(module, method_name); |
250 | GRAPH_DUMP("After QuantFusion + packed param registration:" , graph); |
251 | |
252 | // Now we have: |
253 | // 1. Inserted quantized weights packed params |
254 | // 2. Inserted packed params to module |
255 | // 3. Inserted quantized op |
256 | // The next thing we need is: |
257 | // 1. Replicate this method in quantize_forward |
258 | // 2. Remove SetAttr for fp weights that are reset by quantize_forward |
259 | // 3. Remove SetAttr node which will subsequently optimize away the nodes |
260 | // producin packed_params |
261 | // 4. Modify quantized_forward to remove all the nodes except for SetAttrs |
262 | cloneMethod(module, method_name, quantized_method_name); |
263 | // removeWeightSetAttrs(module, quantized_method_name); |
264 | auto quantized_graph = module.get_method(quantized_method_name).graph(); |
265 | removePackedParamInsertionAndFPWeightsSetAttr( |
266 | quantized_graph, packed_param_attr_names); |
267 | // Removing packed params is not sufficient since that does not do DCE |
268 | // for observer node's getatts and callmthods because callmethods have side |
269 | // effects |
270 | removeObserverCallMethods(quantized_graph); |
271 | // This step removed the return output from the graph and subsequent |
272 | // DCE removes all the ops. After that only remaining things should be |
273 | // packed_params |
274 | keepOnlyPackedParamsGeneration(module, method_name); |
275 | return module; |
276 | } |
277 | |
278 | } // namespace jit |
279 | } // namespace torch |
280 | |