1#pragma once
2
3#include <torch/csrc/jit/api/module.h>
4#include <torch/csrc/jit/passes/quantization/quantization_type.h>
5
6namespace std {
7
8template <>
9struct hash<torch::jit::Module> {
10 inline size_t operator()(const torch::jit::Module& arg) const {
11 return std::hash<c10::intrusive_ptr<c10::ivalue::Object>>()(arg._ivalue());
12 }
13};
14
15} // namespace std
16
17namespace torch {
18namespace jit {
19
20using QConfig = std::tuple<Module, Module>;
21using QConfigDict = std::unordered_map<std::string, c10::optional<QConfig>>;
22
23/** \brief Insert observer module and observer function call for
24 * the Tensors that needs to be observed.
25 *
26 * For each Tensor that needs to be observed in the method, insert observer
27 * module to the input module and add forward calls of observer to the specified
28 * method.
29 *
30 * \param module the input module
31 * \param method_name the method we want to insert observers for
32 * \param qconfig_dict the qconfig dictionary that specifies how
33 * each module is going to be quantized
34 * \param inplace whether we want to do inplace modification to the input module
35 * or clone the module
36 * \param is_dynamic whether the dynamic quantization script is being used.
37 */
38TORCH_API Module InsertObservers(
39 Module& module,
40 const std::string& method_name,
41 const QConfigDict& qconfig_dict,
42 bool inplace,
43 QuantType quant_type = QuantType::STATIC);
44
45/** \brief Insert observer module and observer method for
46 * the Tensors that needs to be observed.
47 *
48 * For each Tensor that needs to be observed in the method, insert observer
49 * module to the input module and observe_<method-name> methods to the module.
50 * This method is clone of mehtod_name with forward calls of observer added.
51 *
52 * \param module the input module
53 * \param method_name the method we want to insert observers for
54 * \param qconfig_dict the qconfig dictionary that specifies how
55 * each module is going to be quantized
56 * \param inplace whether we want to do inplace modification to the input module
57 * or clone the module
58 * \param is_dynamic whether the dynamic quantization script is being used.
59 */
60TORCH_API Module InsertObserversForOnDevicePTQ(
61 Module& module,
62 const std::string& method_name,
63 const QConfigDict& qconfig_dict,
64 bool inplace,
65 QuantType quant_type = QuantType::STATIC);
66
67} // namespace jit
68} // namespace torch
69