1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/relay/quantize.h |
22 | * \brief Header of definitions for quantization |
23 | */ |
24 | #ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_ |
25 | #define TVM_RELAY_QUANTIZE_QUANTIZE_H_ |
26 | |
27 | #include <tvm/relay/expr.h> |
28 | #include <tvm/relay/op.h> |
29 | |
30 | #include <string> |
31 | |
32 | #include "../transforms/pattern_utils.h" |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | namespace quantize { |
37 | |
38 | /*! \brief Kind of annotate field */ |
39 | enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 }; |
40 | |
41 | /*! \brief Attribute for simulated quantize operator */ |
42 | struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> { |
43 | int kind; |
44 | bool sign; |
45 | std::string rounding; |
46 | |
47 | TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs" ) { |
48 | TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration." ); |
49 | TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type." ); |
50 | TVM_ATTR_FIELD(rounding).set_default("round" ).describe( |
51 | "rounding mode. Can be 'floor', 'ceil', 'round'" ); |
52 | } |
53 | }; |
54 | |
55 | class QConfig; |
56 | /*! |
57 | * \brief Container for build configuration options |
58 | */ |
59 | class QConfigNode : public Object { |
60 | public: |
61 | int nbit_input = 8; |
62 | int nbit_weight = 8; |
63 | int nbit_activation = 32; |
64 | DataType dtype_input = DataType::Int(8); |
65 | DataType dtype_weight = DataType::Int(8); |
66 | DataType dtype_activation = DataType::Int(32); |
67 | std::string calibrate_mode = "global_scale" ; |
68 | double global_scale = 8.0; |
69 | std::string weight_scale = "power2" ; |
70 | bool skip_dense_layer = true; |
71 | Array<Expr> skip_conv_layers = Array<Expr>(ObjectPtr<Object>(nullptr)); |
72 | bool do_simulation = false; |
73 | bool round_for_shift = true; |
74 | Array<Expr> debug_enabled_ops = Array<Expr>(ObjectPtr<Object>(nullptr)); |
75 | std::string rounding = "UPWARD" ; |
76 | int calibrate_chunk_by = -1; |
77 | std::string partition_conversions = "disabled" ; |
78 | |
79 | void VisitAttrs(AttrVisitor* v) { |
80 | v->Visit("nbit_input" , &nbit_input); |
81 | v->Visit("nbit_weight" , &nbit_weight); |
82 | v->Visit("nbit_activation" , &nbit_activation); |
83 | v->Visit("dtype_input" , &dtype_input); |
84 | v->Visit("dtype_weight" , &dtype_weight); |
85 | v->Visit("dtype_activation" , &dtype_activation); |
86 | v->Visit("calibrate_mode" , &calibrate_mode); |
87 | v->Visit("global_scale" , &global_scale); |
88 | v->Visit("weight_scale" , &weight_scale); |
89 | v->Visit("skip_dense_layer" , &skip_dense_layer); |
90 | v->Visit("skip_conv_layers" , &skip_conv_layers); |
91 | v->Visit("do_simulation" , &do_simulation); |
92 | v->Visit("round_for_shift" , &round_for_shift); |
93 | v->Visit("debug_enabled_ops" , &debug_enabled_ops); |
94 | v->Visit("rounding" , &rounding); |
95 | v->Visit("calibrate_chunk_by" , &calibrate_chunk_by); |
96 | v->Visit("partition_conversions" , &partition_conversions); |
97 | } |
98 | |
99 | static constexpr const char* _type_key = "relay.quantize.QConfig" ; |
100 | TVM_DECLARE_FINAL_OBJECT_INFO(QConfigNode, Object); |
101 | }; |
102 | |
103 | /*! |
104 | * \brief Container for build configuration options |
105 | */ |
106 | class QConfig : public ObjectRef { |
107 | public: |
108 | QConfig() {} |
109 | explicit QConfig(ObjectPtr<Object> n) : ObjectRef(n) {} |
110 | |
111 | const QConfigNode* operator->() const { return static_cast<const QConfigNode*>(get()); } |
112 | |
113 | QConfigNode* operator->() { return static_cast<QConfigNode*>(get_mutable()); } |
114 | |
115 | /*! |
116 | * \brief Push a new BuildConfig context onto the thread local stack. |
117 | * \param build_config The configuration to set as the current context. |
118 | */ |
119 | static void EnterQConfigScope(const QConfig& qconfig); |
120 | |
121 | /*! |
122 | * \brief Pop a build config off the thread local context stack, restoring the previous |
123 | * configuration as the current context. |
124 | */ |
125 | static void ExitQConfigScope(); |
126 | |
127 | /*! |
128 | * \brief Get the current BuildConfig context from thread local storage, or a default |
129 | * configuration if a BuildConfig scope has not been entered. |
130 | * \return The configuration that is the current context. |
131 | */ |
132 | static QConfig& Current(); |
133 | |
134 | using ContainerType = QConfigNode; |
135 | }; |
136 | |
137 | /*! |
138 | * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the |
139 | * context stack when constructed, and pops it when destructed. |
140 | */ |
141 | struct QConfigContext { |
142 | /*! |
143 | * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current |
144 | * context. When the BuildConfigContext is destructed, the previous context is restored. |
145 | * \param build_config The BuildConfig to set as the new current context. |
146 | */ |
147 | explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); } |
148 | |
149 | /*! \brief Destructor. Pops the context off the thread local stack. */ |
150 | ~QConfigContext() { QConfig::ExitQConfigScope(); } |
151 | }; |
152 | |
153 | } // namespace quantize |
154 | } // namespace relay |
155 | } // namespace tvm |
156 | #endif // TVM_RELAY_QUANTIZE_QUANTIZE_H_ |
157 | |