1 | #include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h> |
2 | |
3 | #include <c10/core/QScheme.h> |
4 | #include <c10/util/irange.h> |
5 | #include <torch/csrc/jit/frontend/schema_matching.h> |
6 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
7 | #include <torch/csrc/jit/jit_log.h> |
8 | #include <torch/csrc/jit/passes/constant_propagation.h> |
9 | #include <torch/csrc/jit/passes/fuse_linear.h> |
10 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
11 | #include <torch/csrc/jit/passes/inliner.h> |
12 | #include <torch/csrc/jit/passes/quantization/helper.h> |
13 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
14 | |
15 | #include <stack> |
16 | #include <utility> |
17 | |
18 | namespace torch { |
19 | namespace jit { |
20 | |
21 | namespace { |
22 | using graph_rewrite_helper::PatternInfo; |
23 | |
24 | // dynamic quantization ops for activation: choose_qparams, quant, dequant |
25 | using DynamicQuantOps = std::tuple<Node*, Node*, Node*>; |
26 | |
27 | std::string kScalarType = "_scalar_type" ; |
28 | |
29 | struct QuantOpParams { |
30 | c10::QScheme qscheme{c10::kPerTensorAffine}; |
31 | std::vector<Value*> qparams; |
32 | // This is only so that insertQuantizationOps can be templatized |
33 | // and subsequntly significant portion of that code can be reused. |
34 | std::string back() const { |
35 | return "AttributeDoesNotExist" ; |
36 | } |
37 | }; |
38 | |
39 | c10::QScheme toAffine(c10::QScheme qscheme) { |
40 | switch (qscheme) { |
41 | case c10::kPerTensorAffine: |
42 | case c10::kPerTensorSymmetric: |
43 | return c10::kPerTensorAffine; |
44 | case c10::kPerChannelAffine: |
45 | case c10::kPerChannelSymmetric: |
46 | return c10::kPerChannelAffine; |
47 | default: |
48 | return qscheme; |
49 | } |
50 | } |
51 | |
52 | bool isPerChannel(at::QScheme qscheme) { |
53 | return qscheme == c10::kPerChannelAffine || |
54 | qscheme == c10::kPerChannelSymmetric; |
55 | } |
56 | |
57 | // Go through the CallMethod graph to check if the value is Weight. |
58 | bool isWeight(Module& module, Value* v) { |
59 | if (isWeight(v)) { |
60 | return true; |
61 | } |
62 | c10::optional<bool> result; |
63 | auto* self = v->owningGraph()->inputs()[0]; |
64 | for (const Use& u : v->uses()) { |
65 | Node* n = u.user; |
66 | if (n->kind() == prim::CallMethod) { |
67 | auto m_opt = getInvokedModuleOpt(module, n, self); |
68 | if (!m_opt.has_value()) { |
69 | return false; |
70 | } |
71 | auto m = *m_opt; |
72 | auto g = m.get_method(n->s(attr::name)).graph(); |
73 | auto call_method_result = isWeight(m, g->inputs()[u.offset]); |
74 | if (result.has_value()) { |
75 | // Check to make sure all the CallMethods in the graph produce the same |
76 | // output. |
77 | TORCH_CHECK( |
78 | call_method_result == result.value(), |
79 | "Expected all CallMethods to use either weight " |
80 | "or non-weight value." , |
81 | v->debugName()); |
82 | } else { |
83 | result = call_method_result; |
84 | } |
85 | } |
86 | } |
87 | return result.has_value() ? result.value() : false; |
88 | } |
89 | |
90 | Node* insertChooseQParams(Graph* graph, Value* original_val) { |
91 | std::string choose_qparams_func = "_choose_qparams_per_tensor" ; |
92 | // Set the reduce range to default to true, since qnnpack backend ignores this |
93 | // argument. |
94 | bool reduce_range_param = true; |
95 | auto reduce_range = graph->insertConstant(reduce_range_param); |
96 | // choose_qparams_per_tensor has 2 outputs, (scale, zero_point). |
97 | Node* choose_qparams = graph->create( |
98 | at::Symbol::aten(choose_qparams_func), |
99 | {original_val, reduce_range}, |
100 | /* num_outputs = */ 2); |
101 | choose_qparams->output(0)->setDebugName(original_val->debugName() + ".scale" ); |
102 | choose_qparams->output(0)->setType(FloatType::get()); |
103 | choose_qparams->output(1)->setDebugName( |
104 | original_val->debugName() + ".zero_point" ); |
105 | choose_qparams->output(1)->setType(IntType::get()); |
106 | graph->insertNode(choose_qparams); |
107 | return choose_qparams; |
108 | } |
109 | |
110 | Node* insertQuant( |
111 | Graph* graph, |
112 | const std::vector<Value*>& inputs, |
113 | NodeKind quant_kind, |
114 | const std::string& debugName) { |
115 | Node* quant = graph->create(quant_kind, inputs); |
116 | quant->output()->setDebugName(debugName); |
117 | graph->insertNode(quant); |
118 | return quant; |
119 | } |
120 | |
121 | Node* insertDeQuant( |
122 | Graph* graph, |
123 | Value* quantized_val, |
124 | Value* original_val, |
125 | size_t id = 0) { |
126 | Node* dequant = graph->create(Symbol::aten("dequantize" ), {quantized_val}); |
127 | dequant->output() |
128 | ->setDebugName( |
129 | original_val->debugName() + ".dequant." + c10::guts::to_string(id)) |
130 | ->setType(original_val->type()); |
131 | graph->insertNode(dequant); |
132 | return dequant; |
133 | } |
134 | |
135 | std::vector<Value*> insertDeQuantForAllUse( |
136 | Graph* graph, |
137 | Value* quantized_val, |
138 | Value* original_val) { |
139 | // copy uses to vector since value->uses() is a reference |
140 | // and changing the graph will also change the uses() list |
141 | const std::vector<Use> uses = original_val->uses(); |
142 | std::vector<Value*> outputs; |
143 | for (const auto i : c10::irange(uses.size())) { |
144 | auto* user = uses[i].user; |
145 | // Insert dequantize node right before use node, because |
146 | // we want to make sure use node and dequantize node reside |
147 | // in the same block so that quant fusion can happen |
148 | WithInsertPoint ins(user); |
149 | Node* dequant = insertDeQuant(graph, quantized_val, original_val, i); |
150 | user->replaceInput(uses[i].offset, dequant->output()); |
151 | outputs.push_back(dequant->output()); |
152 | } |
153 | return outputs; |
154 | } |
155 | |
156 | Node* insertQParam( |
157 | Graph* graph, |
158 | Value* quantized_input, |
159 | NodeKind node_kind, |
160 | const TypePtr& output_type, |
161 | const std::string& param_name) { |
162 | Node* qparam = graph->create(node_kind, {quantized_input}); |
163 | qparam->output() |
164 | ->setDebugName(quantized_input->debugName() + "." + param_name) |
165 | ->setType(output_type); |
166 | graph->insertNode(qparam); |
167 | return qparam; |
168 | } |
169 | |
170 | Node* insertScalarToTensor(Graph* graph, Value* scalar_value) { |
171 | Node* n = scalar_value->node(); |
172 | WithInsertPoint ins(n->next()); |
173 | Value* float_scalar_type = graph->insertConstant(IValue(c10::kFloat)); |
174 | Value* none = graph->insertConstant(IValue()); |
175 | Node* tensor_node = graph->create( |
176 | Symbol::aten("scalar_tensor" ), |
177 | {scalar_value, float_scalar_type, none, none, none}); |
178 | Value* tensor_output = tensor_node->output(); |
179 | tensor_output->setDebugName(scalar_value->debugName() + ".tensor" ); |
180 | graph->insertNode(tensor_node); |
181 | // replace original_output with tensor |
182 | scalar_value->replaceAllUsesAfterNodeWith(tensor_node, tensor_output); |
183 | return tensor_node; |
184 | } |
185 | |
186 | Node* insertItem(Graph* graph, Value* tensor, const TypePtr& output_type) { |
187 | WithInsertPoint ins(tensor->node()->next()); |
188 | Node* n = graph->create(Symbol::aten("item" ), {tensor}); |
189 | Value* scalar = n->output(); |
190 | scalar->setDebugName(tensor->debugName() + ".scalar" )->setType(output_type); |
191 | graph->insertNode(n); |
192 | return n; |
193 | } |
194 | |
195 | DynamicQuantOps insertChooseQParamQuantDequant( |
196 | Graph* graph, |
197 | Value* original_val, |
198 | Value* dtype, |
199 | NodeKind quant_kind) { |
200 | Node* choose_qparams = insertChooseQParams(graph, original_val); |
201 | std::vector<Value*> quant_inputs = {original_val}; |
202 | for (auto& out : choose_qparams->outputs()) { |
203 | quant_inputs.push_back(out); |
204 | } |
205 | quant_inputs.push_back(dtype); |
206 | Node* quant = insertQuant( |
207 | graph, quant_inputs, quant_kind, original_val->debugName() + ".quant" ); |
208 | Node* dequant = insertDeQuant(graph, quant->output(), original_val); |
209 | return std::make_tuple(choose_qparams, quant, dequant); |
210 | } |
211 | |
212 | Node* insertFP16CastOps(Graph* graph, Value* observer_out) { |
213 | // If the weight value is outside of the range for FP16 range, i.e. [5.96e-8, |
214 | // 65504], we saturate the values to the min/max of this range. |
215 | Node* saturated_weight = |
216 | graph->create(Symbol::aten("_saturate_weight_to_fp16" ), {observer_out}); |
217 | graph->insertNode(saturated_weight); |
218 | graph->lint(); |
219 | |
220 | return saturated_weight; |
221 | } |
222 | |
223 | // find the observer for Value `v` and return the name of the observer |
224 | c10::optional<std::string> findObserverName(Value* v) { |
225 | // Note that here we just check for the name of observer, but the ideally |
226 | // we should be comparing the type of observer, this is a temporary |
227 | // work around until data only clone of module.clone is supported. |
228 | Node* n = v->node(); |
229 | if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward" ) { |
230 | auto module_instance = n->inputs().at(0); |
231 | if (module_instance->node()->kind() == prim::GetAttr && |
232 | module_instance->node()->s(attr::name).find("_observer_" ) != |
233 | std::string::npos) { |
234 | return module_instance->node()->s(attr::name); |
235 | } |
236 | } |
237 | return c10::nullopt; |
238 | } |
239 | |
240 | bool isPlaceholderObserver(Value* observer) { |
241 | if (getModuleName(observer).has_value()) { |
242 | auto name = getModuleName(observer).value(); |
243 | // if PlaceholderObserver is (anywhere) in name |
244 | if (name.find("PlaceholderObserver" ) != std::string::npos) { |
245 | return true; |
246 | } |
247 | } |
248 | return false; |
249 | } |
250 | |
251 | at::ScalarType getObserverDtype(Module& module, Value* v) { |
252 | auto observer_name = findObserverName(v); |
253 | if (observer_name.has_value()) { |
254 | auto observer_module = module.attr(observer_name.value()).toModule(); |
255 | at::ScalarType scalar_type = observer_module.attr("dtype" ).toScalarType(); |
256 | return scalar_type; |
257 | } |
258 | return at::ScalarType::Undefined; |
259 | } |
260 | |
261 | c10::optional<std::string> getEmbeddingBagObsName( |
262 | script::Module& module, |
263 | Node* n) { |
264 | Value* v = n->output(); |
265 | auto observer = n->input(0); |
266 | auto observer_module = module.attr(findObserverName(v).value()).toModule(); |
267 | if (observer_module.hasattr("custom_op" )) { |
268 | auto op_name = observer_module.attr("custom_op" ).toStringRef(); |
269 | return isPlaceholderObserver(observer) ? std::move(op_name) : "" ; |
270 | } |
271 | return c10::nullopt; |
272 | } |
273 | |
274 | bool isEmbeddingBagOp( |
275 | Node* observer, |
276 | c10::optional<std::string> embedding_bag_name) { |
277 | return embedding_bag_name && |
278 | embedding_bag_name.value().find("embedding_bag_" ) != std::string::npos; |
279 | } |
280 | |
281 | template <typename T> |
282 | Node* insertQuantDequantNodes( |
283 | Value* self, |
284 | Node* observer, |
285 | T& qparams, |
286 | const std::string& quantize_func); |
287 | |
288 | // Insert quant and dequant nodes into the graph for both static and dynamic |
289 | // quant. |
290 | template <> |
291 | Node* insertQuantDequantNodes<std::vector<std::string>>( |
292 | Value* self, |
293 | Node* observer, |
294 | std::vector<std::string>& qparam_names, |
295 | const std::string& quantize_func) { |
296 | Graph* g = observer->owningGraph(); |
297 | Value* observer_out = observer->output(); |
298 | Value* original_val = observer->input(1); |
299 | std::vector<Value*> inputs = {observer_out}; |
300 | // Insert GetAttr nodes for quantization parameters |
301 | for (const auto& qparam_name : qparam_names) { |
302 | inputs.push_back(g->insertGetAttr(self, qparam_name)); |
303 | } |
304 | Node* quant = insertQuant( |
305 | g, |
306 | inputs, |
307 | at::Symbol::aten(quantize_func), |
308 | original_val->debugName() + ".quant" ); |
309 | Node* dequant = insertDeQuant(g, quant->output(), original_val); |
310 | return dequant; |
311 | } |
312 | |
313 | Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) { |
314 | Graph* g = observer->owningGraph(); |
315 | auto observer_out = observer->output(); |
316 | |
317 | std::string prepack_fn, quant_fn; |
318 | std::vector<Value*> prepack_inputs = {observer_out}; |
319 | if (op_name == "embedding_bag_4bit" ) { |
320 | bool optimized_qparams = false; |
321 | constexpr int NBINS = 200; |
322 | constexpr float RATIO = 0.16; |
323 | Value* optimized_qparams_false = g->insertConstant(optimized_qparams); |
324 | Value* nbins_200 = g->insertConstant(NBINS); |
325 | Value* ratio_0_16 = g->insertConstant(RATIO); |
326 | prepack_fn = "quantized::embedding_bag_4bit_prepack" ; |
327 | quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets" ; |
328 | prepack_inputs.push_back(optimized_qparams_false); |
329 | prepack_inputs.push_back(nbins_200); |
330 | prepack_inputs.push_back(ratio_0_16); |
331 | } else if (op_name == "embedding_bag_byte" ) { |
332 | prepack_fn = "quantized::embedding_bag_byte_prepack" ; |
333 | quant_fn = "quantized::embedding_bag_byte_rowwise_offsets" ; |
334 | } else { |
335 | TORCH_INTERNAL_ASSERT( |
336 | false, |
337 | "Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization." ); |
338 | } |
339 | |
340 | std::vector<Use> uses = observer_out->uses(); |
341 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
342 | Node* embedding_bag_float_op; |
343 | // We expect that the output of the weight observer will be consumed by the |
344 | // embedding_bag operator. |
345 | for (const Use& use : uses) { |
346 | if (matchCallFuncToUse(use, "embedding_bag" , 2) || |
347 | matchAtenFuncToUse(use, "embedding_bag" , 0)) { |
348 | embedding_bag_float_op = use.user; |
349 | } |
350 | } |
351 | |
352 | // Insert prepack op |
353 | Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs); |
354 | g->insertNode(prepack); |
355 | |
356 | std::vector<Value*> embedding_bag_inputs = |
357 | // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) |
358 | embedding_bag_float_op->inputs().vec(); |
359 | std::vector<Value*> qembedding_bag_inputs = {prepack->output()}; |
360 | const auto inputs_size = embedding_bag_float_op->inputs().size(); |
361 | const bool is_aten_op = |
362 | embedding_bag_float_op->kind() == Symbol::aten("embedding_bag" ); |
363 | // Create and insert quantized embedding op. |
364 | Value* none = g->insertConstant(IValue()); |
365 | Value* zero = g->insertConstant(IValue(0)); |
366 | bool pruned_wt = false; |
367 | auto pruned_const = g->insertConstant(pruned_wt); |
368 | |
369 | if (is_aten_op) { |
370 | TORCH_CHECK( |
371 | inputs_size == 9, |
372 | "Expecting FP aten::embedding_bag operator to have 9 inputs" ); |
373 | // input 0 is the output of prepack op. |
374 | // Last input is added after we account for extra input in 4-bit case. |
375 | for (unsigned long i = 1; i < inputs_size - 2; ++i) { |
376 | qembedding_bag_inputs.push_back(embedding_bag_inputs[i]); |
377 | } |
378 | // The sparse field in the float operator denotes sparse gradients. |
379 | // For inference this stands for pruned weights. We currently don't support |
380 | // pruning in graph mode API so we set the field to 0 for inference. |
381 | qembedding_bag_inputs[5] = pruned_const; |
382 | } else { |
383 | TORCH_CHECK( |
384 | inputs_size == 12, |
385 | "Expecting F.embedding_bag operator to have 12 inputs" ); |
386 | qembedding_bag_inputs.push_back(embedding_bag_inputs[1]); // indices |
387 | qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets |
388 | qembedding_bag_inputs.push_back( |
389 | embedding_bag_inputs[6]); // scale_grad_by_freq |
390 | qembedding_bag_inputs.push_back(zero); // mode |
391 | qembedding_bag_inputs.push_back(pruned_const); // pruned_weights |
392 | qembedding_bag_inputs.push_back( |
393 | embedding_bag_inputs[9]); // per_sample_weights |
394 | } |
395 | |
396 | qembedding_bag_inputs.push_back(none); // compressed_indices_mapping |
397 | qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 2]); |
398 | |
399 | TORCH_CHECK( |
400 | embedding_bag_inputs[inputs_size - 1]->mustBeNone(), |
401 | "Expected aten::embedding_bag padding_idx input to be None" ); |
402 | |
403 | Node* qembedding_bag = |
404 | g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs); |
405 | if (is_aten_op) { |
406 | WithInsertPoint ins(embedding_bag_float_op); |
407 | g->insertNode(qembedding_bag); |
408 | // Verify that the outputs (apart from index 0) have no uses in the graph. |
409 | for (const auto i : |
410 | c10::irange(1, embedding_bag_float_op->outputs().size())) { |
411 | TORCH_CHECK( |
412 | !embedding_bag_float_op->output(i)->hasUses(), |
413 | "Expected aten::embedding_bag to only have use for its first output." ); |
414 | } |
415 | } else { |
416 | g->insertNode(qembedding_bag); |
417 | } |
418 | embedding_bag_float_op->output(0)->replaceAllUsesWith( |
419 | qembedding_bag->output()); |
420 | embedding_bag_float_op->removeAllInputs(); |
421 | embedding_bag_float_op->destroy(); |
422 | g->lint(); |
423 | return qembedding_bag; |
424 | } |
425 | |
426 | template <typename T> |
427 | void insertQuantizationOps( |
428 | Module& module, |
429 | Value* self, |
430 | Node* observer, |
431 | bool is_per_channel, |
432 | T& qparams, |
433 | QuantType quant_type = QuantType::STATIC) { |
434 | Graph* g = observer->owningGraph(); |
435 | // Observer output |
436 | Value* observer_out = observer->output(); |
437 | // Inserting before insert point |
438 | WithInsertPoint ins(observer_out->node()->next()); |
439 | |
440 | std::string quantize_func; |
441 | if (is_per_channel) { |
442 | quantize_func = "quantize_per_channel" ; |
443 | } else { |
444 | quantize_func = "quantize_per_tensor" ; |
445 | } |
446 | Value* original_val = observer->input(1); |
447 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
448 | Node *quant, *choose_qparams, *dequant; |
449 | // Temporary solution to quantize embedding_bag operators. Will be re-written |
450 | // once we support quantization of embedding_bag weights. |
451 | auto embedding_bag_name = getEmbeddingBagObsName(module, observer); |
452 | if (isEmbeddingBagOp(observer, embedding_bag_name)) { |
453 | if (isWeight(module, observer_out)) { |
454 | auto op_name = embedding_bag_name.value(); |
455 | Node* dequant = insertEmbeddingBagOps(observer, op_name); |
456 | observer_out->replaceAllUsesWith(original_val); |
457 | original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output()); |
458 | } else { |
459 | // Special case for embedding bag operators indices input - we don't |
460 | // quantize the input but we still need to insert observers for it because |
461 | // the order of input and weight can be changed in the module code. |
462 | observer_out->replaceAllUsesWith(original_val); |
463 | } |
464 | return; |
465 | } |
466 | if (quant_type == QuantType::DYNAMIC) { |
467 | if (getObserverDtype(module, observer_out) == at::ScalarType::Half) { |
468 | dequant = insertFP16CastOps(g, observer_out); |
469 | } else if (!isWeight(module, observer_out)) { |
470 | auto observer_dtype = getObserverDtype(module, observer_out); |
471 | if (observer_dtype == at::ScalarType::QUInt8 || |
472 | observer_dtype == at::ScalarType::QInt8) { |
473 | // For activation tensors we insert choose_qparams, quant, dequant ops. |
474 | Value* dtype = g->insertGetAttr(self, qparams.back()); |
475 | std::tie(choose_qparams, quant, dequant) = |
476 | insertChooseQParamQuantDequant( |
477 | g, observer_out, dtype, at::Symbol::aten(quantize_func)); |
478 | } else { |
479 | // dtype does not require quantization, e.g. float32 |
480 | // will just remove the observer call |
481 | observer_out->replaceAllUsesWith(original_val); |
482 | return; |
483 | } |
484 | } else { |
485 | // For weight tensors we insert quant-dequant ops. |
486 | dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func); |
487 | } |
488 | } else { // Static quant |
489 | dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func); |
490 | } |
491 | observer_out->replaceAllUsesWith(original_val); |
492 | |
493 | original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output()); |
494 | GRAPH_DUMP("insert nodes:" , original_val->owningGraph()); |
495 | } |
496 | |
497 | void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) { |
498 | const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"( |
499 | graph(%a, %reduce_range, %a_dtype): |
500 | %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range) |
501 | %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype) |
502 | %a_dequant = aten::dequantize(%a_quant) |
503 | return (%a_dequant) )" ); |
504 | const Graph& dynamic_quant_graph = *dynamic_quant_pattern.pattern_graph; |
505 | |
506 | const auto& matches = findPatternMatches(dynamic_quant_graph, *graph); |
507 | if (matches.empty()) { |
508 | return; |
509 | } |
510 | |
511 | const auto& vmap = dynamic_quant_pattern.vmap; |
512 | Value* dequant_val = vmap.at("a_dequant" ); |
513 | Node* pattern_dequant = dequant_val->node(); |
514 | Value* quant_val = vmap.at("a_quant" ); |
515 | Node* pattern_quant = quant_val->node(); |
516 | Value* choose_qparam_val = vmap.at("a_scale" ); |
517 | Node* pattern_choose_qparam = choose_qparam_val->node(); |
518 | |
519 | std::vector<DynamicQuantOps> nodes_to_rewrite; |
520 | std::vector<Node*> choose_qparam_nodes_to_rewrite; |
521 | for (const Match& match : matches) { |
522 | Node* matched_dequantize = match.nodes_map.at(pattern_dequant); |
523 | Node* matched_quantize = match.nodes_map.at(pattern_quant); |
524 | Node* matched_choose_qparam = match.nodes_map.at(pattern_choose_qparam); |
525 | if (matched_dequantize->output()->uses().size() > 1) { |
526 | nodes_to_rewrite.emplace_back( |
527 | matched_choose_qparam, matched_quantize, matched_dequantize); |
528 | } |
529 | } |
530 | for (const auto& nodes : nodes_to_rewrite) { |
531 | auto quant_node = std::get<1>(nodes); |
532 | auto dequant_node = std::get<2>(nodes); |
533 | // get input of quantize call. |
534 | Value* original_val = quant_node->inputs()[0]; |
535 | Value* dequant_out = dequant_node->output(); |
536 | Value* dtype = quant_node->inputs()[3]; |
537 | std::vector<Use> uses = dequant_out->uses(); |
538 | for (const Use& use : uses) { |
539 | auto* user = use.user; |
540 | WithInsertPoint ins(user); |
541 | auto quant_ops = insertChooseQParamQuantDequant( |
542 | graph.get(), original_val, dtype, quant_node->kind()); |
543 | user->replaceInputWith(dequant_out, std::get<2>(quant_ops)->output()); |
544 | } |
545 | } |
546 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
547 | Node *choose_qparams, *quant, *dequant; |
548 | for (const auto& n : nodes_to_rewrite) { |
549 | std::tie(choose_qparams, quant, dequant) = n; |
550 | dequant->removeAllInputs(); |
551 | quant->removeAllInputs(); |
552 | choose_qparams->removeAllInputs(); |
553 | } |
554 | for (const auto& n : nodes_to_rewrite) { |
555 | std::tie(choose_qparams, quant, dequant) = n; |
556 | dequant->destroy(); |
557 | quant->destroy(); |
558 | choose_qparams->destroy(); |
559 | } |
560 | } |
561 | |
562 | void RemoveRedundantDequantize(std::shared_ptr<Graph>& graph) { |
563 | const std::string dequantize = R"( |
564 | graph(%a_quant): |
565 | %a_dequant = aten::dequantize(%a_quant) |
566 | return (%a_dequant) )" ; |
567 | const std::string dequantize_replacement = R"( |
568 | graph(%a): |
569 | return (%a) )" ; |
570 | auto filter = [&](const Match& match, |
571 | const std::unordered_map<std::string, Value*>& vmap) { |
572 | const auto& match_vmap = match.values_map; |
573 | auto dequant_node = match_vmap.at(vmap.at("a_dequant" ))->node(); |
574 | Value* dequant_out = dequant_node->output(); |
575 | // Values can be used multiple times in a single node |
576 | if (dequant_out->uses().size() != 1) { |
577 | return false; |
578 | } |
579 | Node* user = dequant_out->uses()[0].user; |
580 | return isTensorInfoNode(user); |
581 | }; |
582 | SubgraphRewriter rewriter; |
583 | rewriter.RegisterRewritePattern(dequantize, dequantize_replacement); |
584 | rewriter.runOnGraph(graph, filter); |
585 | } |
586 | |
587 | void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) { |
588 | const std::string dynamic_quant_ops = R"( |
589 | graph(%a, %reduce_range, %a_dtype): |
590 | %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range) |
591 | %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype) |
592 | %a_dequant = aten::dequantize(%a_quant) |
593 | return (%a_dequant) )" ; |
594 | const std::string dynamic_quant_replacement = R"( |
595 | graph(%a, %reduce_range, %a_dtype): |
596 | return (%a) )" ; |
597 | auto filter = [&](const Match& match, |
598 | const std::unordered_map<std::string, Value*>& vmap) { |
599 | const auto& match_vmap = match.values_map; |
600 | auto dequant_node = match_vmap.at(vmap.at("a_dequant" ))->node(); |
601 | Value* dequant_out = dequant_node->output(); |
602 | // Values can be used multiple times in a single node |
603 | if (dequant_out->uses().size() != 1) { |
604 | return false; |
605 | } |
606 | Node* user = dequant_out->uses()[0].user; |
607 | return !nodeQuantizable(user, QuantType::DYNAMIC); |
608 | }; |
609 | SubgraphRewriter rewriter; |
610 | rewriter.RegisterRewritePattern(dynamic_quant_ops, dynamic_quant_replacement); |
611 | rewriter.runOnGraph(graph, filter); |
612 | } |
613 | |
614 | void ReplicateClampScalarArgs(std::shared_ptr<Graph>& graph) { |
615 | std::stack<Block*> blocks_to_visit; |
616 | std::unordered_set<Node*> scalar_nodes_to_rewrite; |
617 | ; |
618 | blocks_to_visit.push(graph->block()); |
619 | while (!blocks_to_visit.empty()) { |
620 | Block* b = blocks_to_visit.top(); |
621 | blocks_to_visit.pop(); |
622 | for (Node* n : b->nodes()) { |
623 | for (Value* output : n->outputs()) { |
624 | if (getClampScalarInputUse(output) && output->uses().size() > 1) { |
625 | scalar_nodes_to_rewrite.insert(n); |
626 | } |
627 | } |
628 | for (Block* subblock : n->blocks()) { |
629 | blocks_to_visit.push(subblock); |
630 | } |
631 | } |
632 | } |
633 | |
634 | for (Node* n : scalar_nodes_to_rewrite) { |
635 | const std::vector<Use> uses = n->output()->uses(); |
636 | for (const auto& use : uses) { |
637 | Node* user = use.user; |
638 | WithInsertPoint ins(user); |
639 | Node* cloned_node = graph->createClone(n, [](Value* v) { return v; }); |
640 | graph->insertNode(cloned_node); |
641 | user->replaceInput(use.offset, cloned_node->output()); |
642 | } |
643 | } |
644 | |
645 | for (Node* n : scalar_nodes_to_rewrite) { |
646 | n->removeAllInputs(); |
647 | } |
648 | |
649 | for (Node* n : scalar_nodes_to_rewrite) { |
650 | n->destroy(); |
651 | } |
652 | } |
653 | |
654 | void checkCalculateQParamsResult(const IValue& qparams) { |
655 | TORCH_CHECK( |
656 | qparams.isTuple(), |
657 | "`calculate_qparams` function is expected to return a " |
658 | "Tuple, but got:" , |
659 | qparams.tagKind()); |
660 | auto tp = qparams.toTuple(); |
661 | TORCH_CHECK( |
662 | tp->elements().size() == 2, |
663 | "`calculate_qparams` function is expected to return a " |
664 | "Tuple of size 2, got Tuple of size " , |
665 | tp->elements().size()); |
666 | // Expect first two elements of the tuple to be Tensor |
667 | for (const auto i : c10::irange(2)) { |
668 | TORCH_CHECK( |
669 | tp->elements()[i].isTensor(), |
670 | "Element of Tuple is expected to be Tensor, but element " , |
671 | i, |
672 | " has type: " , |
673 | tp->elements()[i].tagKind()); |
674 | } |
675 | } |
676 | |
677 | class SubGraphCloneHelper { |
678 | public: |
679 | // Given a list of nodes, build a graph corresponding to these nodes. |
680 | // User should make sure to run this graph with expected input. |
681 | std::unique_ptr<GraphFunction> buildGraphFromNodes( |
682 | const std::vector<Node*>& nodes, |
683 | const std::string& name); |
684 | |
685 | // Given a list of nodes in src, produce a Graph with these nodes. |
686 | void buildObserverSubgraph( |
687 | const std::vector<Node*>& src, |
688 | std::shared_ptr<Graph> dest); |
689 | |
690 | private: |
691 | // Clone node in the destination Graph g. |
692 | void cloneNodeInGraph( |
693 | Node* node, |
694 | std::shared_ptr<Graph>& g, |
695 | std::unordered_map<Value*, Value*>& remap_values); |
696 | }; |
697 | |
698 | class InsertQuantDeQuantHelper { |
699 | public: |
700 | InsertQuantDeQuantHelper(QuantType quant_type, bool debug) |
701 | : quant_type_(quant_type), debug_(debug) {} |
702 | |
703 | void run(Module& module, const std::string& method_name); |
704 | |
705 | void runForOnDevicePTQ(Module& module, const std::string& method_name); |
706 | |
707 | // Cleanup observer nodes from graph and observer modules |
708 | // from module object and ClassType |
709 | void cleanup(Module& module); |
710 | |
711 | // Cleanup observer nodes only but not modules |
712 | // This is for ondevice PTQ |
713 | void removeObserverNodes(Module& m); |
714 | |
715 | // In order to propagate quantization ops through the ops that doesn't |
716 | // require observation, we'll first inline the graph, and call the |
717 | // PropgateQuantizationOps pass |
718 | void propagateQuantizationOps(Module& module); |
719 | |
720 | // Used for dynamic quantization to selectively run the weight observers. |
721 | // It extracts the subgraph corresponding to the weight and runs it with |
722 | // the module instance. |
723 | void runWeightObserver(Module& module, const std::string& method_name); |
724 | |
725 | private: |
726 | ModuleMethodVector getInvokedMethods( |
727 | Module& module, |
728 | const std::string& method_name); |
729 | |
730 | // Get quantization parameter map of the given Value in Graph |
731 | // by searching for observer module of the value and extract the |
732 | // quantization parameters from the observer module |
733 | std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector( |
734 | script::Module& module, |
735 | Node* n); |
736 | QuantOpParams insertCalculateQParams( |
737 | script::Module& module, |
738 | Graph* g, |
739 | Node* n); |
740 | |
741 | void checkQScheme(Graph* g, c10::QScheme qscheme) { |
742 | if (qscheme_for_graph_.count(g)) { |
743 | // FIXME[T110786721]: This check was broken before nevery failing. |
744 | // Once fixed, this check triggers and fails tests. |
745 | // Fix the tests that enabling this check produce! |
746 | /* |
747 | TORCH_CHECK( |
748 | qscheme_for_graph_.at(g) == qscheme, |
749 | "Quantizing same graph with different types of " |
750 | "QSchemes is not supported.\n", |
751 | " Expecting:", |
752 | c10::toString(qscheme_for_graph_.at(g)), |
753 | " Got:", |
754 | c10::toString(qscheme)); |
755 | */ |
756 | } else { |
757 | qscheme_for_graph_[g] = toAffine(qscheme); |
758 | } |
759 | } |
760 | |
761 | void collectObserverNodesAndValueToQuantize(Module& module, Value*); |
762 | void cleanup(Module& module, Graph* g); |
763 | void removeObserverNodes(Graph* g); |
764 | void quantizeTensors(Module& module, Graph* g, Value* self); |
765 | void insertCalculateQParamsAndQuantizationOps( |
766 | Module& module, |
767 | Graph* g, |
768 | Value* self); |
769 | |
770 | // Function that extracts and runs the weight observer in a separate |
771 | // subgraph. |
772 | void extractAndRunWeightObserver( |
773 | Module& module, |
774 | Value* self, |
775 | Value* weight_value); |
776 | |
777 | // Recursively find the nodes that produce the value and add to subgraph. |
778 | void findSubgraph(Value* self, Value* v, std::vector<Node*>& weight_subgraph); |
779 | |
780 | // Quantizes two types of general ops(ops that works both for floating point |
781 | // and quantized Tensors) in this pass |
782 | // for ops that only manipulates shape, e.g. flatten, quantization |
783 | // is done by swapping with previous dequantize op |
784 | // for ops that manipulates values of Tensor, e.g. average pool, quantization |
785 | // is done by inserting quant/dequant ops after the op |
786 | // also has a special handling of clamp/hardtanh |
787 | void propagateQuantizationOps(Block* block); |
788 | |
789 | // Propagate quantization parameters from other quantized tensors |
790 | void propagateQParams( |
791 | Value* original_output, |
792 | const std::vector<Value*>& inputs, |
793 | bool is_scalar = false, |
794 | const c10::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt = |
795 | c10::nullopt); |
796 | |
797 | bool isQuantized(Value* v) { |
798 | return quantized_values_.count(v) != 0; |
799 | } |
800 | |
801 | std::unordered_map<Graph*, std::vector<std::string>> |
802 | observer_modules_to_remove_; |
803 | // We only remove observer module attributes from type in the |
804 | // first encounter of the graph, after that since the attributes |
805 | // is already removed from the ClassType, we'll use the list of slot index to |
806 | // replay this removal |
807 | std::unordered_map<Graph*, std::vector<int>> removed_observer_slots_; |
808 | std::unordered_map<Graph*, std::vector<Node*>> nodes_to_destroy_; |
809 | // Map from Graph to observer node, we can use observer node to |
810 | // get the information of original value that's been observed and |
811 | // the quantization parameters |
812 | std::unordered_map<Graph*, std::vector<Node*>> observer_nodes_for_graph_; |
813 | // A map from qparam name (e.g. _scale) to the attribute name in |
814 | // the module(e.g. weight_scale_0) |
815 | std::unordered_map<Node*, std::unordered_map<std::string, std::string>> |
816 | qparam_name_map_for_node_; |
817 | // Record qscheme for every graph, this is for checking |
818 | // each graph is only quantized with one type of QScheme |
819 | std::unordered_map<Graph*, c10::QScheme> qscheme_for_graph_; |
820 | |
821 | // Set of quantized values, so that we quantize each value only |
822 | // once |
823 | std::unordered_set<Value*> quantized_values_; |
824 | |
825 | // Map from original weight value to GraphFunction corresponding to the |
826 | // subgraph that includes the weight observer and dependent nodes. |
827 | std::unordered_map<Value*, std::unique_ptr<GraphFunction>> |
828 | weight_to_graph_fn_; |
829 | |
830 | QuantType quant_type_ = QuantType::STATIC; |
831 | bool debug_ = false; |
832 | }; |
833 | |
834 | void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize( |
835 | Module& module, |
836 | Value* v) { |
837 | auto* g = v->owningGraph(); |
838 | auto observer_name = findObserverName(v); |
839 | if (!observer_name) { |
840 | return; |
841 | } |
842 | observer_modules_to_remove_[g].push_back(observer_name.value()); |
843 | |
844 | Node* observer = v->node(); |
845 | TORCH_INTERNAL_ASSERT( |
846 | observer->kind() == prim::CallMethod && |
847 | observer->s(attr::name) == "forward" && |
848 | observer->inputs()[0]->node()->kind() == prim::GetAttr && |
849 | observer->inputs()[0]->node()->s(attr::name) == observer_name); |
850 | |
851 | // Observer forward call node |
852 | nodes_to_destroy_[g].push_back(observer); |
853 | // GetAttr node for observer module |
854 | nodes_to_destroy_[g].push_back(observer->inputs()[0]->node()); |
855 | observer_nodes_for_graph_[g].push_back(observer); |
856 | } |
857 | |
858 | void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) { |
859 | for (auto& method : module.get_methods()) { |
860 | removeObserverNodes(method.graph().get()); |
861 | } |
862 | for (Module m : module.children()) { |
863 | removeObserverNodes(m); |
864 | } |
865 | } |
866 | |
867 | void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) { |
868 | if (nodes_to_destroy_.count(g)) { |
869 | for (auto& n : nodes_to_destroy_.at(g)) { |
870 | n->removeAllInputs(); |
871 | } |
872 | for (auto& n : nodes_to_destroy_.at(g)) { |
873 | n->destroy(); |
874 | } |
875 | nodes_to_destroy_.at(g).clear(); |
876 | } |
877 | } |
878 | |
879 | void InsertQuantDeQuantHelper::cleanup(Module& module) { |
880 | for (auto& method : module.get_methods()) { |
881 | cleanup(module, method.graph().get()); |
882 | } |
883 | for (Module m : module.children()) { |
884 | cleanup(m); |
885 | } |
886 | } |
887 | |
888 | void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) { |
889 | GRAPH_DUMP("Before Remove Observers:" , g); |
890 | removeObserverNodes(g); |
891 | |
892 | // 1. If we have seen this graph before, this means the observer |
893 | // attributes has been removed from the type(see step 2) but the slot |
894 | // index of these attributes are kept in the list, we'll replay the observer |
895 | // slots removal using these slot indexes |
896 | if (removed_observer_slots_.count(g)) { |
897 | for (auto slot : removed_observer_slots_.at(g)) { |
898 | module._ivalue()->unsafeRemoveSlot(slot); |
899 | } |
900 | } |
901 | |
902 | // 2. Remove observer modules from last one to first one in order to |
903 | // reduce the time complexity, assuming all the observer modules |
904 | // are added after the existing modules, we'll have complexity of |
905 | // O(N) where N is number of observer modules with this optimization |
906 | if (observer_modules_to_remove_.count(g)) { |
907 | auto& observers = observer_modules_to_remove_.at(g); |
908 | for (int64_t i = observers.size() - 1; i >= 0; --i) { |
909 | auto observer_name = observers[i]; |
910 | GRAPH_DEBUG("Trying to remove: " , observer_name); |
911 | if (module.type()->hasAttribute(observer_name)) { |
912 | // We record the slot index here in order to replay the |
913 | // slot removal in other objects that's sharing the ClassType |
914 | // since we're going to remove attribute in the ClassType here |
915 | removed_observer_slots_[g].push_back( |
916 | module.type()->getAttributeSlot(observer_name)); |
917 | module._ivalue()->unsafeRemoveAttr(observer_name); |
918 | module.type()->unsafeRemoveAttribute(observer_name); |
919 | } |
920 | } |
921 | observers.clear(); |
922 | } |
923 | GRAPH_DUMP("After remove observers :" , g); |
924 | } |
925 | |
926 | void SubGraphCloneHelper::cloneNodeInGraph( |
927 | Node* node, |
928 | std::shared_ptr<Graph>& g, |
929 | std::unordered_map<Value*, Value*>& remap_old_to_new) { |
930 | auto* block = g->block(); |
931 | auto value_fn = [&](Value* v) { |
932 | if (remap_old_to_new.count(v) == 0) { |
933 | auto new_value = g->block()->addInput(); |
934 | remap_old_to_new[v] = new_value; |
935 | new_value->copyMetadata(v); |
936 | return new_value; |
937 | } else { |
938 | return remap_old_to_new[v]; |
939 | } |
940 | }; |
941 | |
942 | auto new_node = block->appendNode(g->createClone(node, value_fn)); |
943 | for (size_t i = 0; i < node->outputs().size(); ++i) { |
944 | auto oo = node->outputs()[i]; |
945 | auto no = new_node->outputs()[i]; |
946 | remap_old_to_new[oo] = no; |
947 | } |
948 | } |
949 | |
950 | void SubGraphCloneHelper::buildObserverSubgraph( |
951 | const std::vector<Node*>& weight_subgraph, |
952 | std::shared_ptr<Graph> dest_graph) { |
953 | std::unordered_map<Value*, Value*> remap_old_to_new; |
954 | // Build weight subgraph |
955 | for (auto n : weight_subgraph) { |
956 | cloneNodeInGraph(n, dest_graph, remap_old_to_new); |
957 | } |
958 | LintGraph(dest_graph); |
959 | |
960 | // Add last node output value as subgraph output. |
961 | for (auto out : weight_subgraph.back()->outputs()) { |
962 | dest_graph->registerOutput(remap_old_to_new[out]); |
963 | } |
964 | GRAPH_DUMP("New weight observer subgraph: " , dest_graph); |
965 | } |
966 | |
967 | std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes( |
968 | const std::vector<Node*>& nodes, |
969 | const std::string& name) { |
970 | auto observer_subgraph = std::make_shared<Graph>(); |
971 | auto build_observer_graph = [&](GraphFunction& func) { |
972 | buildObserverSubgraph(nodes, func.graph()); |
973 | }; |
974 | return torch::make_unique<GraphFunction>( |
975 | name, observer_subgraph, build_observer_graph); |
976 | } |
977 | |
978 | void InsertQuantDeQuantHelper::findSubgraph( |
979 | Value* self, |
980 | Value* input_val, |
981 | std::vector<Node*>& weight_subgraph) { |
982 | Node* node = input_val->node(); |
983 | weight_subgraph.push_back(node); |
984 | const auto& inputs = node->inputs().vec(); |
985 | for (auto v : inputs) { |
986 | if (!hitGraphInput(v)) { |
987 | findSubgraph(self, v, weight_subgraph); |
988 | } else { |
989 | TORCH_CHECK( |
990 | v == self, |
991 | "Unexpected value found when handling weight value " |
992 | " in findSubgraph, traced back to:" , |
993 | v->debugName(), |
994 | " which is not self:" , |
995 | self->debugName()); |
996 | } |
997 | } |
998 | } |
999 | |
1000 | void InsertQuantDeQuantHelper::extractAndRunWeightObserver( |
1001 | Module& module, |
1002 | Value* self, |
1003 | Value* weight_value) { |
1004 | std::vector<Node*> weight_subgraph; |
1005 | // If the graph was already visited, return the GraphFunction directly. |
1006 | // Multiple module instances can share the same graph code, so we don't need |
1007 | // to re-run the extraction process. |
1008 | if (weight_to_graph_fn_.count(weight_value) == 0) { |
1009 | // Extract the subgraph nodes. |
1010 | findSubgraph(self, weight_value, weight_subgraph); |
1011 | |
1012 | // Reverse to traverse subgraph in correct direction |
1013 | std::reverse(weight_subgraph.begin(), weight_subgraph.end()); |
1014 | |
1015 | // Build the graph using the nodes found from the weight observer. |
1016 | SubGraphCloneHelper o; |
1017 | std::unique_ptr<GraphFunction> func = |
1018 | o.buildGraphFromNodes(weight_subgraph, "observer_subgraph" ); |
1019 | weight_to_graph_fn_[weight_value] = std::move(func); |
1020 | } |
1021 | Stack module_inp = {module._ivalue()}; |
1022 | // Run the graph with the module input. |
1023 | weight_to_graph_fn_[weight_value]->run(module_inp); |
1024 | } |
1025 | |
1026 | void InsertQuantDeQuantHelper::quantizeTensors( |
1027 | Module& module, |
1028 | Graph* g, |
1029 | Value* self) { |
1030 | if (!observer_nodes_for_graph_.count(g)) { |
1031 | return; |
1032 | } |
1033 | for (auto* n : observer_nodes_for_graph_.at(g)) { |
1034 | auto* original_value = n->input(1); |
1035 | auto tp = getQSchemeAndQParamVector(module, n); |
1036 | auto qscheme = std::get<0>(tp); |
1037 | auto qparam_map = std::get<1>(tp); |
1038 | checkQScheme(g, qscheme); |
1039 | std::vector<std::string> qparam_names; |
1040 | for (auto& pr : qparam_map) { |
1041 | const auto& name = pr.first; |
1042 | const auto& qparam = pr.second; |
1043 | size_t uid = 0; |
1044 | auto qparam_name = |
1045 | original_value->debugName() + name + "_" + c10::to_string(uid++); |
1046 | while (module.hasattr(qparam_name)) { |
1047 | qparam_name = |
1048 | original_value->debugName() + name + "_" + c10::to_string(uid++); |
1049 | } |
1050 | qparam_name_map_for_node_[n][name] = qparam_name; |
1051 | module.register_attribute(qparam_name, qparam.type(), qparam); |
1052 | qparam_names.push_back(qparam_name); |
1053 | } |
1054 | insertQuantizationOps( |
1055 | module, self, n, isPerChannel(qscheme), qparam_names, quant_type_); |
1056 | } |
1057 | } |
1058 | |
1059 | std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper:: |
1060 | getQSchemeAndQParamVector(script::Module& module, Node* n) { |
1061 | // TODO: refactor findObserverName to take Node* as input |
1062 | Value* v = n->output(); |
1063 | TORCH_INTERNAL_ASSERT( |
1064 | v->type()->isSubtypeOf(*TensorType::get()), |
1065 | "Expected output of observer node to be Tensor" ); |
1066 | auto observer_name = findObserverName(v); |
1067 | TORCH_INTERNAL_ASSERT( |
1068 | observer_name, |
1069 | "getQSchemeAndParamMap expects the corresponding observer for " , |
1070 | v->debugName(), |
1071 | " exists." ); |
1072 | QParamVector qparams; |
1073 | c10::QScheme qscheme = c10::kPerTensorAffine; |
1074 | |
1075 | auto observer_module = module.attr(observer_name.value()).toModule(); |
1076 | auto scalar_type = observer_module.attr("dtype" ); |
1077 | if (isPlaceholderObserver(n->input(0))) { |
1078 | // get compute_dtype for dynamic quantization |
1079 | if (observer_module.hasattr("is_dynamic" ) && |
1080 | observer_module.attr("is_dynamic" ).toBool()) { |
1081 | qparams.emplace_back(kScalarType, observer_module.attr("dtype" )); |
1082 | } |
1083 | return std::make_tuple(qscheme, std::move(qparams)); |
1084 | } else if (scalar_type == at::ScalarType::Half) { |
1085 | return std::make_tuple(qscheme, std::move(qparams)); |
1086 | } |
1087 | auto calculate_qparams = observer_module.get_method("calculate_qparams" ); |
1088 | IValue result = calculate_qparams(std::vector<IValue>()); |
1089 | checkCalculateQParamsResult(result); |
1090 | TORCH_CHECK( |
1091 | scalar_type.toScalarType() != at::ScalarType::Undefined, |
1092 | "dtype of observer can't be undefined" ); |
1093 | auto tp = result.toTuple(); |
1094 | at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat); |
1095 | at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt); |
1096 | // quantization parameters should appear in the same order as |
1097 | // the argument for quantize_per_tensor/quantize_per_channel function |
1098 | |
1099 | qscheme = observer_module.attr("qscheme" ).toQScheme(); |
1100 | if (isPerChannel(qscheme)) { |
1101 | auto axis = observer_module.attr("ch_axis" ); |
1102 | qparams.emplace_back("_scale" , scale); |
1103 | qparams.emplace_back("_zero_point" , zero_point); |
1104 | qparams.emplace_back("_axis" , axis.toInt()); |
1105 | } else { |
1106 | qparams.emplace_back("_scale" , scale.item<double>()); |
1107 | qparams.emplace_back("_zero_point" , zero_point.item<int64_t>()); |
1108 | } |
1109 | qparams.emplace_back(kScalarType, scalar_type); |
1110 | return std::make_tuple(qscheme, std::move(qparams)); |
1111 | } |
1112 | |
1113 | ModuleMethodVector InsertQuantDeQuantHelper::getInvokedMethods( |
1114 | Module& module, |
1115 | const std::string& method_name) { |
1116 | auto graph = module.get_method(method_name).graph(); |
1117 | |
1118 | ModuleMethodVector invoked_methods; |
1119 | std::stack<Block*> blocks_to_visit; |
1120 | blocks_to_visit.push(graph->block()); |
1121 | while (!blocks_to_visit.empty()) { |
1122 | Block* b = blocks_to_visit.top(); |
1123 | blocks_to_visit.pop(); |
1124 | for (Node* n : b->nodes()) { |
1125 | if (n->kind() == prim::CallMethod) { |
1126 | auto module_instance = n->inputs()[0]; |
1127 | auto module_method_name = n->s(attr::name); |
1128 | c10::optional<Module> m; |
1129 | // calling method on self |
1130 | if (module_instance == graph->inputs()[0]) { |
1131 | m = module; |
1132 | } else if ( |
1133 | module_instance->node()->kind() == prim::GetAttr && |
1134 | module_instance->node()->s(attr::name).find("_observer_" ) == |
1135 | std::string::npos) { |
1136 | m = getInvokedModuleOpt(module, n, graph->inputs()[0]); |
1137 | } |
1138 | if (m) { |
1139 | invoked_methods.emplace_back(*m, module_method_name); |
1140 | } |
1141 | } |
1142 | |
1143 | for (Block* subblock : n->blocks()) { |
1144 | blocks_to_visit.push(subblock); |
1145 | } |
1146 | } |
1147 | } |
1148 | return invoked_methods; |
1149 | } |
1150 | |
1151 | void InsertQuantDeQuantHelper::propagateQParams( |
1152 | Value* original_output, |
1153 | const std::vector<Value*>& inputs, |
1154 | bool is_scalar, |
1155 | const c10::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt) { |
1156 | Node* n = original_output->node(); |
1157 | Graph* graph = n->owningGraph(); |
1158 | if (is_scalar) { |
1159 | // convert Scalar to Tensor |
1160 | n = insertScalarToTensor(graph, original_output); |
1161 | original_output = n->output(); |
1162 | } |
1163 | // for ops like average pool, we'll insert quant dequant after the op |
1164 | // We'll assume the tensor is a PerTensorAffine quantized Tensor for |
1165 | // now, and may generalize later if this becomes an issue |
1166 | TORCH_INTERNAL_ASSERT( |
1167 | inputs.size() == 1, "Expecting single input for the aten function" ); |
1168 | // input of the dequantize node |
1169 | Value* quantized_input = inputs[0]->node()->input(0); |
1170 | // insert ops after the general op |
1171 | Node* quantized_input_node = quantized_input->node(); |
1172 | // Insert after the node that is later in topological order |
1173 | WithInsertPoint ins( |
1174 | quantized_input_node->isAfter(n) ? quantized_input_node->next() |
1175 | : n->next()); |
1176 | std::vector<Value*> quant_inputs; |
1177 | auto quant_kind = Symbol::aten("quantize_per_tensor" ); |
1178 | if (qparams_opt.has_value()) { |
1179 | quant_inputs = {original_output}; |
1180 | auto qscheme = std::get<0>(*qparams_opt); |
1181 | auto qparams = std::get<1>(*qparams_opt); |
1182 | if (isPerChannel(qscheme)) { |
1183 | quant_kind = Symbol::aten("quantize_per_channel" ); |
1184 | } |
1185 | for (const auto& qparam : qparams) { |
1186 | Value* qparam_val = graph->insertConstant(qparam.second); |
1187 | qparam_val->setDebugName(quantized_input->debugName() + qparam.first); |
1188 | quant_inputs.push_back(qparam_val); |
1189 | } |
1190 | } else { |
1191 | // Only per tensor affine quantized tensor is supported in this case |
1192 | // get quantization parameters from previous quantized op |
1193 | Node* scale = insertQParam( |
1194 | graph, |
1195 | quantized_input, |
1196 | at::Symbol::aten("q_scale" ), |
1197 | FloatType::get(), |
1198 | "q_scale" ); |
1199 | Node* zero_point = insertQParam( |
1200 | graph, |
1201 | quantized_input, |
1202 | at::Symbol::aten("q_zero_point" ), |
1203 | IntType::get(), |
1204 | "q_zero_point" ); |
1205 | Node* dtype = insertQParam( |
1206 | graph, quantized_input, prim::dtype, IntType::get(), "dtype" ); |
1207 | quant_inputs = { |
1208 | original_output, |
1209 | scale->output(), |
1210 | zero_point->output(), |
1211 | dtype->output()}; |
1212 | } |
1213 | Node* quant = insertQuant( |
1214 | graph, quant_inputs, quant_kind, original_output->debugName() + ".quant" ); |
1215 | Value* quantized_output = quant->output(); |
1216 | // replace uses of original output of the general op with quantized |
1217 | // output |
1218 | original_output->replaceAllUsesAfterNodeWith(quant, quantized_output); |
1219 | const auto& outputs = |
1220 | insertDeQuantForAllUse(graph, quantized_output, quantized_output); |
1221 | for (auto* output : outputs) { |
1222 | if (is_scalar) { |
1223 | // Convert the dequantized Tensor back to Scalar |
1224 | Node* item = insertItem(graph, output, FloatType::get()); |
1225 | Value* scalar = item->output(); |
1226 | output->replaceAllUsesAfterNodeWith(item, scalar); |
1227 | output = scalar; |
1228 | } |
1229 | quantized_values_.insert(output); |
1230 | } |
1231 | } |
1232 | |
1233 | void removeDequantizeFromInputs(const std::unordered_set<Value*>& inputs) { |
1234 | // Delete dequantize node, we have one dequantize |
1235 | // for each use of the value |
1236 | for (auto* dequantized_val : inputs) { |
1237 | auto* dequantize_node = dequantized_val->node(); |
1238 | TORCH_INTERNAL_ASSERT( |
1239 | dequantized_val->uses().size() == 1, |
1240 | "Expect to have one dequantize node for each use" ); |
1241 | // Replace useses of dequantized_val with the input of |
1242 | // dequantize node |
1243 | dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]); |
1244 | dequantize_node->removeAllInputs(); |
1245 | dequantize_node->destroy(); |
1246 | } |
1247 | } |
1248 | |
1249 | // Check if we need to propagate the quantization ops from input to |
1250 | // output |
1251 | c10::optional<std::vector<Value*>> getDequantizedInputs(Value* output) { |
1252 | auto inputs = getPassThroughInputs(output); |
1253 | if (!inputs.empty()) { |
1254 | // note that we don't need to recursively check for prim::If |
1255 | // here because if all inputs of a prim::If is dequantized |
1256 | // the dequantize will be factored out before we get to this |
1257 | // point |
1258 | bool is_dequantized = true; |
1259 | for (auto* input : inputs) { |
1260 | GRAPH_DEBUG( |
1261 | "checking if input:" , |
1262 | input->debugName(), |
1263 | " in node:" , |
1264 | *input->node(), |
1265 | "is quantized" ); |
1266 | is_dequantized &= input->node()->kind() == Symbol::aten("dequantize" ); |
1267 | } |
1268 | if (is_dequantized) { |
1269 | return inputs; |
1270 | } |
1271 | } |
1272 | return c10::nullopt; |
1273 | } |
1274 | |
1275 | void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) { |
1276 | for (Node* n : block->nodes()) { |
1277 | if (n->kind() == prim::If) { |
1278 | for (Block* subblock : n->blocks()) { |
1279 | propagateQuantizationOps(subblock); |
1280 | } |
1281 | if (n->outputs().empty()) { |
1282 | continue; |
1283 | } |
1284 | if (n->outputs().size() > 1) { |
1285 | // Factoring out dequantize for if blocks with multiple outputs |
1286 | // is not supported right now |
1287 | continue; |
1288 | } |
1289 | } |
1290 | if (isSingleInputGeneralValueAtenFunction(n)) { |
1291 | for (auto* output : n->outputs()) { |
1292 | if (isQuantized(output)) { |
1293 | continue; |
1294 | } |
1295 | if (auto inputs = getDequantizedInputs(output)) { |
1296 | propagateQParams(output, *inputs); |
1297 | if (isClamp(n)) { |
1298 | for (size_t i = 1; i <= 2; ++i) { |
1299 | // propagate qparams for min and max scalar arguments |
1300 | // for aten::clamp/aten::hardtanh |
1301 | propagateQParams(n->input(i), *inputs, /* is_scalar */ true); |
1302 | } |
1303 | } |
1304 | } |
1305 | } |
1306 | } else if (auto qparams_opt = getFixedQParams(n)) { |
1307 | for (auto* output : n->outputs()) { |
1308 | if (isQuantized(output)) { |
1309 | continue; |
1310 | } |
1311 | if (auto inputs = getDequantizedInputs(output)) { |
1312 | propagateQParams(output, *inputs, /* is_scalar */ false, qparams_opt); |
1313 | } |
1314 | } |
1315 | } else { |
1316 | // For ops that are quantized by propagating dequantize ops, |
1317 | // e.g. flatten we need to |
1318 | // 1. check if we need to propagate dequantize op |
1319 | // 2. remove the dequantize ops from inputs |
1320 | // 3. insert dequantize for all outputs |
1321 | // to make sure it works for ops with multiple outputs |
1322 | // since removing dequantize from inputs is mutating the graph |
1323 | // and it will affect future checks for whether all the inputs |
1324 | // has been quantized or not(since currently we just check if |
1325 | // the value is produced by dequantize op to decide if the value |
1326 | // is quantized or not |
1327 | // list of dequantized input values |
1328 | std::unordered_set<Value*> dequantized_inputs; |
1329 | std::vector<Value*> outputs_to_dequantize; |
1330 | // 1. collect dequantized inputs and outputs we need to dequantize |
1331 | for (auto* output : n->outputs()) { |
1332 | if (isQuantized(output)) { |
1333 | continue; |
1334 | } |
1335 | if (auto inputs = getDequantizedInputs(output)) { |
1336 | std::copy( |
1337 | inputs->begin(), |
1338 | inputs->end(), |
1339 | std::inserter(dequantized_inputs, dequantized_inputs.end())); |
1340 | outputs_to_dequantize.push_back(output); |
1341 | } |
1342 | } |
1343 | // 2. remove the dequantize ops from inputs |
1344 | removeDequantizeFromInputs(dequantized_inputs); |
1345 | // 3. insert dequantize op for outpus |
1346 | for (auto* output : outputs_to_dequantize) { |
1347 | insertDeQuantForAllUse(output->owningGraph(), output, output); |
1348 | } |
1349 | } |
1350 | |
1351 | if (isBinaryOpWithScalarInput(n)) { |
1352 | // Print warning for add_scalar/mul_scalar when debug is enabled |
1353 | // since the quantization parameter for these ops depends on |
1354 | // input and it's too complicated to encode the equations in |
1355 | // the IR: |
1356 | // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp#L64-L74 |
1357 | if (debug_) { |
1358 | TORCH_WARN_ONCE( |
1359 | "debug option for add_scalar and mul_scalar is not supported, " |
1360 | "please don't use debug option for models that uses these ops." ); |
1361 | } |
1362 | } |
1363 | } |
1364 | } |
1365 | |
1366 | void InsertQuantDeQuantHelper::runWeightObserver( |
1367 | Module& module, |
1368 | const std::string& method_name) { |
1369 | if (quant_type_ != QuantType::DYNAMIC) { |
1370 | return; |
1371 | } |
1372 | |
1373 | for (auto& invoked_methods : getInvokedMethods(module, method_name)) { |
1374 | auto& invoked_module = std::get<0>(invoked_methods); |
1375 | const auto& invoked_method_name = std::get<1>(invoked_methods); |
1376 | runWeightObserver(invoked_module, invoked_method_name); |
1377 | } |
1378 | Method method = module.get_method(method_name); |
1379 | auto graph = method.graph(); |
1380 | Value* self = graph->inputs()[0]; |
1381 | |
1382 | std::vector<Value*> weight_values; |
1383 | // Visit all blocks in the current graph to find weight values. |
1384 | std::stack<Block*> blocks_to_visit; |
1385 | blocks_to_visit.push(graph->block()); |
1386 | while (!blocks_to_visit.empty()) { |
1387 | Block* b = blocks_to_visit.top(); |
1388 | blocks_to_visit.pop(); |
1389 | for (auto n : b->nodes()) { |
1390 | for (Value* v : n->outputs()) { |
1391 | if (!v->type()->isSubtypeOf(*TensorType::get())) { |
1392 | continue; |
1393 | } |
1394 | auto observer_name = findObserverName(v); |
1395 | if (observer_name && isWeight(module, v)) { |
1396 | weight_values.push_back(v); |
1397 | } |
1398 | } |
1399 | for (Block* subblock : n->blocks()) { |
1400 | blocks_to_visit.push(subblock); |
1401 | } |
1402 | } |
1403 | } |
1404 | // For all the observed weight values, find the corresponding subgraph that |
1405 | // contributes to the weight tensor, and run that subgraph to observe the |
1406 | // weight. |
1407 | for (const auto& v : weight_values) { |
1408 | extractAndRunWeightObserver(module, self, v); |
1409 | } |
1410 | } |
1411 | |
1412 | void InsertQuantDeQuantHelper::run( |
1413 | Module& module, |
1414 | const std::string& method_name) { |
1415 | for (auto& invoked_methods : getInvokedMethods(module, method_name)) { |
1416 | auto& invoked_module = std::get<0>(invoked_methods); |
1417 | const auto& invoked_method_name = std::get<1>(invoked_methods); |
1418 | run(invoked_module, invoked_method_name); |
1419 | } |
1420 | |
1421 | Method method = module.get_method(method_name); |
1422 | auto graph = method.graph(); |
1423 | // We only need to register new parameters if the graph has |
1424 | // been quantized before |
1425 | // TODO: dedup this part with code in quantizeTensors |
1426 | if (observer_nodes_for_graph_.count(graph.get())) { |
1427 | for (auto* n : observer_nodes_for_graph_.at(graph.get())) { |
1428 | auto tp = getQSchemeAndQParamVector(module, n); |
1429 | checkQScheme(graph.get(), std::get<0>(tp)); |
1430 | auto qparam_map = std::get<1>(tp); |
1431 | // We check the size here because for some observers (like |
1432 | // PlaceholderObserver) the qparams might be empty. |
1433 | if (!qparam_map.empty()) { |
1434 | TORCH_INTERNAL_ASSERT( |
1435 | qparam_name_map_for_node_.count(n), |
1436 | "Expected to have a qparam_name_map for node:" , |
1437 | *n); |
1438 | auto qparam_name_map = qparam_name_map_for_node_.at(n); |
1439 | for (auto& pr : qparam_map) { |
1440 | const auto& name = pr.first; |
1441 | const auto& qparam = pr.second; |
1442 | module._ivalue()->setAttr(qparam_name_map.at(name), qparam); |
1443 | } |
1444 | } |
1445 | } |
1446 | return; |
1447 | } |
1448 | |
1449 | // prim::Param nodes do not belong to the graph. Hence the Insert |
1450 | // point is the beginning of graph node. This also safe guards against |
1451 | // observing a potentially mutated value due to some in-place operation |
1452 | std::vector<Value*> input_values; |
1453 | for (const auto idx : c10::irange(1, method.num_inputs())) { |
1454 | auto& v = graph->inputs()[idx]; |
1455 | if (v->type()->isSubtypeOf(*TensorType::get())) { |
1456 | input_values.push_back(v); |
1457 | } |
1458 | } |
1459 | |
1460 | std::stack<Block*> blocks_to_visit; |
1461 | blocks_to_visit.push(graph->block()); |
1462 | while (!blocks_to_visit.empty()) { |
1463 | Block* b = blocks_to_visit.top(); |
1464 | blocks_to_visit.pop(); |
1465 | for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) { |
1466 | Node* n = *it++; |
1467 | for (Value* v : n->outputs()) { |
1468 | if (!v->type()->isSubtypeOf(*TensorType::get())) { |
1469 | continue; |
1470 | } |
1471 | collectObserverNodesAndValueToQuantize(module, v); |
1472 | } |
1473 | |
1474 | for (Block* subblock : n->blocks()) { |
1475 | blocks_to_visit.push(subblock); |
1476 | } |
1477 | } |
1478 | } |
1479 | |
1480 | for (Value* v : input_values) { |
1481 | collectObserverNodesAndValueToQuantize(module, v); |
1482 | } |
1483 | GRAPH_DUMP("Before Quantize Tensors:" , graph); |
1484 | Value* self = graph->inputs()[0]; |
1485 | quantizeTensors(module, graph.get(), self); |
1486 | GRAPH_DUMP("After Quantize Tensors:" , graph); |
1487 | } |
1488 | |
1489 | void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) { |
1490 | SwapFunctionalLinear(module); |
1491 | auto graph = module.get_method("forward" ).graph(); |
1492 | Inline(*graph); |
1493 | ConstantPropagation(graph); |
1494 | ReplicateChooseQParamsQuantDequant(graph); |
1495 | RemoveRedundantQuantizationOps(graph); |
1496 | ReplicateQuant(graph); |
1497 | ReplicateDeQuant(graph); |
1498 | // TODO: add filter to the clamp patterns and remove this pass |
1499 | ReplicateClampScalarArgs(graph); |
1500 | propagateQuantizationOps(graph->block()); |
1501 | RemoveRedundantDequantize(graph); |
1502 | } |
1503 | |
1504 | // Insert quant and dequant nodes into the graph for both static and dynamic |
1505 | // quant. |
1506 | template <> |
1507 | Node* insertQuantDequantNodes<QuantOpParams>( |
1508 | Value* self, |
1509 | Node* observer, |
1510 | QuantOpParams& qparams, |
1511 | const std::string& quantize_func) { |
1512 | (void)self; |
1513 | Graph* g = observer->owningGraph(); |
1514 | Value* observer_out = observer->output(); |
1515 | Value* original_val = observer->input(1); |
1516 | std::vector<Value*> inputs; |
1517 | // + 1 for tensor to be quantized |
1518 | inputs.reserve(qparams.qparams.size() + 1); |
1519 | inputs.push_back({observer_out}); |
1520 | for (const auto& qparam_values : qparams.qparams) { |
1521 | inputs.push_back(qparam_values); |
1522 | } |
1523 | Node* quant = insertQuant( |
1524 | g, |
1525 | inputs, |
1526 | at::Symbol::aten(quantize_func), |
1527 | original_val->debugName() + ".quant" ); |
1528 | // Have to make sure that quant node appears after the values it depends on. |
1529 | for (Value* v : inputs) { |
1530 | quant->moveAfter(v->node()); |
1531 | } |
1532 | Node* dequant = insertDeQuant(g, quant->output(), original_val); |
1533 | dequant->moveAfter(quant); |
1534 | return dequant; |
1535 | } |
1536 | |
1537 | void checkCalculateQParamsResultTypes(const Node* out) { |
1538 | TORCH_CHECK( |
1539 | out->outputs().size() == 2, |
1540 | "calculate_qparams should produce output of size 2 (scale, zero_point)." ); |
1541 | Value* scale = out->output(0); |
1542 | Value* zp = out->output(1); |
1543 | TORCH_CHECK( |
1544 | scale->type()->expect<TensorType>(), |
1545 | "Scale value should be of Tensor type." ); |
1546 | TORCH_CHECK( |
1547 | zp->type()->expect<TensorType>(), "Scale value should be of float type." ); |
1548 | } |
1549 | |
1550 | QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams( |
1551 | script::Module& module, |
1552 | Graph* g, |
1553 | Node* n) { |
1554 | // TODO: refactor findObserverName to take Node* as input |
1555 | Value* self = g->inputs()[0]; |
1556 | Value* v = n->output(); |
1557 | TORCH_INTERNAL_ASSERT( |
1558 | v->type()->isSubtypeOf(*TensorType::get()), |
1559 | "Expected output of observer node to be Tensor" ); |
1560 | auto observer_name = findObserverName(v); |
1561 | TORCH_INTERNAL_ASSERT( |
1562 | observer_name, |
1563 | "getQSchemeAndParamMap expects the corresponding observer for " , |
1564 | v->debugName(), |
1565 | " exists." ); |
1566 | std::vector<Value*> qparams_graph_values; |
1567 | QuantOpParams quant_op_params; |
1568 | |
1569 | TORCH_CHECK( |
1570 | !isPlaceholderObserver(n->input(0)), |
1571 | "Placeholder observers are not supported in ondevice PTQ." ); |
1572 | auto observer_module = module.attr(observer_name.value()).toModule(); |
1573 | Value* observer_module_value = g->insertGetAttr(self, observer_name.value()); |
1574 | auto scalar_type = observer_module.attr("dtype" ); |
1575 | TORCH_CHECK( |
1576 | scalar_type.toScalarType() != at::ScalarType::Undefined, |
1577 | "dtype of observer can't be undefined" ); |
1578 | // Not sure if we need to support this for on device PTQ. |
1579 | if (scalar_type == at::ScalarType::Half) { |
1580 | return quant_op_params; |
1581 | } |
1582 | auto calculate_qparams = observer_module.get_method("calculate_qparams" ); |
1583 | auto calculate_qparams_schema = calculate_qparams.function().getSchema(); |
1584 | MatchedSchema matched_schema = matchSchema( |
1585 | calculate_qparams_schema, |
1586 | v->node()->sourceRange(), |
1587 | *g, |
1588 | {observer_module_value}, |
1589 | {}); |
1590 | Node* call = g->insertMethodCall("calculate_qparams" , matched_schema)->node(); |
1591 | Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0))); |
1592 | checkCalculateQParamsResultTypes(scale_zp_node); |
1593 | auto qscheme = observer_module.attr("qscheme" ).toQScheme(); |
1594 | quant_op_params.qscheme = qscheme; |
1595 | quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value* |
1596 | quant_op_params.qparams.push_back( |
1597 | scale_zp_node->output(1)); // zero_point Value* |
1598 | if (isPerChannel(qscheme)) { |
1599 | Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis" ); |
1600 | quant_op_params.qparams.push_back(ch_axis_value); |
1601 | } |
1602 | Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype" ); |
1603 | quant_op_params.qparams.push_back(scalar_type_value); |
1604 | return quant_op_params; |
1605 | } |
1606 | |
1607 | void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps( |
1608 | Module& module, |
1609 | Graph* graph, |
1610 | Value* self) { |
1611 | if (!observer_nodes_for_graph_.count(graph)) { |
1612 | return; |
1613 | } |
1614 | for (auto* n : observer_nodes_for_graph_.at(graph)) { |
1615 | Graph* g = n->owningGraph(); |
1616 | // Observer output |
1617 | Value* observer_out = n->output(); |
1618 | // Inserting before insert point |
1619 | WithInsertPoint insert_qparams_calc(observer_out->node()->next()); |
1620 | auto quant_op_params = insertCalculateQParams(module, g, n); |
1621 | insertQuantizationOps( |
1622 | module, |
1623 | self, |
1624 | n, |
1625 | isPerChannel(quant_op_params.qscheme), |
1626 | quant_op_params, |
1627 | quant_type_); |
1628 | } |
1629 | } |
1630 | |
1631 | void InsertQuantDeQuantHelper::runForOnDevicePTQ( |
1632 | Module& module, |
1633 | const std::string& method_name) { |
1634 | // In all likelihood this really wont do anything because we expect that |
1635 | // the input method for quantization's prepare step will be inlined. Thus |
1636 | // only call methods we will see will belong to observer's forward calls. |
1637 | for (auto& invoked_methods : getInvokedMethods(module, method_name)) { |
1638 | auto& invoked_module = std::get<0>(invoked_methods); |
1639 | const auto& invoked_method_name = std::get<1>(invoked_methods); |
1640 | runForOnDevicePTQ(invoked_module, invoked_method_name); |
1641 | } |
1642 | |
1643 | Method method = module.get_method(method_name); |
1644 | auto graph = method.graph(); |
1645 | // Unliked the run method we dont need to extract new qparam values for the |
1646 | // the same graph used in different call site. |
1647 | // Reason is that for on device PTQ we dont: |
1648 | // 1. Run calculate_qparams |
1649 | // 2. Get the scale and zero point |
1650 | // 3. get axis and dtype |
1651 | // 4. register values from 2 and 3 as attributes on the parent module. |
1652 | // Instead we insert call to calculate_qparams (1) via insertCalculateQParams |
1653 | // in the graph itself. Then instead of 2 and 3, we get the output Value* |
1654 | // and for 3, we insert GetAttr for axis and dtype and use those Value* |
1655 | // with insterQuantizationOps |
1656 | |
1657 | // prim::Param nodes do not belong to the graph. Hence the Insert |
1658 | // point is the beginning of graph node. This also safe guards against |
1659 | // observing a potentially mutated value due to some in-place operation |
1660 | std::vector<Value*> input_values; |
1661 | for (const auto idx : c10::irange(1, method.num_inputs())) { |
1662 | auto& v = graph->inputs()[idx]; |
1663 | if (v->type()->isSubtypeOf(*TensorType::get())) { |
1664 | input_values.push_back(v); |
1665 | } |
1666 | } |
1667 | |
1668 | std::stack<Block*> blocks_to_visit; |
1669 | blocks_to_visit.push(graph->block()); |
1670 | while (!blocks_to_visit.empty()) { |
1671 | Block* b = blocks_to_visit.top(); |
1672 | blocks_to_visit.pop(); |
1673 | for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) { |
1674 | Node* n = *it++; |
1675 | for (Value* v : n->outputs()) { |
1676 | if (!v->type()->isSubtypeOf(*TensorType::get())) { |
1677 | continue; |
1678 | } |
1679 | collectObserverNodesAndValueToQuantize(module, v); |
1680 | } |
1681 | |
1682 | for (Block* subblock : n->blocks()) { |
1683 | blocks_to_visit.push(subblock); |
1684 | } |
1685 | } |
1686 | } |
1687 | |
1688 | for (Value* v : input_values) { |
1689 | collectObserverNodesAndValueToQuantize(module, v); |
1690 | } |
1691 | |
1692 | GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:" , graph); |
1693 | Value* self = graph->inputs()[0]; |
1694 | insertCalculateQParamsAndQuantizationOps(module, graph.get(), self); |
1695 | GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:" , graph); |
1696 | } |
1697 | |
1698 | } // namespace |
1699 | |
1700 | void ReplicateQuant(std::shared_ptr<Graph>& graph) { |
1701 | std::stack<Block*> blocks_to_visit; |
1702 | std::vector<Node*> quant_nodes_to_rewrite; |
1703 | blocks_to_visit.push(graph->block()); |
1704 | while (!blocks_to_visit.empty()) { |
1705 | Block* b = blocks_to_visit.top(); |
1706 | blocks_to_visit.pop(); |
1707 | for (Node* n : b->nodes()) { |
1708 | // find quantize node that quantizes the output of if |
1709 | if ((n->kind() == Symbol::aten("quantize_per_tensor" ) || |
1710 | n->kind() == Symbol::aten("quantize_per_channel" )) && |
1711 | n->input(0)->node()->kind() == prim::If) { |
1712 | quant_nodes_to_rewrite.push_back(n); |
1713 | } |
1714 | for (Block* subblock : n->blocks()) { |
1715 | blocks_to_visit.push(subblock); |
1716 | } |
1717 | } |
1718 | } |
1719 | for (Node* n : quant_nodes_to_rewrite) { |
1720 | Node* if_node = n->input(0)->node(); |
1721 | // move the nodes that produces the quantization parameters before |
1722 | // prim::If |
1723 | for (const auto i : c10::irange(1, n->inputs().size())) { |
1724 | n->input(i)->node()->moveBefore(if_node); |
1725 | } |
1726 | // replace all uses of the quantized node with the output of if node |
1727 | n->output()->replaceAllUsesWith(if_node->output()); |
1728 | // add quantize nodes to the end of all blocks |
1729 | for (Block* if_block : if_node->blocks()) { |
1730 | TORCH_CHECK( |
1731 | if_block->outputs().size() == 1, |
1732 | "replicate quantize only works for `if` node with one output right now" ); |
1733 | // the original return value of the block |
1734 | Value* ret_val = if_block->outputs()[0]; |
1735 | std::vector<Value*> quantize_inputs = n->inputs().vec(); |
1736 | quantize_inputs[0] = ret_val; |
1737 | WithInsertPoint ins(if_block->return_node()); |
1738 | Node* quant = graph->create(n->kind(), quantize_inputs); |
1739 | if_block->replaceOutput(0, quant->output()); |
1740 | quant->output()->copyMetadata(ret_val); |
1741 | graph->insertNode(quant); |
1742 | } |
1743 | } |
1744 | |
1745 | for (Node* n : quant_nodes_to_rewrite) { |
1746 | n->removeAllInputs(); |
1747 | } |
1748 | for (Node* n : quant_nodes_to_rewrite) { |
1749 | n->destroy(); |
1750 | } |
1751 | } |
1752 | |
1753 | void ReplicateDeQuant(std::shared_ptr<Graph>& graph) { |
1754 | std::stack<Block*> blocks_to_visit; |
1755 | std::vector<Node*> dequant_nodes_to_rewrite; |
1756 | blocks_to_visit.push(graph->block()); |
1757 | while (!blocks_to_visit.empty()) { |
1758 | Block* b = blocks_to_visit.top(); |
1759 | blocks_to_visit.pop(); |
1760 | for (Node* n : b->nodes()) { |
1761 | if (n->kind() == Symbol::aten("dequantize" ) && |
1762 | n->output()->uses().size() > 1) { |
1763 | dequant_nodes_to_rewrite.push_back(n); |
1764 | } |
1765 | for (Block* subblock : n->blocks()) { |
1766 | blocks_to_visit.push(subblock); |
1767 | } |
1768 | } |
1769 | } |
1770 | for (Node* n : dequant_nodes_to_rewrite) { |
1771 | auto* quantized_val = n->input(0); |
1772 | auto* dequantized_val = n->output(); |
1773 | insertDeQuantForAllUse(graph.get(), quantized_val, dequantized_val); |
1774 | } |
1775 | |
1776 | for (Node* n : dequant_nodes_to_rewrite) { |
1777 | n->removeAllInputs(); |
1778 | } |
1779 | |
1780 | for (Node* n : dequant_nodes_to_rewrite) { |
1781 | n->destroy(); |
1782 | } |
1783 | } |
1784 | |
1785 | Module InsertQuantDeQuant( |
1786 | Module& input_module, |
1787 | const std::string& method_name, |
1788 | bool inplace, |
1789 | bool debug, |
1790 | QuantType quant_type) { |
1791 | Module module = input_module.clone(inplace); |
1792 | InsertQuantDeQuantHelper h(quant_type, debug); |
1793 | h.runWeightObserver(module, method_name); |
1794 | h.run(module, method_name); |
1795 | h.cleanup(module); |
1796 | h.propagateQuantizationOps(module); |
1797 | return module; |
1798 | } |
1799 | |
1800 | /* |
1801 | * |
1802 | * Assumption: method_name method has observer placed |
1803 | * Objective: modify that method to insert calls to: |
1804 | * 1. calculate_qparams |
1805 | * 2. GetAttr for axis and dtype values |
1806 | * 3. Use Values from above two to insert calls to quant + dequant |
1807 | * Thus after this step you have a graph of, e.g., observe_forward, |
1808 | * that has observer nodes, calculate_qparams run on those observer nodes, |
1809 | * output of which is used by quant-dequant nodes. output of dequant is used |
1810 | * by the actual op. |
1811 | * Later on we will replace dequant + op (e.g. linear) with |
1812 | * 1. prepacked_op context |
1813 | * 2. unpack |
1814 | * 3. dequantize |
1815 | * 4. linear |
1816 | * |
1817 | * Of the above pattern 2, 3, and 4 can be replaced by linear_run op |
1818 | */ |
1819 | // Module InsertQuantDeQuantForOnDevicePTQ( |
1820 | Module InsertQuantDeQuantOnDevicePTQ( |
1821 | Module& input_module, |
1822 | const std::string& method_name, |
1823 | bool inplace, |
1824 | bool debug, |
1825 | QuantType quant_type) { |
1826 | Module module = input_module.clone(inplace); |
1827 | const std::string kObserveString = "observe_" ; |
1828 | const auto matched_pos = method_name.find(kObserveString); |
1829 | const auto end_pos = matched_pos + kObserveString.length(); |
1830 | const std::string orig_method_name = method_name.substr(end_pos); |
1831 | TORCH_CHECK( |
1832 | matched_pos == 0, |
1833 | "Quant dequant nodes can only be added to observe_" , |
1834 | orig_method_name, |
1835 | ". Please make sure to run prepare step for on-device PTQ." ); |
1836 | |
1837 | std::string quantize_method_name = "quantize_" + orig_method_name; |
1838 | cloneMethod(module, method_name, quantize_method_name); |
1839 | InsertQuantDeQuantHelper h(quant_type, debug); |
1840 | h.runForOnDevicePTQ(module, quantize_method_name); |
1841 | h.removeObserverNodes(module); |
1842 | // Dont need: |
1843 | // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's |
1844 | // quant dequant RemoveRedundantQuantizationOps: THis is removing activation |
1845 | // observers for dynamic quant when the op related to it is not dynamically |
1846 | // quantizable. Doesnt really make sense. In our case we wont have those |
1847 | // anyway since for dynamic quant activations wont be observed We can still |
1848 | // use this function because the above two methods should really be a noop |
1849 | h.propagateQuantizationOps(module); |
1850 | return module; |
1851 | } |
1852 | } // namespace jit |
1853 | } // namespace torch |
1854 | |