1#include <torch/csrc/jit/passes/quantization/insert_quant_dequant.h>
2
3#include <c10/core/QScheme.h>
4#include <c10/util/irange.h>
5#include <torch/csrc/jit/frontend/schema_matching.h>
6#include <torch/csrc/jit/ir/subgraph_matcher.h>
7#include <torch/csrc/jit/jit_log.h>
8#include <torch/csrc/jit/passes/constant_propagation.h>
9#include <torch/csrc/jit/passes/fuse_linear.h>
10#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
11#include <torch/csrc/jit/passes/inliner.h>
12#include <torch/csrc/jit/passes/quantization/helper.h>
13#include <torch/csrc/jit/passes/subgraph_rewrite.h>
14
15#include <stack>
16#include <utility>
17
18namespace torch {
19namespace jit {
20
21namespace {
22using graph_rewrite_helper::PatternInfo;
23
24// dynamic quantization ops for activation: choose_qparams, quant, dequant
25using DynamicQuantOps = std::tuple<Node*, Node*, Node*>;
26
27std::string kScalarType = "_scalar_type";
28
29struct QuantOpParams {
30 c10::QScheme qscheme{c10::kPerTensorAffine};
31 std::vector<Value*> qparams;
32 // This is only so that insertQuantizationOps can be templatized
33 // and subsequntly significant portion of that code can be reused.
34 std::string back() const {
35 return "AttributeDoesNotExist";
36 }
37};
38
39c10::QScheme toAffine(c10::QScheme qscheme) {
40 switch (qscheme) {
41 case c10::kPerTensorAffine:
42 case c10::kPerTensorSymmetric:
43 return c10::kPerTensorAffine;
44 case c10::kPerChannelAffine:
45 case c10::kPerChannelSymmetric:
46 return c10::kPerChannelAffine;
47 default:
48 return qscheme;
49 }
50}
51
52bool isPerChannel(at::QScheme qscheme) {
53 return qscheme == c10::kPerChannelAffine ||
54 qscheme == c10::kPerChannelSymmetric;
55}
56
57// Go through the CallMethod graph to check if the value is Weight.
58bool isWeight(Module& module, Value* v) {
59 if (isWeight(v)) {
60 return true;
61 }
62 c10::optional<bool> result;
63 auto* self = v->owningGraph()->inputs()[0];
64 for (const Use& u : v->uses()) {
65 Node* n = u.user;
66 if (n->kind() == prim::CallMethod) {
67 auto m_opt = getInvokedModuleOpt(module, n, self);
68 if (!m_opt.has_value()) {
69 return false;
70 }
71 auto m = *m_opt;
72 auto g = m.get_method(n->s(attr::name)).graph();
73 auto call_method_result = isWeight(m, g->inputs()[u.offset]);
74 if (result.has_value()) {
75 // Check to make sure all the CallMethods in the graph produce the same
76 // output.
77 TORCH_CHECK(
78 call_method_result == result.value(),
79 "Expected all CallMethods to use either weight "
80 "or non-weight value.",
81 v->debugName());
82 } else {
83 result = call_method_result;
84 }
85 }
86 }
87 return result.has_value() ? result.value() : false;
88}
89
90Node* insertChooseQParams(Graph* graph, Value* original_val) {
91 std::string choose_qparams_func = "_choose_qparams_per_tensor";
92 // Set the reduce range to default to true, since qnnpack backend ignores this
93 // argument.
94 bool reduce_range_param = true;
95 auto reduce_range = graph->insertConstant(reduce_range_param);
96 // choose_qparams_per_tensor has 2 outputs, (scale, zero_point).
97 Node* choose_qparams = graph->create(
98 at::Symbol::aten(choose_qparams_func),
99 {original_val, reduce_range},
100 /* num_outputs = */ 2);
101 choose_qparams->output(0)->setDebugName(original_val->debugName() + ".scale");
102 choose_qparams->output(0)->setType(FloatType::get());
103 choose_qparams->output(1)->setDebugName(
104 original_val->debugName() + ".zero_point");
105 choose_qparams->output(1)->setType(IntType::get());
106 graph->insertNode(choose_qparams);
107 return choose_qparams;
108}
109
110Node* insertQuant(
111 Graph* graph,
112 const std::vector<Value*>& inputs,
113 NodeKind quant_kind,
114 const std::string& debugName) {
115 Node* quant = graph->create(quant_kind, inputs);
116 quant->output()->setDebugName(debugName);
117 graph->insertNode(quant);
118 return quant;
119}
120
121Node* insertDeQuant(
122 Graph* graph,
123 Value* quantized_val,
124 Value* original_val,
125 size_t id = 0) {
126 Node* dequant = graph->create(Symbol::aten("dequantize"), {quantized_val});
127 dequant->output()
128 ->setDebugName(
129 original_val->debugName() + ".dequant." + c10::guts::to_string(id))
130 ->setType(original_val->type());
131 graph->insertNode(dequant);
132 return dequant;
133}
134
135std::vector<Value*> insertDeQuantForAllUse(
136 Graph* graph,
137 Value* quantized_val,
138 Value* original_val) {
139 // copy uses to vector since value->uses() is a reference
140 // and changing the graph will also change the uses() list
141 const std::vector<Use> uses = original_val->uses();
142 std::vector<Value*> outputs;
143 for (const auto i : c10::irange(uses.size())) {
144 auto* user = uses[i].user;
145 // Insert dequantize node right before use node, because
146 // we want to make sure use node and dequantize node reside
147 // in the same block so that quant fusion can happen
148 WithInsertPoint ins(user);
149 Node* dequant = insertDeQuant(graph, quantized_val, original_val, i);
150 user->replaceInput(uses[i].offset, dequant->output());
151 outputs.push_back(dequant->output());
152 }
153 return outputs;
154}
155
156Node* insertQParam(
157 Graph* graph,
158 Value* quantized_input,
159 NodeKind node_kind,
160 const TypePtr& output_type,
161 const std::string& param_name) {
162 Node* qparam = graph->create(node_kind, {quantized_input});
163 qparam->output()
164 ->setDebugName(quantized_input->debugName() + "." + param_name)
165 ->setType(output_type);
166 graph->insertNode(qparam);
167 return qparam;
168}
169
170Node* insertScalarToTensor(Graph* graph, Value* scalar_value) {
171 Node* n = scalar_value->node();
172 WithInsertPoint ins(n->next());
173 Value* float_scalar_type = graph->insertConstant(IValue(c10::kFloat));
174 Value* none = graph->insertConstant(IValue());
175 Node* tensor_node = graph->create(
176 Symbol::aten("scalar_tensor"),
177 {scalar_value, float_scalar_type, none, none, none});
178 Value* tensor_output = tensor_node->output();
179 tensor_output->setDebugName(scalar_value->debugName() + ".tensor");
180 graph->insertNode(tensor_node);
181 // replace original_output with tensor
182 scalar_value->replaceAllUsesAfterNodeWith(tensor_node, tensor_output);
183 return tensor_node;
184}
185
186Node* insertItem(Graph* graph, Value* tensor, const TypePtr& output_type) {
187 WithInsertPoint ins(tensor->node()->next());
188 Node* n = graph->create(Symbol::aten("item"), {tensor});
189 Value* scalar = n->output();
190 scalar->setDebugName(tensor->debugName() + ".scalar")->setType(output_type);
191 graph->insertNode(n);
192 return n;
193}
194
195DynamicQuantOps insertChooseQParamQuantDequant(
196 Graph* graph,
197 Value* original_val,
198 Value* dtype,
199 NodeKind quant_kind) {
200 Node* choose_qparams = insertChooseQParams(graph, original_val);
201 std::vector<Value*> quant_inputs = {original_val};
202 for (auto& out : choose_qparams->outputs()) {
203 quant_inputs.push_back(out);
204 }
205 quant_inputs.push_back(dtype);
206 Node* quant = insertQuant(
207 graph, quant_inputs, quant_kind, original_val->debugName() + ".quant");
208 Node* dequant = insertDeQuant(graph, quant->output(), original_val);
209 return std::make_tuple(choose_qparams, quant, dequant);
210}
211
212Node* insertFP16CastOps(Graph* graph, Value* observer_out) {
213 // If the weight value is outside of the range for FP16 range, i.e. [5.96e-8,
214 // 65504], we saturate the values to the min/max of this range.
215 Node* saturated_weight =
216 graph->create(Symbol::aten("_saturate_weight_to_fp16"), {observer_out});
217 graph->insertNode(saturated_weight);
218 graph->lint();
219
220 return saturated_weight;
221}
222
223// find the observer for Value `v` and return the name of the observer
224c10::optional<std::string> findObserverName(Value* v) {
225 // Note that here we just check for the name of observer, but the ideally
226 // we should be comparing the type of observer, this is a temporary
227 // work around until data only clone of module.clone is supported.
228 Node* n = v->node();
229 if (n->kind() == prim::CallMethod && n->s(attr::name) == "forward") {
230 auto module_instance = n->inputs().at(0);
231 if (module_instance->node()->kind() == prim::GetAttr &&
232 module_instance->node()->s(attr::name).find("_observer_") !=
233 std::string::npos) {
234 return module_instance->node()->s(attr::name);
235 }
236 }
237 return c10::nullopt;
238}
239
240bool isPlaceholderObserver(Value* observer) {
241 if (getModuleName(observer).has_value()) {
242 auto name = getModuleName(observer).value();
243 // if PlaceholderObserver is (anywhere) in name
244 if (name.find("PlaceholderObserver") != std::string::npos) {
245 return true;
246 }
247 }
248 return false;
249}
250
251at::ScalarType getObserverDtype(Module& module, Value* v) {
252 auto observer_name = findObserverName(v);
253 if (observer_name.has_value()) {
254 auto observer_module = module.attr(observer_name.value()).toModule();
255 at::ScalarType scalar_type = observer_module.attr("dtype").toScalarType();
256 return scalar_type;
257 }
258 return at::ScalarType::Undefined;
259}
260
261c10::optional<std::string> getEmbeddingBagObsName(
262 script::Module& module,
263 Node* n) {
264 Value* v = n->output();
265 auto observer = n->input(0);
266 auto observer_module = module.attr(findObserverName(v).value()).toModule();
267 if (observer_module.hasattr("custom_op")) {
268 auto op_name = observer_module.attr("custom_op").toStringRef();
269 return isPlaceholderObserver(observer) ? std::move(op_name) : "";
270 }
271 return c10::nullopt;
272}
273
274bool isEmbeddingBagOp(
275 Node* observer,
276 c10::optional<std::string> embedding_bag_name) {
277 return embedding_bag_name &&
278 embedding_bag_name.value().find("embedding_bag_") != std::string::npos;
279}
280
281template <typename T>
282Node* insertQuantDequantNodes(
283 Value* self,
284 Node* observer,
285 T& qparams,
286 const std::string& quantize_func);
287
288// Insert quant and dequant nodes into the graph for both static and dynamic
289// quant.
290template <>
291Node* insertQuantDequantNodes<std::vector<std::string>>(
292 Value* self,
293 Node* observer,
294 std::vector<std::string>& qparam_names,
295 const std::string& quantize_func) {
296 Graph* g = observer->owningGraph();
297 Value* observer_out = observer->output();
298 Value* original_val = observer->input(1);
299 std::vector<Value*> inputs = {observer_out};
300 // Insert GetAttr nodes for quantization parameters
301 for (const auto& qparam_name : qparam_names) {
302 inputs.push_back(g->insertGetAttr(self, qparam_name));
303 }
304 Node* quant = insertQuant(
305 g,
306 inputs,
307 at::Symbol::aten(quantize_func),
308 original_val->debugName() + ".quant");
309 Node* dequant = insertDeQuant(g, quant->output(), original_val);
310 return dequant;
311}
312
313Node* insertEmbeddingBagOps(Node* observer, const std::string& op_name) {
314 Graph* g = observer->owningGraph();
315 auto observer_out = observer->output();
316
317 std::string prepack_fn, quant_fn;
318 std::vector<Value*> prepack_inputs = {observer_out};
319 if (op_name == "embedding_bag_4bit") {
320 bool optimized_qparams = false;
321 constexpr int NBINS = 200;
322 constexpr float RATIO = 0.16;
323 Value* optimized_qparams_false = g->insertConstant(optimized_qparams);
324 Value* nbins_200 = g->insertConstant(NBINS);
325 Value* ratio_0_16 = g->insertConstant(RATIO);
326 prepack_fn = "quantized::embedding_bag_4bit_prepack";
327 quant_fn = "quantized::embedding_bag_4bit_rowwise_offsets";
328 prepack_inputs.push_back(optimized_qparams_false);
329 prepack_inputs.push_back(nbins_200);
330 prepack_inputs.push_back(ratio_0_16);
331 } else if (op_name == "embedding_bag_byte") {
332 prepack_fn = "quantized::embedding_bag_byte_prepack";
333 quant_fn = "quantized::embedding_bag_byte_rowwise_offsets";
334 } else {
335 TORCH_INTERNAL_ASSERT(
336 false,
337 "Graph Mode Quantization currently supports 4-bit and 8-bit embedding bag quantization.");
338 }
339
340 std::vector<Use> uses = observer_out->uses();
341 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
342 Node* embedding_bag_float_op;
343 // We expect that the output of the weight observer will be consumed by the
344 // embedding_bag operator.
345 for (const Use& use : uses) {
346 if (matchCallFuncToUse(use, "embedding_bag", 2) ||
347 matchAtenFuncToUse(use, "embedding_bag", 0)) {
348 embedding_bag_float_op = use.user;
349 }
350 }
351
352 // Insert prepack op
353 Node* prepack = g->create(Symbol::fromQualString(prepack_fn), prepack_inputs);
354 g->insertNode(prepack);
355
356 std::vector<Value*> embedding_bag_inputs =
357 // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
358 embedding_bag_float_op->inputs().vec();
359 std::vector<Value*> qembedding_bag_inputs = {prepack->output()};
360 const auto inputs_size = embedding_bag_float_op->inputs().size();
361 const bool is_aten_op =
362 embedding_bag_float_op->kind() == Symbol::aten("embedding_bag");
363 // Create and insert quantized embedding op.
364 Value* none = g->insertConstant(IValue());
365 Value* zero = g->insertConstant(IValue(0));
366 bool pruned_wt = false;
367 auto pruned_const = g->insertConstant(pruned_wt);
368
369 if (is_aten_op) {
370 TORCH_CHECK(
371 inputs_size == 9,
372 "Expecting FP aten::embedding_bag operator to have 9 inputs");
373 // input 0 is the output of prepack op.
374 // Last input is added after we account for extra input in 4-bit case.
375 for (unsigned long i = 1; i < inputs_size - 2; ++i) {
376 qembedding_bag_inputs.push_back(embedding_bag_inputs[i]);
377 }
378 // The sparse field in the float operator denotes sparse gradients.
379 // For inference this stands for pruned weights. We currently don't support
380 // pruning in graph mode API so we set the field to 0 for inference.
381 qembedding_bag_inputs[5] = pruned_const;
382 } else {
383 TORCH_CHECK(
384 inputs_size == 12,
385 "Expecting F.embedding_bag operator to have 12 inputs");
386 qembedding_bag_inputs.push_back(embedding_bag_inputs[1]); // indices
387 qembedding_bag_inputs.push_back(embedding_bag_inputs[3]); // offsets
388 qembedding_bag_inputs.push_back(
389 embedding_bag_inputs[6]); // scale_grad_by_freq
390 qembedding_bag_inputs.push_back(zero); // mode
391 qembedding_bag_inputs.push_back(pruned_const); // pruned_weights
392 qembedding_bag_inputs.push_back(
393 embedding_bag_inputs[9]); // per_sample_weights
394 }
395
396 qembedding_bag_inputs.push_back(none); // compressed_indices_mapping
397 qembedding_bag_inputs.push_back(embedding_bag_inputs[inputs_size - 2]);
398
399 TORCH_CHECK(
400 embedding_bag_inputs[inputs_size - 1]->mustBeNone(),
401 "Expected aten::embedding_bag padding_idx input to be None");
402
403 Node* qembedding_bag =
404 g->create(Symbol::fromQualString(quant_fn), qembedding_bag_inputs);
405 if (is_aten_op) {
406 WithInsertPoint ins(embedding_bag_float_op);
407 g->insertNode(qembedding_bag);
408 // Verify that the outputs (apart from index 0) have no uses in the graph.
409 for (const auto i :
410 c10::irange(1, embedding_bag_float_op->outputs().size())) {
411 TORCH_CHECK(
412 !embedding_bag_float_op->output(i)->hasUses(),
413 "Expected aten::embedding_bag to only have use for its first output.");
414 }
415 } else {
416 g->insertNode(qembedding_bag);
417 }
418 embedding_bag_float_op->output(0)->replaceAllUsesWith(
419 qembedding_bag->output());
420 embedding_bag_float_op->removeAllInputs();
421 embedding_bag_float_op->destroy();
422 g->lint();
423 return qembedding_bag;
424}
425
426template <typename T>
427void insertQuantizationOps(
428 Module& module,
429 Value* self,
430 Node* observer,
431 bool is_per_channel,
432 T& qparams,
433 QuantType quant_type = QuantType::STATIC) {
434 Graph* g = observer->owningGraph();
435 // Observer output
436 Value* observer_out = observer->output();
437 // Inserting before insert point
438 WithInsertPoint ins(observer_out->node()->next());
439
440 std::string quantize_func;
441 if (is_per_channel) {
442 quantize_func = "quantize_per_channel";
443 } else {
444 quantize_func = "quantize_per_tensor";
445 }
446 Value* original_val = observer->input(1);
447 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
448 Node *quant, *choose_qparams, *dequant;
449 // Temporary solution to quantize embedding_bag operators. Will be re-written
450 // once we support quantization of embedding_bag weights.
451 auto embedding_bag_name = getEmbeddingBagObsName(module, observer);
452 if (isEmbeddingBagOp(observer, embedding_bag_name)) {
453 if (isWeight(module, observer_out)) {
454 auto op_name = embedding_bag_name.value();
455 Node* dequant = insertEmbeddingBagOps(observer, op_name);
456 observer_out->replaceAllUsesWith(original_val);
457 original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
458 } else {
459 // Special case for embedding bag operators indices input - we don't
460 // quantize the input but we still need to insert observers for it because
461 // the order of input and weight can be changed in the module code.
462 observer_out->replaceAllUsesWith(original_val);
463 }
464 return;
465 }
466 if (quant_type == QuantType::DYNAMIC) {
467 if (getObserverDtype(module, observer_out) == at::ScalarType::Half) {
468 dequant = insertFP16CastOps(g, observer_out);
469 } else if (!isWeight(module, observer_out)) {
470 auto observer_dtype = getObserverDtype(module, observer_out);
471 if (observer_dtype == at::ScalarType::QUInt8 ||
472 observer_dtype == at::ScalarType::QInt8) {
473 // For activation tensors we insert choose_qparams, quant, dequant ops.
474 Value* dtype = g->insertGetAttr(self, qparams.back());
475 std::tie(choose_qparams, quant, dequant) =
476 insertChooseQParamQuantDequant(
477 g, observer_out, dtype, at::Symbol::aten(quantize_func));
478 } else {
479 // dtype does not require quantization, e.g. float32
480 // will just remove the observer call
481 observer_out->replaceAllUsesWith(original_val);
482 return;
483 }
484 } else {
485 // For weight tensors we insert quant-dequant ops.
486 dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
487 }
488 } else { // Static quant
489 dequant = insertQuantDequantNodes(self, observer, qparams, quantize_func);
490 }
491 observer_out->replaceAllUsesWith(original_val);
492
493 original_val->replaceAllUsesAfterNodeWith(dequant, dequant->output());
494 GRAPH_DUMP("insert nodes:", original_val->owningGraph());
495}
496
497void ReplicateChooseQParamsQuantDequant(std::shared_ptr<Graph>& graph) {
498 const PatternInfo& dynamic_quant_pattern = PatternInfo::parse_from_str(R"(
499 graph(%a, %reduce_range, %a_dtype):
500 %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
501 %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
502 %a_dequant = aten::dequantize(%a_quant)
503 return (%a_dequant) )");
504 const Graph& dynamic_quant_graph = *dynamic_quant_pattern.pattern_graph;
505
506 const auto& matches = findPatternMatches(dynamic_quant_graph, *graph);
507 if (matches.empty()) {
508 return;
509 }
510
511 const auto& vmap = dynamic_quant_pattern.vmap;
512 Value* dequant_val = vmap.at("a_dequant");
513 Node* pattern_dequant = dequant_val->node();
514 Value* quant_val = vmap.at("a_quant");
515 Node* pattern_quant = quant_val->node();
516 Value* choose_qparam_val = vmap.at("a_scale");
517 Node* pattern_choose_qparam = choose_qparam_val->node();
518
519 std::vector<DynamicQuantOps> nodes_to_rewrite;
520 std::vector<Node*> choose_qparam_nodes_to_rewrite;
521 for (const Match& match : matches) {
522 Node* matched_dequantize = match.nodes_map.at(pattern_dequant);
523 Node* matched_quantize = match.nodes_map.at(pattern_quant);
524 Node* matched_choose_qparam = match.nodes_map.at(pattern_choose_qparam);
525 if (matched_dequantize->output()->uses().size() > 1) {
526 nodes_to_rewrite.emplace_back(
527 matched_choose_qparam, matched_quantize, matched_dequantize);
528 }
529 }
530 for (const auto& nodes : nodes_to_rewrite) {
531 auto quant_node = std::get<1>(nodes);
532 auto dequant_node = std::get<2>(nodes);
533 // get input of quantize call.
534 Value* original_val = quant_node->inputs()[0];
535 Value* dequant_out = dequant_node->output();
536 Value* dtype = quant_node->inputs()[3];
537 std::vector<Use> uses = dequant_out->uses();
538 for (const Use& use : uses) {
539 auto* user = use.user;
540 WithInsertPoint ins(user);
541 auto quant_ops = insertChooseQParamQuantDequant(
542 graph.get(), original_val, dtype, quant_node->kind());
543 user->replaceInputWith(dequant_out, std::get<2>(quant_ops)->output());
544 }
545 }
546 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
547 Node *choose_qparams, *quant, *dequant;
548 for (const auto& n : nodes_to_rewrite) {
549 std::tie(choose_qparams, quant, dequant) = n;
550 dequant->removeAllInputs();
551 quant->removeAllInputs();
552 choose_qparams->removeAllInputs();
553 }
554 for (const auto& n : nodes_to_rewrite) {
555 std::tie(choose_qparams, quant, dequant) = n;
556 dequant->destroy();
557 quant->destroy();
558 choose_qparams->destroy();
559 }
560}
561
562void RemoveRedundantDequantize(std::shared_ptr<Graph>& graph) {
563 const std::string dequantize = R"(
564 graph(%a_quant):
565 %a_dequant = aten::dequantize(%a_quant)
566 return (%a_dequant) )";
567 const std::string dequantize_replacement = R"(
568 graph(%a):
569 return (%a) )";
570 auto filter = [&](const Match& match,
571 const std::unordered_map<std::string, Value*>& vmap) {
572 const auto& match_vmap = match.values_map;
573 auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
574 Value* dequant_out = dequant_node->output();
575 // Values can be used multiple times in a single node
576 if (dequant_out->uses().size() != 1) {
577 return false;
578 }
579 Node* user = dequant_out->uses()[0].user;
580 return isTensorInfoNode(user);
581 };
582 SubgraphRewriter rewriter;
583 rewriter.RegisterRewritePattern(dequantize, dequantize_replacement);
584 rewriter.runOnGraph(graph, filter);
585}
586
587void RemoveRedundantQuantizationOps(std::shared_ptr<Graph>& graph) {
588 const std::string dynamic_quant_ops = R"(
589 graph(%a, %reduce_range, %a_dtype):
590 %a_scale : float, %a_zero_point : int = aten::_choose_qparams_per_tensor(%a, %reduce_range)
591 %a_quant = aten::quantize_per_tensor(%a, %a_scale, %a_zero_point, %a_dtype)
592 %a_dequant = aten::dequantize(%a_quant)
593 return (%a_dequant) )";
594 const std::string dynamic_quant_replacement = R"(
595 graph(%a, %reduce_range, %a_dtype):
596 return (%a) )";
597 auto filter = [&](const Match& match,
598 const std::unordered_map<std::string, Value*>& vmap) {
599 const auto& match_vmap = match.values_map;
600 auto dequant_node = match_vmap.at(vmap.at("a_dequant"))->node();
601 Value* dequant_out = dequant_node->output();
602 // Values can be used multiple times in a single node
603 if (dequant_out->uses().size() != 1) {
604 return false;
605 }
606 Node* user = dequant_out->uses()[0].user;
607 return !nodeQuantizable(user, QuantType::DYNAMIC);
608 };
609 SubgraphRewriter rewriter;
610 rewriter.RegisterRewritePattern(dynamic_quant_ops, dynamic_quant_replacement);
611 rewriter.runOnGraph(graph, filter);
612}
613
614void ReplicateClampScalarArgs(std::shared_ptr<Graph>& graph) {
615 std::stack<Block*> blocks_to_visit;
616 std::unordered_set<Node*> scalar_nodes_to_rewrite;
617 ;
618 blocks_to_visit.push(graph->block());
619 while (!blocks_to_visit.empty()) {
620 Block* b = blocks_to_visit.top();
621 blocks_to_visit.pop();
622 for (Node* n : b->nodes()) {
623 for (Value* output : n->outputs()) {
624 if (getClampScalarInputUse(output) && output->uses().size() > 1) {
625 scalar_nodes_to_rewrite.insert(n);
626 }
627 }
628 for (Block* subblock : n->blocks()) {
629 blocks_to_visit.push(subblock);
630 }
631 }
632 }
633
634 for (Node* n : scalar_nodes_to_rewrite) {
635 const std::vector<Use> uses = n->output()->uses();
636 for (const auto& use : uses) {
637 Node* user = use.user;
638 WithInsertPoint ins(user);
639 Node* cloned_node = graph->createClone(n, [](Value* v) { return v; });
640 graph->insertNode(cloned_node);
641 user->replaceInput(use.offset, cloned_node->output());
642 }
643 }
644
645 for (Node* n : scalar_nodes_to_rewrite) {
646 n->removeAllInputs();
647 }
648
649 for (Node* n : scalar_nodes_to_rewrite) {
650 n->destroy();
651 }
652}
653
654void checkCalculateQParamsResult(const IValue& qparams) {
655 TORCH_CHECK(
656 qparams.isTuple(),
657 "`calculate_qparams` function is expected to return a "
658 "Tuple, but got:",
659 qparams.tagKind());
660 auto tp = qparams.toTuple();
661 TORCH_CHECK(
662 tp->elements().size() == 2,
663 "`calculate_qparams` function is expected to return a "
664 "Tuple of size 2, got Tuple of size ",
665 tp->elements().size());
666 // Expect first two elements of the tuple to be Tensor
667 for (const auto i : c10::irange(2)) {
668 TORCH_CHECK(
669 tp->elements()[i].isTensor(),
670 "Element of Tuple is expected to be Tensor, but element ",
671 i,
672 " has type: ",
673 tp->elements()[i].tagKind());
674 }
675}
676
677class SubGraphCloneHelper {
678 public:
679 // Given a list of nodes, build a graph corresponding to these nodes.
680 // User should make sure to run this graph with expected input.
681 std::unique_ptr<GraphFunction> buildGraphFromNodes(
682 const std::vector<Node*>& nodes,
683 const std::string& name);
684
685 // Given a list of nodes in src, produce a Graph with these nodes.
686 void buildObserverSubgraph(
687 const std::vector<Node*>& src,
688 std::shared_ptr<Graph> dest);
689
690 private:
691 // Clone node in the destination Graph g.
692 void cloneNodeInGraph(
693 Node* node,
694 std::shared_ptr<Graph>& g,
695 std::unordered_map<Value*, Value*>& remap_values);
696};
697
698class InsertQuantDeQuantHelper {
699 public:
700 InsertQuantDeQuantHelper(QuantType quant_type, bool debug)
701 : quant_type_(quant_type), debug_(debug) {}
702
703 void run(Module& module, const std::string& method_name);
704
705 void runForOnDevicePTQ(Module& module, const std::string& method_name);
706
707 // Cleanup observer nodes from graph and observer modules
708 // from module object and ClassType
709 void cleanup(Module& module);
710
711 // Cleanup observer nodes only but not modules
712 // This is for ondevice PTQ
713 void removeObserverNodes(Module& m);
714
715 // In order to propagate quantization ops through the ops that doesn't
716 // require observation, we'll first inline the graph, and call the
717 // PropgateQuantizationOps pass
718 void propagateQuantizationOps(Module& module);
719
720 // Used for dynamic quantization to selectively run the weight observers.
721 // It extracts the subgraph corresponding to the weight and runs it with
722 // the module instance.
723 void runWeightObserver(Module& module, const std::string& method_name);
724
725 private:
726 ModuleMethodVector getInvokedMethods(
727 Module& module,
728 const std::string& method_name);
729
730 // Get quantization parameter map of the given Value in Graph
731 // by searching for observer module of the value and extract the
732 // quantization parameters from the observer module
733 std::tuple<c10::QScheme, QParamVector> getQSchemeAndQParamVector(
734 script::Module& module,
735 Node* n);
736 QuantOpParams insertCalculateQParams(
737 script::Module& module,
738 Graph* g,
739 Node* n);
740
741 void checkQScheme(Graph* g, c10::QScheme qscheme) {
742 if (qscheme_for_graph_.count(g)) {
743 // FIXME[T110786721]: This check was broken before nevery failing.
744 // Once fixed, this check triggers and fails tests.
745 // Fix the tests that enabling this check produce!
746 /*
747 TORCH_CHECK(
748 qscheme_for_graph_.at(g) == qscheme,
749 "Quantizing same graph with different types of "
750 "QSchemes is not supported.\n",
751 " Expecting:",
752 c10::toString(qscheme_for_graph_.at(g)),
753 " Got:",
754 c10::toString(qscheme));
755 */
756 } else {
757 qscheme_for_graph_[g] = toAffine(qscheme);
758 }
759 }
760
761 void collectObserverNodesAndValueToQuantize(Module& module, Value*);
762 void cleanup(Module& module, Graph* g);
763 void removeObserverNodes(Graph* g);
764 void quantizeTensors(Module& module, Graph* g, Value* self);
765 void insertCalculateQParamsAndQuantizationOps(
766 Module& module,
767 Graph* g,
768 Value* self);
769
770 // Function that extracts and runs the weight observer in a separate
771 // subgraph.
772 void extractAndRunWeightObserver(
773 Module& module,
774 Value* self,
775 Value* weight_value);
776
777 // Recursively find the nodes that produce the value and add to subgraph.
778 void findSubgraph(Value* self, Value* v, std::vector<Node*>& weight_subgraph);
779
780 // Quantizes two types of general ops(ops that works both for floating point
781 // and quantized Tensors) in this pass
782 // for ops that only manipulates shape, e.g. flatten, quantization
783 // is done by swapping with previous dequantize op
784 // for ops that manipulates values of Tensor, e.g. average pool, quantization
785 // is done by inserting quant/dequant ops after the op
786 // also has a special handling of clamp/hardtanh
787 void propagateQuantizationOps(Block* block);
788
789 // Propagate quantization parameters from other quantized tensors
790 void propagateQParams(
791 Value* original_output,
792 const std::vector<Value*>& inputs,
793 bool is_scalar = false,
794 const c10::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt =
795 c10::nullopt);
796
797 bool isQuantized(Value* v) {
798 return quantized_values_.count(v) != 0;
799 }
800
801 std::unordered_map<Graph*, std::vector<std::string>>
802 observer_modules_to_remove_;
803 // We only remove observer module attributes from type in the
804 // first encounter of the graph, after that since the attributes
805 // is already removed from the ClassType, we'll use the list of slot index to
806 // replay this removal
807 std::unordered_map<Graph*, std::vector<int>> removed_observer_slots_;
808 std::unordered_map<Graph*, std::vector<Node*>> nodes_to_destroy_;
809 // Map from Graph to observer node, we can use observer node to
810 // get the information of original value that's been observed and
811 // the quantization parameters
812 std::unordered_map<Graph*, std::vector<Node*>> observer_nodes_for_graph_;
813 // A map from qparam name (e.g. _scale) to the attribute name in
814 // the module(e.g. weight_scale_0)
815 std::unordered_map<Node*, std::unordered_map<std::string, std::string>>
816 qparam_name_map_for_node_;
817 // Record qscheme for every graph, this is for checking
818 // each graph is only quantized with one type of QScheme
819 std::unordered_map<Graph*, c10::QScheme> qscheme_for_graph_;
820
821 // Set of quantized values, so that we quantize each value only
822 // once
823 std::unordered_set<Value*> quantized_values_;
824
825 // Map from original weight value to GraphFunction corresponding to the
826 // subgraph that includes the weight observer and dependent nodes.
827 std::unordered_map<Value*, std::unique_ptr<GraphFunction>>
828 weight_to_graph_fn_;
829
830 QuantType quant_type_ = QuantType::STATIC;
831 bool debug_ = false;
832};
833
834void InsertQuantDeQuantHelper::collectObserverNodesAndValueToQuantize(
835 Module& module,
836 Value* v) {
837 auto* g = v->owningGraph();
838 auto observer_name = findObserverName(v);
839 if (!observer_name) {
840 return;
841 }
842 observer_modules_to_remove_[g].push_back(observer_name.value());
843
844 Node* observer = v->node();
845 TORCH_INTERNAL_ASSERT(
846 observer->kind() == prim::CallMethod &&
847 observer->s(attr::name) == "forward" &&
848 observer->inputs()[0]->node()->kind() == prim::GetAttr &&
849 observer->inputs()[0]->node()->s(attr::name) == observer_name);
850
851 // Observer forward call node
852 nodes_to_destroy_[g].push_back(observer);
853 // GetAttr node for observer module
854 nodes_to_destroy_[g].push_back(observer->inputs()[0]->node());
855 observer_nodes_for_graph_[g].push_back(observer);
856}
857
858void InsertQuantDeQuantHelper::removeObserverNodes(Module& module) {
859 for (auto& method : module.get_methods()) {
860 removeObserverNodes(method.graph().get());
861 }
862 for (Module m : module.children()) {
863 removeObserverNodes(m);
864 }
865}
866
867void InsertQuantDeQuantHelper::removeObserverNodes(Graph* g) {
868 if (nodes_to_destroy_.count(g)) {
869 for (auto& n : nodes_to_destroy_.at(g)) {
870 n->removeAllInputs();
871 }
872 for (auto& n : nodes_to_destroy_.at(g)) {
873 n->destroy();
874 }
875 nodes_to_destroy_.at(g).clear();
876 }
877}
878
879void InsertQuantDeQuantHelper::cleanup(Module& module) {
880 for (auto& method : module.get_methods()) {
881 cleanup(module, method.graph().get());
882 }
883 for (Module m : module.children()) {
884 cleanup(m);
885 }
886}
887
888void InsertQuantDeQuantHelper::cleanup(Module& module, Graph* g) {
889 GRAPH_DUMP("Before Remove Observers:", g);
890 removeObserverNodes(g);
891
892 // 1. If we have seen this graph before, this means the observer
893 // attributes has been removed from the type(see step 2) but the slot
894 // index of these attributes are kept in the list, we'll replay the observer
895 // slots removal using these slot indexes
896 if (removed_observer_slots_.count(g)) {
897 for (auto slot : removed_observer_slots_.at(g)) {
898 module._ivalue()->unsafeRemoveSlot(slot);
899 }
900 }
901
902 // 2. Remove observer modules from last one to first one in order to
903 // reduce the time complexity, assuming all the observer modules
904 // are added after the existing modules, we'll have complexity of
905 // O(N) where N is number of observer modules with this optimization
906 if (observer_modules_to_remove_.count(g)) {
907 auto& observers = observer_modules_to_remove_.at(g);
908 for (int64_t i = observers.size() - 1; i >= 0; --i) {
909 auto observer_name = observers[i];
910 GRAPH_DEBUG("Trying to remove: ", observer_name);
911 if (module.type()->hasAttribute(observer_name)) {
912 // We record the slot index here in order to replay the
913 // slot removal in other objects that's sharing the ClassType
914 // since we're going to remove attribute in the ClassType here
915 removed_observer_slots_[g].push_back(
916 module.type()->getAttributeSlot(observer_name));
917 module._ivalue()->unsafeRemoveAttr(observer_name);
918 module.type()->unsafeRemoveAttribute(observer_name);
919 }
920 }
921 observers.clear();
922 }
923 GRAPH_DUMP("After remove observers :", g);
924}
925
926void SubGraphCloneHelper::cloneNodeInGraph(
927 Node* node,
928 std::shared_ptr<Graph>& g,
929 std::unordered_map<Value*, Value*>& remap_old_to_new) {
930 auto* block = g->block();
931 auto value_fn = [&](Value* v) {
932 if (remap_old_to_new.count(v) == 0) {
933 auto new_value = g->block()->addInput();
934 remap_old_to_new[v] = new_value;
935 new_value->copyMetadata(v);
936 return new_value;
937 } else {
938 return remap_old_to_new[v];
939 }
940 };
941
942 auto new_node = block->appendNode(g->createClone(node, value_fn));
943 for (size_t i = 0; i < node->outputs().size(); ++i) {
944 auto oo = node->outputs()[i];
945 auto no = new_node->outputs()[i];
946 remap_old_to_new[oo] = no;
947 }
948}
949
950void SubGraphCloneHelper::buildObserverSubgraph(
951 const std::vector<Node*>& weight_subgraph,
952 std::shared_ptr<Graph> dest_graph) {
953 std::unordered_map<Value*, Value*> remap_old_to_new;
954 // Build weight subgraph
955 for (auto n : weight_subgraph) {
956 cloneNodeInGraph(n, dest_graph, remap_old_to_new);
957 }
958 LintGraph(dest_graph);
959
960 // Add last node output value as subgraph output.
961 for (auto out : weight_subgraph.back()->outputs()) {
962 dest_graph->registerOutput(remap_old_to_new[out]);
963 }
964 GRAPH_DUMP("New weight observer subgraph: ", dest_graph);
965}
966
967std::unique_ptr<GraphFunction> SubGraphCloneHelper::buildGraphFromNodes(
968 const std::vector<Node*>& nodes,
969 const std::string& name) {
970 auto observer_subgraph = std::make_shared<Graph>();
971 auto build_observer_graph = [&](GraphFunction& func) {
972 buildObserverSubgraph(nodes, func.graph());
973 };
974 return torch::make_unique<GraphFunction>(
975 name, observer_subgraph, build_observer_graph);
976}
977
978void InsertQuantDeQuantHelper::findSubgraph(
979 Value* self,
980 Value* input_val,
981 std::vector<Node*>& weight_subgraph) {
982 Node* node = input_val->node();
983 weight_subgraph.push_back(node);
984 const auto& inputs = node->inputs().vec();
985 for (auto v : inputs) {
986 if (!hitGraphInput(v)) {
987 findSubgraph(self, v, weight_subgraph);
988 } else {
989 TORCH_CHECK(
990 v == self,
991 "Unexpected value found when handling weight value "
992 " in findSubgraph, traced back to:",
993 v->debugName(),
994 " which is not self:",
995 self->debugName());
996 }
997 }
998}
999
1000void InsertQuantDeQuantHelper::extractAndRunWeightObserver(
1001 Module& module,
1002 Value* self,
1003 Value* weight_value) {
1004 std::vector<Node*> weight_subgraph;
1005 // If the graph was already visited, return the GraphFunction directly.
1006 // Multiple module instances can share the same graph code, so we don't need
1007 // to re-run the extraction process.
1008 if (weight_to_graph_fn_.count(weight_value) == 0) {
1009 // Extract the subgraph nodes.
1010 findSubgraph(self, weight_value, weight_subgraph);
1011
1012 // Reverse to traverse subgraph in correct direction
1013 std::reverse(weight_subgraph.begin(), weight_subgraph.end());
1014
1015 // Build the graph using the nodes found from the weight observer.
1016 SubGraphCloneHelper o;
1017 std::unique_ptr<GraphFunction> func =
1018 o.buildGraphFromNodes(weight_subgraph, "observer_subgraph");
1019 weight_to_graph_fn_[weight_value] = std::move(func);
1020 }
1021 Stack module_inp = {module._ivalue()};
1022 // Run the graph with the module input.
1023 weight_to_graph_fn_[weight_value]->run(module_inp);
1024}
1025
1026void InsertQuantDeQuantHelper::quantizeTensors(
1027 Module& module,
1028 Graph* g,
1029 Value* self) {
1030 if (!observer_nodes_for_graph_.count(g)) {
1031 return;
1032 }
1033 for (auto* n : observer_nodes_for_graph_.at(g)) {
1034 auto* original_value = n->input(1);
1035 auto tp = getQSchemeAndQParamVector(module, n);
1036 auto qscheme = std::get<0>(tp);
1037 auto qparam_map = std::get<1>(tp);
1038 checkQScheme(g, qscheme);
1039 std::vector<std::string> qparam_names;
1040 for (auto& pr : qparam_map) {
1041 const auto& name = pr.first;
1042 const auto& qparam = pr.second;
1043 size_t uid = 0;
1044 auto qparam_name =
1045 original_value->debugName() + name + "_" + c10::to_string(uid++);
1046 while (module.hasattr(qparam_name)) {
1047 qparam_name =
1048 original_value->debugName() + name + "_" + c10::to_string(uid++);
1049 }
1050 qparam_name_map_for_node_[n][name] = qparam_name;
1051 module.register_attribute(qparam_name, qparam.type(), qparam);
1052 qparam_names.push_back(qparam_name);
1053 }
1054 insertQuantizationOps(
1055 module, self, n, isPerChannel(qscheme), qparam_names, quant_type_);
1056 }
1057}
1058
1059std::tuple<c10::QScheme, QParamVector> InsertQuantDeQuantHelper::
1060 getQSchemeAndQParamVector(script::Module& module, Node* n) {
1061 // TODO: refactor findObserverName to take Node* as input
1062 Value* v = n->output();
1063 TORCH_INTERNAL_ASSERT(
1064 v->type()->isSubtypeOf(*TensorType::get()),
1065 "Expected output of observer node to be Tensor");
1066 auto observer_name = findObserverName(v);
1067 TORCH_INTERNAL_ASSERT(
1068 observer_name,
1069 "getQSchemeAndParamMap expects the corresponding observer for ",
1070 v->debugName(),
1071 " exists.");
1072 QParamVector qparams;
1073 c10::QScheme qscheme = c10::kPerTensorAffine;
1074
1075 auto observer_module = module.attr(observer_name.value()).toModule();
1076 auto scalar_type = observer_module.attr("dtype");
1077 if (isPlaceholderObserver(n->input(0))) {
1078 // get compute_dtype for dynamic quantization
1079 if (observer_module.hasattr("is_dynamic") &&
1080 observer_module.attr("is_dynamic").toBool()) {
1081 qparams.emplace_back(kScalarType, observer_module.attr("dtype"));
1082 }
1083 return std::make_tuple(qscheme, std::move(qparams));
1084 } else if (scalar_type == at::ScalarType::Half) {
1085 return std::make_tuple(qscheme, std::move(qparams));
1086 }
1087 auto calculate_qparams = observer_module.get_method("calculate_qparams");
1088 IValue result = calculate_qparams(std::vector<IValue>());
1089 checkCalculateQParamsResult(result);
1090 TORCH_CHECK(
1091 scalar_type.toScalarType() != at::ScalarType::Undefined,
1092 "dtype of observer can't be undefined");
1093 auto tp = result.toTuple();
1094 at::Tensor scale = tp->elements()[0].toTensor().to(at::kFloat);
1095 at::Tensor zero_point = tp->elements()[1].toTensor().to(at::kInt);
1096 // quantization parameters should appear in the same order as
1097 // the argument for quantize_per_tensor/quantize_per_channel function
1098
1099 qscheme = observer_module.attr("qscheme").toQScheme();
1100 if (isPerChannel(qscheme)) {
1101 auto axis = observer_module.attr("ch_axis");
1102 qparams.emplace_back("_scale", scale);
1103 qparams.emplace_back("_zero_point", zero_point);
1104 qparams.emplace_back("_axis", axis.toInt());
1105 } else {
1106 qparams.emplace_back("_scale", scale.item<double>());
1107 qparams.emplace_back("_zero_point", zero_point.item<int64_t>());
1108 }
1109 qparams.emplace_back(kScalarType, scalar_type);
1110 return std::make_tuple(qscheme, std::move(qparams));
1111}
1112
1113ModuleMethodVector InsertQuantDeQuantHelper::getInvokedMethods(
1114 Module& module,
1115 const std::string& method_name) {
1116 auto graph = module.get_method(method_name).graph();
1117
1118 ModuleMethodVector invoked_methods;
1119 std::stack<Block*> blocks_to_visit;
1120 blocks_to_visit.push(graph->block());
1121 while (!blocks_to_visit.empty()) {
1122 Block* b = blocks_to_visit.top();
1123 blocks_to_visit.pop();
1124 for (Node* n : b->nodes()) {
1125 if (n->kind() == prim::CallMethod) {
1126 auto module_instance = n->inputs()[0];
1127 auto module_method_name = n->s(attr::name);
1128 c10::optional<Module> m;
1129 // calling method on self
1130 if (module_instance == graph->inputs()[0]) {
1131 m = module;
1132 } else if (
1133 module_instance->node()->kind() == prim::GetAttr &&
1134 module_instance->node()->s(attr::name).find("_observer_") ==
1135 std::string::npos) {
1136 m = getInvokedModuleOpt(module, n, graph->inputs()[0]);
1137 }
1138 if (m) {
1139 invoked_methods.emplace_back(*m, module_method_name);
1140 }
1141 }
1142
1143 for (Block* subblock : n->blocks()) {
1144 blocks_to_visit.push(subblock);
1145 }
1146 }
1147 }
1148 return invoked_methods;
1149}
1150
1151void InsertQuantDeQuantHelper::propagateQParams(
1152 Value* original_output,
1153 const std::vector<Value*>& inputs,
1154 bool is_scalar,
1155 const c10::optional<std::tuple<c10::QScheme, QParamVector>>& qparams_opt) {
1156 Node* n = original_output->node();
1157 Graph* graph = n->owningGraph();
1158 if (is_scalar) {
1159 // convert Scalar to Tensor
1160 n = insertScalarToTensor(graph, original_output);
1161 original_output = n->output();
1162 }
1163 // for ops like average pool, we'll insert quant dequant after the op
1164 // We'll assume the tensor is a PerTensorAffine quantized Tensor for
1165 // now, and may generalize later if this becomes an issue
1166 TORCH_INTERNAL_ASSERT(
1167 inputs.size() == 1, "Expecting single input for the aten function");
1168 // input of the dequantize node
1169 Value* quantized_input = inputs[0]->node()->input(0);
1170 // insert ops after the general op
1171 Node* quantized_input_node = quantized_input->node();
1172 // Insert after the node that is later in topological order
1173 WithInsertPoint ins(
1174 quantized_input_node->isAfter(n) ? quantized_input_node->next()
1175 : n->next());
1176 std::vector<Value*> quant_inputs;
1177 auto quant_kind = Symbol::aten("quantize_per_tensor");
1178 if (qparams_opt.has_value()) {
1179 quant_inputs = {original_output};
1180 auto qscheme = std::get<0>(*qparams_opt);
1181 auto qparams = std::get<1>(*qparams_opt);
1182 if (isPerChannel(qscheme)) {
1183 quant_kind = Symbol::aten("quantize_per_channel");
1184 }
1185 for (const auto& qparam : qparams) {
1186 Value* qparam_val = graph->insertConstant(qparam.second);
1187 qparam_val->setDebugName(quantized_input->debugName() + qparam.first);
1188 quant_inputs.push_back(qparam_val);
1189 }
1190 } else {
1191 // Only per tensor affine quantized tensor is supported in this case
1192 // get quantization parameters from previous quantized op
1193 Node* scale = insertQParam(
1194 graph,
1195 quantized_input,
1196 at::Symbol::aten("q_scale"),
1197 FloatType::get(),
1198 "q_scale");
1199 Node* zero_point = insertQParam(
1200 graph,
1201 quantized_input,
1202 at::Symbol::aten("q_zero_point"),
1203 IntType::get(),
1204 "q_zero_point");
1205 Node* dtype = insertQParam(
1206 graph, quantized_input, prim::dtype, IntType::get(), "dtype");
1207 quant_inputs = {
1208 original_output,
1209 scale->output(),
1210 zero_point->output(),
1211 dtype->output()};
1212 }
1213 Node* quant = insertQuant(
1214 graph, quant_inputs, quant_kind, original_output->debugName() + ".quant");
1215 Value* quantized_output = quant->output();
1216 // replace uses of original output of the general op with quantized
1217 // output
1218 original_output->replaceAllUsesAfterNodeWith(quant, quantized_output);
1219 const auto& outputs =
1220 insertDeQuantForAllUse(graph, quantized_output, quantized_output);
1221 for (auto* output : outputs) {
1222 if (is_scalar) {
1223 // Convert the dequantized Tensor back to Scalar
1224 Node* item = insertItem(graph, output, FloatType::get());
1225 Value* scalar = item->output();
1226 output->replaceAllUsesAfterNodeWith(item, scalar);
1227 output = scalar;
1228 }
1229 quantized_values_.insert(output);
1230 }
1231}
1232
1233void removeDequantizeFromInputs(const std::unordered_set<Value*>& inputs) {
1234 // Delete dequantize node, we have one dequantize
1235 // for each use of the value
1236 for (auto* dequantized_val : inputs) {
1237 auto* dequantize_node = dequantized_val->node();
1238 TORCH_INTERNAL_ASSERT(
1239 dequantized_val->uses().size() == 1,
1240 "Expect to have one dequantize node for each use");
1241 // Replace useses of dequantized_val with the input of
1242 // dequantize node
1243 dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
1244 dequantize_node->removeAllInputs();
1245 dequantize_node->destroy();
1246 }
1247}
1248
1249// Check if we need to propagate the quantization ops from input to
1250// output
1251c10::optional<std::vector<Value*>> getDequantizedInputs(Value* output) {
1252 auto inputs = getPassThroughInputs(output);
1253 if (!inputs.empty()) {
1254 // note that we don't need to recursively check for prim::If
1255 // here because if all inputs of a prim::If is dequantized
1256 // the dequantize will be factored out before we get to this
1257 // point
1258 bool is_dequantized = true;
1259 for (auto* input : inputs) {
1260 GRAPH_DEBUG(
1261 "checking if input:",
1262 input->debugName(),
1263 " in node:",
1264 *input->node(),
1265 "is quantized");
1266 is_dequantized &= input->node()->kind() == Symbol::aten("dequantize");
1267 }
1268 if (is_dequantized) {
1269 return inputs;
1270 }
1271 }
1272 return c10::nullopt;
1273}
1274
1275void InsertQuantDeQuantHelper::propagateQuantizationOps(Block* block) {
1276 for (Node* n : block->nodes()) {
1277 if (n->kind() == prim::If) {
1278 for (Block* subblock : n->blocks()) {
1279 propagateQuantizationOps(subblock);
1280 }
1281 if (n->outputs().empty()) {
1282 continue;
1283 }
1284 if (n->outputs().size() > 1) {
1285 // Factoring out dequantize for if blocks with multiple outputs
1286 // is not supported right now
1287 continue;
1288 }
1289 }
1290 if (isSingleInputGeneralValueAtenFunction(n)) {
1291 for (auto* output : n->outputs()) {
1292 if (isQuantized(output)) {
1293 continue;
1294 }
1295 if (auto inputs = getDequantizedInputs(output)) {
1296 propagateQParams(output, *inputs);
1297 if (isClamp(n)) {
1298 for (size_t i = 1; i <= 2; ++i) {
1299 // propagate qparams for min and max scalar arguments
1300 // for aten::clamp/aten::hardtanh
1301 propagateQParams(n->input(i), *inputs, /* is_scalar */ true);
1302 }
1303 }
1304 }
1305 }
1306 } else if (auto qparams_opt = getFixedQParams(n)) {
1307 for (auto* output : n->outputs()) {
1308 if (isQuantized(output)) {
1309 continue;
1310 }
1311 if (auto inputs = getDequantizedInputs(output)) {
1312 propagateQParams(output, *inputs, /* is_scalar */ false, qparams_opt);
1313 }
1314 }
1315 } else {
1316 // For ops that are quantized by propagating dequantize ops,
1317 // e.g. flatten we need to
1318 // 1. check if we need to propagate dequantize op
1319 // 2. remove the dequantize ops from inputs
1320 // 3. insert dequantize for all outputs
1321 // to make sure it works for ops with multiple outputs
1322 // since removing dequantize from inputs is mutating the graph
1323 // and it will affect future checks for whether all the inputs
1324 // has been quantized or not(since currently we just check if
1325 // the value is produced by dequantize op to decide if the value
1326 // is quantized or not
1327 // list of dequantized input values
1328 std::unordered_set<Value*> dequantized_inputs;
1329 std::vector<Value*> outputs_to_dequantize;
1330 // 1. collect dequantized inputs and outputs we need to dequantize
1331 for (auto* output : n->outputs()) {
1332 if (isQuantized(output)) {
1333 continue;
1334 }
1335 if (auto inputs = getDequantizedInputs(output)) {
1336 std::copy(
1337 inputs->begin(),
1338 inputs->end(),
1339 std::inserter(dequantized_inputs, dequantized_inputs.end()));
1340 outputs_to_dequantize.push_back(output);
1341 }
1342 }
1343 // 2. remove the dequantize ops from inputs
1344 removeDequantizeFromInputs(dequantized_inputs);
1345 // 3. insert dequantize op for outpus
1346 for (auto* output : outputs_to_dequantize) {
1347 insertDeQuantForAllUse(output->owningGraph(), output, output);
1348 }
1349 }
1350
1351 if (isBinaryOpWithScalarInput(n)) {
1352 // Print warning for add_scalar/mul_scalar when debug is enabled
1353 // since the quantization parameter for these ops depends on
1354 // input and it's too complicated to encode the equations in
1355 // the IR:
1356 // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/quantized/cpu/BinaryOps.cpp#L64-L74
1357 if (debug_) {
1358 TORCH_WARN_ONCE(
1359 "debug option for add_scalar and mul_scalar is not supported, "
1360 "please don't use debug option for models that uses these ops.");
1361 }
1362 }
1363 }
1364}
1365
1366void InsertQuantDeQuantHelper::runWeightObserver(
1367 Module& module,
1368 const std::string& method_name) {
1369 if (quant_type_ != QuantType::DYNAMIC) {
1370 return;
1371 }
1372
1373 for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1374 auto& invoked_module = std::get<0>(invoked_methods);
1375 const auto& invoked_method_name = std::get<1>(invoked_methods);
1376 runWeightObserver(invoked_module, invoked_method_name);
1377 }
1378 Method method = module.get_method(method_name);
1379 auto graph = method.graph();
1380 Value* self = graph->inputs()[0];
1381
1382 std::vector<Value*> weight_values;
1383 // Visit all blocks in the current graph to find weight values.
1384 std::stack<Block*> blocks_to_visit;
1385 blocks_to_visit.push(graph->block());
1386 while (!blocks_to_visit.empty()) {
1387 Block* b = blocks_to_visit.top();
1388 blocks_to_visit.pop();
1389 for (auto n : b->nodes()) {
1390 for (Value* v : n->outputs()) {
1391 if (!v->type()->isSubtypeOf(*TensorType::get())) {
1392 continue;
1393 }
1394 auto observer_name = findObserverName(v);
1395 if (observer_name && isWeight(module, v)) {
1396 weight_values.push_back(v);
1397 }
1398 }
1399 for (Block* subblock : n->blocks()) {
1400 blocks_to_visit.push(subblock);
1401 }
1402 }
1403 }
1404 // For all the observed weight values, find the corresponding subgraph that
1405 // contributes to the weight tensor, and run that subgraph to observe the
1406 // weight.
1407 for (const auto& v : weight_values) {
1408 extractAndRunWeightObserver(module, self, v);
1409 }
1410}
1411
1412void InsertQuantDeQuantHelper::run(
1413 Module& module,
1414 const std::string& method_name) {
1415 for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1416 auto& invoked_module = std::get<0>(invoked_methods);
1417 const auto& invoked_method_name = std::get<1>(invoked_methods);
1418 run(invoked_module, invoked_method_name);
1419 }
1420
1421 Method method = module.get_method(method_name);
1422 auto graph = method.graph();
1423 // We only need to register new parameters if the graph has
1424 // been quantized before
1425 // TODO: dedup this part with code in quantizeTensors
1426 if (observer_nodes_for_graph_.count(graph.get())) {
1427 for (auto* n : observer_nodes_for_graph_.at(graph.get())) {
1428 auto tp = getQSchemeAndQParamVector(module, n);
1429 checkQScheme(graph.get(), std::get<0>(tp));
1430 auto qparam_map = std::get<1>(tp);
1431 // We check the size here because for some observers (like
1432 // PlaceholderObserver) the qparams might be empty.
1433 if (!qparam_map.empty()) {
1434 TORCH_INTERNAL_ASSERT(
1435 qparam_name_map_for_node_.count(n),
1436 "Expected to have a qparam_name_map for node:",
1437 *n);
1438 auto qparam_name_map = qparam_name_map_for_node_.at(n);
1439 for (auto& pr : qparam_map) {
1440 const auto& name = pr.first;
1441 const auto& qparam = pr.second;
1442 module._ivalue()->setAttr(qparam_name_map.at(name), qparam);
1443 }
1444 }
1445 }
1446 return;
1447 }
1448
1449 // prim::Param nodes do not belong to the graph. Hence the Insert
1450 // point is the beginning of graph node. This also safe guards against
1451 // observing a potentially mutated value due to some in-place operation
1452 std::vector<Value*> input_values;
1453 for (const auto idx : c10::irange(1, method.num_inputs())) {
1454 auto& v = graph->inputs()[idx];
1455 if (v->type()->isSubtypeOf(*TensorType::get())) {
1456 input_values.push_back(v);
1457 }
1458 }
1459
1460 std::stack<Block*> blocks_to_visit;
1461 blocks_to_visit.push(graph->block());
1462 while (!blocks_to_visit.empty()) {
1463 Block* b = blocks_to_visit.top();
1464 blocks_to_visit.pop();
1465 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
1466 Node* n = *it++;
1467 for (Value* v : n->outputs()) {
1468 if (!v->type()->isSubtypeOf(*TensorType::get())) {
1469 continue;
1470 }
1471 collectObserverNodesAndValueToQuantize(module, v);
1472 }
1473
1474 for (Block* subblock : n->blocks()) {
1475 blocks_to_visit.push(subblock);
1476 }
1477 }
1478 }
1479
1480 for (Value* v : input_values) {
1481 collectObserverNodesAndValueToQuantize(module, v);
1482 }
1483 GRAPH_DUMP("Before Quantize Tensors:", graph);
1484 Value* self = graph->inputs()[0];
1485 quantizeTensors(module, graph.get(), self);
1486 GRAPH_DUMP("After Quantize Tensors:", graph);
1487}
1488
1489void InsertQuantDeQuantHelper::propagateQuantizationOps(Module& module) {
1490 SwapFunctionalLinear(module);
1491 auto graph = module.get_method("forward").graph();
1492 Inline(*graph);
1493 ConstantPropagation(graph);
1494 ReplicateChooseQParamsQuantDequant(graph);
1495 RemoveRedundantQuantizationOps(graph);
1496 ReplicateQuant(graph);
1497 ReplicateDeQuant(graph);
1498 // TODO: add filter to the clamp patterns and remove this pass
1499 ReplicateClampScalarArgs(graph);
1500 propagateQuantizationOps(graph->block());
1501 RemoveRedundantDequantize(graph);
1502}
1503
1504// Insert quant and dequant nodes into the graph for both static and dynamic
1505// quant.
1506template <>
1507Node* insertQuantDequantNodes<QuantOpParams>(
1508 Value* self,
1509 Node* observer,
1510 QuantOpParams& qparams,
1511 const std::string& quantize_func) {
1512 (void)self;
1513 Graph* g = observer->owningGraph();
1514 Value* observer_out = observer->output();
1515 Value* original_val = observer->input(1);
1516 std::vector<Value*> inputs;
1517 // + 1 for tensor to be quantized
1518 inputs.reserve(qparams.qparams.size() + 1);
1519 inputs.push_back({observer_out});
1520 for (const auto& qparam_values : qparams.qparams) {
1521 inputs.push_back(qparam_values);
1522 }
1523 Node* quant = insertQuant(
1524 g,
1525 inputs,
1526 at::Symbol::aten(quantize_func),
1527 original_val->debugName() + ".quant");
1528 // Have to make sure that quant node appears after the values it depends on.
1529 for (Value* v : inputs) {
1530 quant->moveAfter(v->node());
1531 }
1532 Node* dequant = insertDeQuant(g, quant->output(), original_val);
1533 dequant->moveAfter(quant);
1534 return dequant;
1535}
1536
1537void checkCalculateQParamsResultTypes(const Node* out) {
1538 TORCH_CHECK(
1539 out->outputs().size() == 2,
1540 "calculate_qparams should produce output of size 2 (scale, zero_point).");
1541 Value* scale = out->output(0);
1542 Value* zp = out->output(1);
1543 TORCH_CHECK(
1544 scale->type()->expect<TensorType>(),
1545 "Scale value should be of Tensor type.");
1546 TORCH_CHECK(
1547 zp->type()->expect<TensorType>(), "Scale value should be of float type.");
1548}
1549
1550QuantOpParams InsertQuantDeQuantHelper::insertCalculateQParams(
1551 script::Module& module,
1552 Graph* g,
1553 Node* n) {
1554 // TODO: refactor findObserverName to take Node* as input
1555 Value* self = g->inputs()[0];
1556 Value* v = n->output();
1557 TORCH_INTERNAL_ASSERT(
1558 v->type()->isSubtypeOf(*TensorType::get()),
1559 "Expected output of observer node to be Tensor");
1560 auto observer_name = findObserverName(v);
1561 TORCH_INTERNAL_ASSERT(
1562 observer_name,
1563 "getQSchemeAndParamMap expects the corresponding observer for ",
1564 v->debugName(),
1565 " exists.");
1566 std::vector<Value*> qparams_graph_values;
1567 QuantOpParams quant_op_params;
1568
1569 TORCH_CHECK(
1570 !isPlaceholderObserver(n->input(0)),
1571 "Placeholder observers are not supported in ondevice PTQ.");
1572 auto observer_module = module.attr(observer_name.value()).toModule();
1573 Value* observer_module_value = g->insertGetAttr(self, observer_name.value());
1574 auto scalar_type = observer_module.attr("dtype");
1575 TORCH_CHECK(
1576 scalar_type.toScalarType() != at::ScalarType::Undefined,
1577 "dtype of observer can't be undefined");
1578 // Not sure if we need to support this for on device PTQ.
1579 if (scalar_type == at::ScalarType::Half) {
1580 return quant_op_params;
1581 }
1582 auto calculate_qparams = observer_module.get_method("calculate_qparams");
1583 auto calculate_qparams_schema = calculate_qparams.function().getSchema();
1584 MatchedSchema matched_schema = matchSchema(
1585 calculate_qparams_schema,
1586 v->node()->sourceRange(),
1587 *g,
1588 {observer_module_value},
1589 {});
1590 Node* call = g->insertMethodCall("calculate_qparams", matched_schema)->node();
1591 Node* scale_zp_node = g->insertNode(g->createTupleUnpack(call->output(0)));
1592 checkCalculateQParamsResultTypes(scale_zp_node);
1593 auto qscheme = observer_module.attr("qscheme").toQScheme();
1594 quant_op_params.qscheme = qscheme;
1595 quant_op_params.qparams.push_back(scale_zp_node->output(0)); // scale Value*
1596 quant_op_params.qparams.push_back(
1597 scale_zp_node->output(1)); // zero_point Value*
1598 if (isPerChannel(qscheme)) {
1599 Value* ch_axis_value = g->insertGetAttr(observer_module_value, "ch_axis");
1600 quant_op_params.qparams.push_back(ch_axis_value);
1601 }
1602 Value* scalar_type_value = g->insertGetAttr(observer_module_value, "dtype");
1603 quant_op_params.qparams.push_back(scalar_type_value);
1604 return quant_op_params;
1605}
1606
1607void InsertQuantDeQuantHelper::insertCalculateQParamsAndQuantizationOps(
1608 Module& module,
1609 Graph* graph,
1610 Value* self) {
1611 if (!observer_nodes_for_graph_.count(graph)) {
1612 return;
1613 }
1614 for (auto* n : observer_nodes_for_graph_.at(graph)) {
1615 Graph* g = n->owningGraph();
1616 // Observer output
1617 Value* observer_out = n->output();
1618 // Inserting before insert point
1619 WithInsertPoint insert_qparams_calc(observer_out->node()->next());
1620 auto quant_op_params = insertCalculateQParams(module, g, n);
1621 insertQuantizationOps(
1622 module,
1623 self,
1624 n,
1625 isPerChannel(quant_op_params.qscheme),
1626 quant_op_params,
1627 quant_type_);
1628 }
1629}
1630
1631void InsertQuantDeQuantHelper::runForOnDevicePTQ(
1632 Module& module,
1633 const std::string& method_name) {
1634 // In all likelihood this really wont do anything because we expect that
1635 // the input method for quantization's prepare step will be inlined. Thus
1636 // only call methods we will see will belong to observer's forward calls.
1637 for (auto& invoked_methods : getInvokedMethods(module, method_name)) {
1638 auto& invoked_module = std::get<0>(invoked_methods);
1639 const auto& invoked_method_name = std::get<1>(invoked_methods);
1640 runForOnDevicePTQ(invoked_module, invoked_method_name);
1641 }
1642
1643 Method method = module.get_method(method_name);
1644 auto graph = method.graph();
1645 // Unliked the run method we dont need to extract new qparam values for the
1646 // the same graph used in different call site.
1647 // Reason is that for on device PTQ we dont:
1648 // 1. Run calculate_qparams
1649 // 2. Get the scale and zero point
1650 // 3. get axis and dtype
1651 // 4. register values from 2 and 3 as attributes on the parent module.
1652 // Instead we insert call to calculate_qparams (1) via insertCalculateQParams
1653 // in the graph itself. Then instead of 2 and 3, we get the output Value*
1654 // and for 3, we insert GetAttr for axis and dtype and use those Value*
1655 // with insterQuantizationOps
1656
1657 // prim::Param nodes do not belong to the graph. Hence the Insert
1658 // point is the beginning of graph node. This also safe guards against
1659 // observing a potentially mutated value due to some in-place operation
1660 std::vector<Value*> input_values;
1661 for (const auto idx : c10::irange(1, method.num_inputs())) {
1662 auto& v = graph->inputs()[idx];
1663 if (v->type()->isSubtypeOf(*TensorType::get())) {
1664 input_values.push_back(v);
1665 }
1666 }
1667
1668 std::stack<Block*> blocks_to_visit;
1669 blocks_to_visit.push(graph->block());
1670 while (!blocks_to_visit.empty()) {
1671 Block* b = blocks_to_visit.top();
1672 blocks_to_visit.pop();
1673 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end;) {
1674 Node* n = *it++;
1675 for (Value* v : n->outputs()) {
1676 if (!v->type()->isSubtypeOf(*TensorType::get())) {
1677 continue;
1678 }
1679 collectObserverNodesAndValueToQuantize(module, v);
1680 }
1681
1682 for (Block* subblock : n->blocks()) {
1683 blocks_to_visit.push(subblock);
1684 }
1685 }
1686 }
1687
1688 for (Value* v : input_values) {
1689 collectObserverNodesAndValueToQuantize(module, v);
1690 }
1691
1692 GRAPH_DUMP("Before insertCalculateQparamsAndQuantizationOps:", graph);
1693 Value* self = graph->inputs()[0];
1694 insertCalculateQParamsAndQuantizationOps(module, graph.get(), self);
1695 GRAPH_DUMP("After insertCalculateQparamsAndQuantizationOps:", graph);
1696}
1697
1698} // namespace
1699
1700void ReplicateQuant(std::shared_ptr<Graph>& graph) {
1701 std::stack<Block*> blocks_to_visit;
1702 std::vector<Node*> quant_nodes_to_rewrite;
1703 blocks_to_visit.push(graph->block());
1704 while (!blocks_to_visit.empty()) {
1705 Block* b = blocks_to_visit.top();
1706 blocks_to_visit.pop();
1707 for (Node* n : b->nodes()) {
1708 // find quantize node that quantizes the output of if
1709 if ((n->kind() == Symbol::aten("quantize_per_tensor") ||
1710 n->kind() == Symbol::aten("quantize_per_channel")) &&
1711 n->input(0)->node()->kind() == prim::If) {
1712 quant_nodes_to_rewrite.push_back(n);
1713 }
1714 for (Block* subblock : n->blocks()) {
1715 blocks_to_visit.push(subblock);
1716 }
1717 }
1718 }
1719 for (Node* n : quant_nodes_to_rewrite) {
1720 Node* if_node = n->input(0)->node();
1721 // move the nodes that produces the quantization parameters before
1722 // prim::If
1723 for (const auto i : c10::irange(1, n->inputs().size())) {
1724 n->input(i)->node()->moveBefore(if_node);
1725 }
1726 // replace all uses of the quantized node with the output of if node
1727 n->output()->replaceAllUsesWith(if_node->output());
1728 // add quantize nodes to the end of all blocks
1729 for (Block* if_block : if_node->blocks()) {
1730 TORCH_CHECK(
1731 if_block->outputs().size() == 1,
1732 "replicate quantize only works for `if` node with one output right now");
1733 // the original return value of the block
1734 Value* ret_val = if_block->outputs()[0];
1735 std::vector<Value*> quantize_inputs = n->inputs().vec();
1736 quantize_inputs[0] = ret_val;
1737 WithInsertPoint ins(if_block->return_node());
1738 Node* quant = graph->create(n->kind(), quantize_inputs);
1739 if_block->replaceOutput(0, quant->output());
1740 quant->output()->copyMetadata(ret_val);
1741 graph->insertNode(quant);
1742 }
1743 }
1744
1745 for (Node* n : quant_nodes_to_rewrite) {
1746 n->removeAllInputs();
1747 }
1748 for (Node* n : quant_nodes_to_rewrite) {
1749 n->destroy();
1750 }
1751}
1752
1753void ReplicateDeQuant(std::shared_ptr<Graph>& graph) {
1754 std::stack<Block*> blocks_to_visit;
1755 std::vector<Node*> dequant_nodes_to_rewrite;
1756 blocks_to_visit.push(graph->block());
1757 while (!blocks_to_visit.empty()) {
1758 Block* b = blocks_to_visit.top();
1759 blocks_to_visit.pop();
1760 for (Node* n : b->nodes()) {
1761 if (n->kind() == Symbol::aten("dequantize") &&
1762 n->output()->uses().size() > 1) {
1763 dequant_nodes_to_rewrite.push_back(n);
1764 }
1765 for (Block* subblock : n->blocks()) {
1766 blocks_to_visit.push(subblock);
1767 }
1768 }
1769 }
1770 for (Node* n : dequant_nodes_to_rewrite) {
1771 auto* quantized_val = n->input(0);
1772 auto* dequantized_val = n->output();
1773 insertDeQuantForAllUse(graph.get(), quantized_val, dequantized_val);
1774 }
1775
1776 for (Node* n : dequant_nodes_to_rewrite) {
1777 n->removeAllInputs();
1778 }
1779
1780 for (Node* n : dequant_nodes_to_rewrite) {
1781 n->destroy();
1782 }
1783}
1784
1785Module InsertQuantDeQuant(
1786 Module& input_module,
1787 const std::string& method_name,
1788 bool inplace,
1789 bool debug,
1790 QuantType quant_type) {
1791 Module module = input_module.clone(inplace);
1792 InsertQuantDeQuantHelper h(quant_type, debug);
1793 h.runWeightObserver(module, method_name);
1794 h.run(module, method_name);
1795 h.cleanup(module);
1796 h.propagateQuantizationOps(module);
1797 return module;
1798}
1799
1800/*
1801 *
1802 * Assumption: method_name method has observer placed
1803 * Objective: modify that method to insert calls to:
1804 * 1. calculate_qparams
1805 * 2. GetAttr for axis and dtype values
1806 * 3. Use Values from above two to insert calls to quant + dequant
1807 * Thus after this step you have a graph of, e.g., observe_forward,
1808 * that has observer nodes, calculate_qparams run on those observer nodes,
1809 * output of which is used by quant-dequant nodes. output of dequant is used
1810 * by the actual op.
1811 * Later on we will replace dequant + op (e.g. linear) with
1812 * 1. prepacked_op context
1813 * 2. unpack
1814 * 3. dequantize
1815 * 4. linear
1816 *
1817 * Of the above pattern 2, 3, and 4 can be replaced by linear_run op
1818 */
1819// Module InsertQuantDeQuantForOnDevicePTQ(
1820Module InsertQuantDeQuantOnDevicePTQ(
1821 Module& input_module,
1822 const std::string& method_name,
1823 bool inplace,
1824 bool debug,
1825 QuantType quant_type) {
1826 Module module = input_module.clone(inplace);
1827 const std::string kObserveString = "observe_";
1828 const auto matched_pos = method_name.find(kObserveString);
1829 const auto end_pos = matched_pos + kObserveString.length();
1830 const std::string orig_method_name = method_name.substr(end_pos);
1831 TORCH_CHECK(
1832 matched_pos == 0,
1833 "Quant dequant nodes can only be added to observe_",
1834 orig_method_name,
1835 ". Please make sure to run prepare step for on-device PTQ.");
1836
1837 std::string quantize_method_name = "quantize_" + orig_method_name;
1838 cloneMethod(module, method_name, quantize_method_name);
1839 InsertQuantDeQuantHelper h(quant_type, debug);
1840 h.runForOnDevicePTQ(module, quantize_method_name);
1841 h.removeObserverNodes(module);
1842 // Dont need:
1843 // ReplicateChooseQParamsQuantDequant: This is propagating dynamic quant's
1844 // quant dequant RemoveRedundantQuantizationOps: THis is removing activation
1845 // observers for dynamic quant when the op related to it is not dynamically
1846 // quantizable. Doesnt really make sense. In our case we wont have those
1847 // anyway since for dynamic quant activations wont be observed We can still
1848 // use this function because the above two methods should really be a noop
1849 h.propagateQuantizationOps(module);
1850 return module;
1851}
1852} // namespace jit
1853} // namespace torch
1854