1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | #include <torch/csrc/jit/passes/quantization/quantization_type.h> |
5 | |
6 | namespace std { |
7 | |
8 | template <> |
9 | struct hash<torch::jit::Module> { |
10 | inline size_t operator()(const torch::jit::Module& arg) const { |
11 | return std::hash<c10::intrusive_ptr<c10::ivalue::Object>>()(arg._ivalue()); |
12 | } |
13 | }; |
14 | |
15 | } // namespace std |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | using QConfig = std::tuple<Module, Module>; |
21 | using QConfigDict = std::unordered_map<std::string, c10::optional<QConfig>>; |
22 | |
23 | /** \brief Insert observer module and observer function call for |
24 | * the Tensors that needs to be observed. |
25 | * |
26 | * For each Tensor that needs to be observed in the method, insert observer |
27 | * module to the input module and add forward calls of observer to the specified |
28 | * method. |
29 | * |
30 | * \param module the input module |
31 | * \param method_name the method we want to insert observers for |
32 | * \param qconfig_dict the qconfig dictionary that specifies how |
33 | * each module is going to be quantized |
34 | * \param inplace whether we want to do inplace modification to the input module |
35 | * or clone the module |
36 | * \param is_dynamic whether the dynamic quantization script is being used. |
37 | */ |
38 | TORCH_API Module InsertObservers( |
39 | Module& module, |
40 | const std::string& method_name, |
41 | const QConfigDict& qconfig_dict, |
42 | bool inplace, |
43 | QuantType quant_type = QuantType::STATIC); |
44 | |
45 | /** \brief Insert observer module and observer method for |
46 | * the Tensors that needs to be observed. |
47 | * |
48 | * For each Tensor that needs to be observed in the method, insert observer |
49 | * module to the input module and observe_<method-name> methods to the module. |
50 | * This method is clone of mehtod_name with forward calls of observer added. |
51 | * |
52 | * \param module the input module |
53 | * \param method_name the method we want to insert observers for |
54 | * \param qconfig_dict the qconfig dictionary that specifies how |
55 | * each module is going to be quantized |
56 | * \param inplace whether we want to do inplace modification to the input module |
57 | * or clone the module |
58 | * \param is_dynamic whether the dynamic quantization script is being used. |
59 | */ |
60 | TORCH_API Module InsertObserversForOnDevicePTQ( |
61 | Module& module, |
62 | const std::string& method_name, |
63 | const QConfigDict& qconfig_dict, |
64 | bool inplace, |
65 | QuantType quant_type = QuantType::STATIC); |
66 | |
67 | } // namespace jit |
68 | } // namespace torch |
69 | |