1#include <torch/csrc/jit/passes/quantization/finalize.h>
2
3#include <torch/csrc/jit/jit_log.h>
4#include <torch/csrc/jit/passes/clear_profiling.h>
5#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
6#include <torch/csrc/jit/passes/constant_pooling.h>
7#include <torch/csrc/jit/passes/constant_propagation.h>
8#include <torch/csrc/jit/passes/dead_code_elimination.h>
9#include <torch/csrc/jit/passes/freeze_module.h>
10#include <torch/csrc/jit/passes/loop_unrolling.h>
11#include <torch/csrc/jit/passes/peephole.h>
12#include <torch/csrc/jit/passes/prepack_folding.h>
13#include <torch/csrc/jit/passes/quantization/quantization_patterns.h>
14#include <torch/csrc/jit/passes/quantization/register_packed_params.h>
15#include <torch/csrc/jit/runtime/graph_iterator.h>
16
17#include <utility>
18
19namespace torch {
20namespace jit {
21
22namespace {
23
24void insertPrepackUnpackForLinear(std::shared_ptr<Graph>& graph) {
25 std::vector<QuantFusionInfo> patterns_and_replacements =
26 linear_prepack_unpack_patterns();
27
28 for (const auto& entry : patterns_and_replacements) {
29 SubgraphRewriter rewriter;
30 rewriter.RegisterRewritePattern(entry.pattern, entry.replacement);
31 rewriter.runOnGraph(graph, entry.filters);
32 }
33}
34
35void insertPrepackUnpackForConv(std::shared_ptr<Graph>& graph) {
36 std::vector<QuantFusionInfo> patterns_and_replacements =
37 conv_prepack_unpack_patterns();
38
39 for (const auto& entry : patterns_and_replacements) {
40 SubgraphRewriter rewriter;
41 rewriter.RegisterRewritePattern(entry.pattern, entry.replacement);
42 rewriter.runOnGraph(graph, entry.filters);
43 }
44}
45
46void removePackedParamInsertionAndFPWeightsSetAttr(
47 std::shared_ptr<Graph>& g,
48 const std::unordered_set<std::string>& packed_param_attr_names) {
49 DepthFirstGraphNodeIterator it(g);
50 Node* n = nullptr;
51 std::vector<Node*> nodes_to_delete;
52 while ((n = it.next()) != nullptr) {
53 if (n->kind() == prim::SetAttr) {
54 const std::string& attr_name = n->s(attr::name);
55 if (packed_param_attr_names.count(attr_name)) {
56 nodes_to_delete.push_back(n);
57 } else {
58 Value* v = n->input(0);
59 Value* self = g->inputs()[0];
60 std::vector<std::string> paths = getModuleAccessPath(v, self);
61 std::string path = joinPaths(paths);
62 if (packed_param_attr_names.count(path)) {
63 nodes_to_delete.push_back(n);
64 }
65 }
66 }
67 }
68 for (auto node : nodes_to_delete) {
69 node->removeAllInputs();
70 }
71 for (auto node : nodes_to_delete) {
72 node->destroy();
73 }
74 ConstantPooling(g);
75 EliminateDeadCode(g);
76}
77
78void removeObserverCallMethods(std::shared_ptr<Graph>& g) {
79 DepthFirstGraphNodeIterator it(g);
80 Node* n = nullptr;
81 std::vector<Node*> nodes_to_delete;
82 while ((n = it.next()) != nullptr) {
83 if (n->kind() == prim::CallMethod) {
84 const std::string& attr_name = n->s(attr::name);
85 if (attr_name == "calculate_qparams") {
86 auto observer_node = n->input(0)->node();
87 if (observer_node->kind() == prim::GetAttr &&
88 observer_node->s(attr::name).find("_observer_") !=
89 std::string::npos) {
90 nodes_to_delete.push_back(n);
91 }
92 }
93 }
94 }
95 for (auto node : nodes_to_delete) {
96 node->removeAllInputs();
97 }
98 for (auto node : nodes_to_delete) {
99 node->destroy();
100 }
101 EliminateDeadCode(g);
102}
103
104void keepOnlyPackedParamsGeneration(Module& m, const std::string& method_name) {
105 auto g = m.get_method(method_name).graph();
106 Function& function = m.get_method(method_name).function();
107 const auto& schema = function.getSchema();
108 auto new_schema = schema.cloneWithReturns({Argument("", NoneType::get())});
109 for (size_t i = 0, output_size = g->outputs().size(); i < output_size; i++) {
110 g->eraseOutput(i);
111 }
112 Node* none_node = g->createNone();
113 g->registerOutput(none_node->output());
114 none_node->insertBefore(g->return_node());
115 function.setSchema(std::move(new_schema));
116 EliminateDeadCode(g);
117}
118
119} // namespace
120
121void QuantFusion(std::shared_ptr<Graph>& graph, QuantType quant_type) {
122 std::vector<QuantFusionInfo> patterns;
123 if (quant_type == QuantType::DYNAMIC) {
124 patterns = dynamic_quant_fusion_pattern_and_replacements();
125 std::vector<QuantFusionInfo> patterns_wo_dynamic_activation_quant =
126 dynamic_quantized_linear_pattern_and_replacements();
127 patterns.insert(
128 patterns.end(),
129 patterns_wo_dynamic_activation_quant.begin(),
130 patterns_wo_dynamic_activation_quant.end());
131 } else {
132 patterns = quant_fusion_pattern_and_replacements();
133 }
134 for (const auto& info : patterns) {
135 SubgraphRewriter rewriter;
136 rewriter.RegisterRewritePattern(info.pattern, info.replacement);
137 rewriter.runOnGraph(graph, info.filters);
138 }
139}
140
141void InsertPrepackUnpack(std::shared_ptr<Graph>& graph) {
142 insertPrepackUnpackForLinear(graph);
143 insertPrepackUnpackForConv(graph);
144}
145
146void InsertPrepackUnpack(Module& module) {
147 for (auto& method : module.get_methods()) {
148 auto graph = method.graph();
149 InsertPrepackUnpack(graph);
150 }
151 for (Module m : module.children()) {
152 InsertPrepackUnpack(m);
153 }
154}
155
156void FoldQuantizedPrepackingOps(Module& module) {
157 auto filter_fn = [](const Node* n) -> bool {
158 return (
159 n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
160 n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
161 n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
162 n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
163 n->kind() ==
164 Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
165 n->kind() ==
166 Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
167 };
168 PrePackingOpsFolder(module, filter_fn, "quantized");
169}
170
171std::unordered_set<std::string> RegisterPrePackingParams(
172 Module& module,
173 const std::string& method_name) {
174 auto filter_fn = [](const Node* n) -> bool {
175 return (
176 n->kind() == Symbol::fromQualString("quantized::linear_prepack") ||
177 n->kind() == Symbol::fromQualString("quantized::conv1d_prepack") ||
178 n->kind() == Symbol::fromQualString("quantized::conv2d_prepack") ||
179 n->kind() == Symbol::fromQualString("quantized::conv3d_prepack") ||
180 n->kind() ==
181 Symbol::fromQualString("quantized::conv_transpose1d_prepack") ||
182 n->kind() ==
183 Symbol::fromQualString("quantized::conv_transpose2d_prepack"));
184 };
185 return RegisterPrePackParams(module, method_name, filter_fn, "");
186}
187
188Module Finalize(
189 Module& module,
190 QuantType quant_type,
191 const std::vector<std::string>& preserved_attrs) {
192 // Tracing annotates the resulting graph with shape information. In many case,
193 // user applies different input shapes to traced graph. It is on the user to
194 // know it is correct to do so. The quantized module needs to be clean up and
195 // To prevent the JIT optimizations from leveraging the annotated shape info,
196 // clear shape information in the graph.
197 for (auto func : module.type()->methods()) {
198 ClearProfilingInformation(toGraphFunction(*func).graph());
199 }
200
201 auto graph = module.get_method("forward").graph();
202 InsertPrepackUnpack(graph);
203 GRAPH_DUMP("Before QuantFusion:", graph);
204 QuantFusion(graph, quant_type);
205 auto frozen = freeze_module(module, preserved_attrs);
206 FoldQuantizedPrepackingOps(frozen);
207 return frozen;
208}
209
210Module FinalizeOnDevicePTQ(
211 Module& module,
212 QuantType quant_type,
213 const std::string& method_name) {
214 // Tracing annotates the resulting graph with shape information. In many case,
215 // user applies different input shapes to traced graph. It is on the user to
216 // know it is correct to do so. The quantized module needs to be clean up and
217 // To prevent the JIT optimizations from leveraging the annotated shape info,
218 // clear shape information in the graph.
219 for (auto func : module.type()->methods()) {
220 ClearProfilingInformation(toGraphFunction(*func).graph());
221 }
222
223 const std::string kQuantizeString = "quantize_";
224 const auto matched_pos = method_name.find(kQuantizeString);
225 const auto end_pos = matched_pos + kQuantizeString.length();
226 const std::string orig_method_name = method_name.substr(end_pos);
227 TORCH_CHECK(
228 matched_pos == 0,
229 "Quantized ops can only be added to quantize_",
230 orig_method_name,
231 ". Please make sure to run quant/dequant nodes insertion step for on-device PTQ.");
232
233 const std::string quantized_method_name = "quantized_" + orig_method_name;
234 auto graph = module.get_method(method_name).graph();
235 // Doing some AOT optimizations here
236 // Of all CSE seeems to be required otherwise in some experiments
237 // serialized model is incorrect. As in it cannot be deserialized
238 // Rest are included as canonical optimizations that are not for inference
239 EliminateCommonSubexpression(graph);
240 EliminateDeadCode(graph);
241 PeepholeOptimize(graph);
242 ConstantPropagation(graph);
243 UnrollConstantLoops(graph);
244 ConstantPooling(graph);
245
246 InsertPrepackUnpack(graph);
247 GRAPH_DUMP("Before QuantFusion:", graph);
248 QuantFusion(graph, quant_type);
249 auto packed_param_attr_names = RegisterPrePackingParams(module, method_name);
250 GRAPH_DUMP("After QuantFusion + packed param registration:", graph);
251
252 // Now we have:
253 // 1. Inserted quantized weights packed params
254 // 2. Inserted packed params to module
255 // 3. Inserted quantized op
256 // The next thing we need is:
257 // 1. Replicate this method in quantize_forward
258 // 2. Remove SetAttr for fp weights that are reset by quantize_forward
259 // 3. Remove SetAttr node which will subsequently optimize away the nodes
260 // producin packed_params
261 // 4. Modify quantized_forward to remove all the nodes except for SetAttrs
262 cloneMethod(module, method_name, quantized_method_name);
263 // removeWeightSetAttrs(module, quantized_method_name);
264 auto quantized_graph = module.get_method(quantized_method_name).graph();
265 removePackedParamInsertionAndFPWeightsSetAttr(
266 quantized_graph, packed_param_attr_names);
267 // Removing packed params is not sufficient since that does not do DCE
268 // for observer node's getatts and callmthods because callmethods have side
269 // effects
270 removeObserverCallMethods(quantized_graph);
271 // This step removed the return output from the graph and subsequent
272 // DCE removes all the ops. After that only remaining things should be
273 // packed_params
274 keepOnlyPackedParamsGeneration(module, method_name);
275 return module;
276}
277
278} // namespace jit
279} // namespace torch
280