1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/passes/quantization/quantization_type.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | /** Replicate quantize node for prim::If blocks, so that we can match |
11 | * quantization patterns in prim::If blocks |
12 | */ |
13 | TORCH_API void ReplicateQuant(std::shared_ptr<Graph>& graph); |
14 | |
15 | /** Replicate dequantize node for each use, so that we can match |
16 | * quantization patterns |
17 | */ |
18 | TORCH_API void ReplicateDeQuant(std::shared_ptr<Graph>& graph); |
19 | |
20 | /** \brief Insert quantize - dequantize calls to the Tensors |
21 | * that are observed in insert_observers pass |
22 | * |
23 | * For each Tensor that is observed, get the observer module and call |
24 | * calculate_qparam on the observer module to get quantization parameters |
25 | * and add quantize - int_repr - dequantize function calls using these |
26 | * parameters we also have special handling for quantizing "bias" right now. |
27 | * |
28 | * \param module the input module |
29 | * \param method_name the method we want to insert quantization calls for |
30 | */ |
31 | TORCH_API Module InsertQuantDeQuant( |
32 | Module& module, |
33 | const std::string& method_name, |
34 | bool inplace, |
35 | bool debug, |
36 | QuantType quant_type = QuantType::STATIC); |
37 | |
38 | TORCH_API Module InsertQuantDeQuantOnDevicePTQ( |
39 | Module& module, |
40 | const std::string& method_name, |
41 | bool inplace, |
42 | bool debug, |
43 | QuantType quant_type = QuantType::STATIC); |
44 | |
45 | } // namespace jit |
46 | } // namespace torch |
47 | |