1#pragma once
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/passes/quantization/quantization_type.h>
6
7namespace torch {
8namespace jit {
9
10/** Replicate quantize node for prim::If blocks, so that we can match
11 * quantization patterns in prim::If blocks
12 */
13TORCH_API void ReplicateQuant(std::shared_ptr<Graph>& graph);
14
15/** Replicate dequantize node for each use, so that we can match
16 * quantization patterns
17 */
18TORCH_API void ReplicateDeQuant(std::shared_ptr<Graph>& graph);
19
20/** \brief Insert quantize - dequantize calls to the Tensors
21 * that are observed in insert_observers pass
22 *
23 * For each Tensor that is observed, get the observer module and call
24 * calculate_qparam on the observer module to get quantization parameters
25 * and add quantize - int_repr - dequantize function calls using these
26 * parameters we also have special handling for quantizing "bias" right now.
27 *
28 * \param module the input module
29 * \param method_name the method we want to insert quantization calls for
30 */
31TORCH_API Module InsertQuantDeQuant(
32 Module& module,
33 const std::string& method_name,
34 bool inplace,
35 bool debug,
36 QuantType quant_type = QuantType::STATIC);
37
38TORCH_API Module InsertQuantDeQuantOnDevicePTQ(
39 Module& module,
40 const std::string& method_name,
41 bool inplace,
42 bool debug,
43 QuantType quant_type = QuantType::STATIC);
44
45} // namespace jit
46} // namespace torch
47