1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/passes/quantization/quantization_type.h> |
6 | |
7 | namespace torch { |
8 | namespace jit { |
9 | |
10 | /** \brief Backend specific pass to fuse dequantize - op - quantize calls |
11 | * as quantized_op calls. |
12 | * |
13 | * Right now this is a fusion for fbgemm backend and only works for quantized |
14 | * conv op, we'll extend to more ops and more backends in the future. |
15 | * |
16 | * Currently supported fusion: |
17 | * q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)), |
18 | * prepack(to_nhwc(w)), |
19 | * prepack(to_nhwc(b)))) |
20 | * |
21 | * q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)), |
22 | * prepack(to_nhwc(w)), |
23 | * prepack(to_nhwc(b)))) |
24 | * |
25 | * \param graph the graph we want to apply fusion |
26 | */ |
27 | TORCH_API void QuantFusion( |
28 | std::shared_ptr<Graph>& graph, |
29 | QuantType quant_type = QuantType::STATIC); |
30 | |
31 | /** \brief Insert prepack and unpack function in graph |
32 | * We want add pack/unpack functions for quantized weight because later we want |
33 | * to fold the packed weight as an attribute of the module, in order to reduce |
34 | * the cost of packing the weight on the fly in quantized models. |
35 | * |
36 | * Each quantized op has it's corresponding prepack/unpack function, |
37 | * right now, we only need to do prepack/unpack for quantized::linear |
38 | * and quantized::conv2d. |
39 | */ |
40 | TORCH_API void InsertPrepackUnpack(std::shared_ptr<Graph>& graph); |
41 | |
42 | /** \brief Insert pack and unpack function in all graphs |
43 | * of module |
44 | * |
45 | * Go through graphs of all the methods of all child modules |
46 | * and call InsertPrepackUnpack on the graph. |
47 | */ |
48 | TORCH_API void InsertPrepackUnpack(Module& module); |
49 | |
50 | TORCH_API script::Module Finalize( |
51 | script::Module& module, |
52 | QuantType quant_type = QuantType::STATIC, |
53 | const std::vector<std::string>& preserved_attrs = |
54 | std::vector<std::string>()); |
55 | |
56 | TORCH_API void FoldQuantizedPrepackingOps(Module& module); |
57 | |
58 | TORCH_API Module FinalizeOnDevicePTQ( |
59 | Module& module, |
60 | QuantType quant_type, |
61 | const std::string& method_name); |
62 | } // namespace jit |
63 | } // namespace torch |
64 | |