1#pragma once
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/passes/quantization/quantization_type.h>
6
7namespace torch {
8namespace jit {
9
10/** \brief Backend specific pass to fuse dequantize - op - quantize calls
11 * as quantized_op calls.
12 *
13 * Right now this is a fusion for fbgemm backend and only works for quantized
14 * conv op, we'll extend to more ops and more backends in the future.
15 *
16 * Currently supported fusion:
17 * q(conv2d(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_conv2d(prepack(to_nhwc(a)),
18 * prepack(to_nhwc(w)),
19 * prepack(to_nhwc(b))))
20 *
21 * q(linear(dq(a), dq(w), dq(b))) --> to_nchw(fbgemm_linear(prepack(to_nhwc(a)),
22 * prepack(to_nhwc(w)),
23 * prepack(to_nhwc(b))))
24 *
25 * \param graph the graph we want to apply fusion
26 */
27TORCH_API void QuantFusion(
28 std::shared_ptr<Graph>& graph,
29 QuantType quant_type = QuantType::STATIC);
30
31/** \brief Insert prepack and unpack function in graph
32 * We want add pack/unpack functions for quantized weight because later we want
33 * to fold the packed weight as an attribute of the module, in order to reduce
34 * the cost of packing the weight on the fly in quantized models.
35 *
36 * Each quantized op has it's corresponding prepack/unpack function,
37 * right now, we only need to do prepack/unpack for quantized::linear
38 * and quantized::conv2d.
39 */
40TORCH_API void InsertPrepackUnpack(std::shared_ptr<Graph>& graph);
41
42/** \brief Insert pack and unpack function in all graphs
43 * of module
44 *
45 * Go through graphs of all the methods of all child modules
46 * and call InsertPrepackUnpack on the graph.
47 */
48TORCH_API void InsertPrepackUnpack(Module& module);
49
50TORCH_API script::Module Finalize(
51 script::Module& module,
52 QuantType quant_type = QuantType::STATIC,
53 const std::vector<std::string>& preserved_attrs =
54 std::vector<std::string>());
55
56TORCH_API void FoldQuantizedPrepackingOps(Module& module);
57
58TORCH_API Module FinalizeOnDevicePTQ(
59 Module& module,
60 QuantType quant_type,
61 const std::string& method_name);
62} // namespace jit
63} // namespace torch
64