1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/relay/quantize.h
22 * \brief Header of definitions for quantization
23 */
24#ifndef TVM_RELAY_QUANTIZE_QUANTIZE_H_
25#define TVM_RELAY_QUANTIZE_QUANTIZE_H_
26
27#include <tvm/relay/expr.h>
28#include <tvm/relay/op.h>
29
30#include <string>
31
32#include "../transforms/pattern_utils.h"
33
34namespace tvm {
35namespace relay {
36namespace quantize {
37
38/*! \brief Kind of annotate field */
39enum QAnnotateKind : int { kQIdentity = 0, kQInput = 1, kQWeight = 2, kQActivation = 3 };
40
41/*! \brief Attribute for simulated quantize operator */
42struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
43 int kind;
44 bool sign;
45 std::string rounding;
46
47 TVM_DECLARE_ATTRS(SimulatedQuantizeAttrs, "relay.attrs.SimulatedQuantizeAttrs") {
48 TVM_ATTR_FIELD(kind).describe("kind of field, hint for nbit/dtype configuration.");
49 TVM_ATTR_FIELD(sign).set_default(true).describe("whether to use signed data type.");
50 TVM_ATTR_FIELD(rounding).set_default("round").describe(
51 "rounding mode. Can be 'floor', 'ceil', 'round'");
52 }
53};
54
55class QConfig;
56/*!
57 * \brief Container for build configuration options
58 */
59class QConfigNode : public Object {
60 public:
61 int nbit_input = 8;
62 int nbit_weight = 8;
63 int nbit_activation = 32;
64 DataType dtype_input = DataType::Int(8);
65 DataType dtype_weight = DataType::Int(8);
66 DataType dtype_activation = DataType::Int(32);
67 std::string calibrate_mode = "global_scale";
68 double global_scale = 8.0;
69 std::string weight_scale = "power2";
70 bool skip_dense_layer = true;
71 Array<Expr> skip_conv_layers = Array<Expr>(ObjectPtr<Object>(nullptr));
72 bool do_simulation = false;
73 bool round_for_shift = true;
74 Array<Expr> debug_enabled_ops = Array<Expr>(ObjectPtr<Object>(nullptr));
75 std::string rounding = "UPWARD";
76 int calibrate_chunk_by = -1;
77 std::string partition_conversions = "disabled";
78
79 void VisitAttrs(AttrVisitor* v) {
80 v->Visit("nbit_input", &nbit_input);
81 v->Visit("nbit_weight", &nbit_weight);
82 v->Visit("nbit_activation", &nbit_activation);
83 v->Visit("dtype_input", &dtype_input);
84 v->Visit("dtype_weight", &dtype_weight);
85 v->Visit("dtype_activation", &dtype_activation);
86 v->Visit("calibrate_mode", &calibrate_mode);
87 v->Visit("global_scale", &global_scale);
88 v->Visit("weight_scale", &weight_scale);
89 v->Visit("skip_dense_layer", &skip_dense_layer);
90 v->Visit("skip_conv_layers", &skip_conv_layers);
91 v->Visit("do_simulation", &do_simulation);
92 v->Visit("round_for_shift", &round_for_shift);
93 v->Visit("debug_enabled_ops", &debug_enabled_ops);
94 v->Visit("rounding", &rounding);
95 v->Visit("calibrate_chunk_by", &calibrate_chunk_by);
96 v->Visit("partition_conversions", &partition_conversions);
97 }
98
99 static constexpr const char* _type_key = "relay.quantize.QConfig";
100 TVM_DECLARE_FINAL_OBJECT_INFO(QConfigNode, Object);
101};
102
103/*!
104 * \brief Container for build configuration options
105 */
106class QConfig : public ObjectRef {
107 public:
108 QConfig() {}
109 explicit QConfig(ObjectPtr<Object> n) : ObjectRef(n) {}
110
111 const QConfigNode* operator->() const { return static_cast<const QConfigNode*>(get()); }
112
113 QConfigNode* operator->() { return static_cast<QConfigNode*>(get_mutable()); }
114
115 /*!
116 * \brief Push a new BuildConfig context onto the thread local stack.
117 * \param build_config The configuration to set as the current context.
118 */
119 static void EnterQConfigScope(const QConfig& qconfig);
120
121 /*!
122 * \brief Pop a build config off the thread local context stack, restoring the previous
123 * configuration as the current context.
124 */
125 static void ExitQConfigScope();
126
127 /*!
128 * \brief Get the current BuildConfig context from thread local storage, or a default
129 * configuration if a BuildConfig scope has not been entered.
130 * \return The configuration that is the current context.
131 */
132 static QConfig& Current();
133
134 using ContainerType = QConfigNode;
135};
136
137/*!
138 * \brief RAII container to provide a scoped BuildConfig context. Pushes a configuration onto the
139 * context stack when constructed, and pops it when destructed.
140 */
141struct QConfigContext {
142 /*!
143 * \brief Enter a new BuildConfig context. The given BuildConfig becomes the new current
144 * context. When the BuildConfigContext is destructed, the previous context is restored.
145 * \param build_config The BuildConfig to set as the new current context.
146 */
147 explicit QConfigContext(const QConfig& qconfig) { QConfig::EnterQConfigScope(qconfig); }
148
149 /*! \brief Destructor. Pops the context off the thread local stack. */
150 ~QConfigContext() { QConfig::ExitQConfigScope(); }
151};
152
153} // namespace quantize
154} // namespace relay
155} // namespace tvm
156#endif // TVM_RELAY_QUANTIZE_QUANTIZE_H_
157