1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/qnn/utils.h
22 * \brief Utility methods needs for quantized ops that can be shared
23 */
24
25#ifndef TVM_RELAY_QNN_UTILS_H_
26#define TVM_RELAY_QNN_UTILS_H_
27
28#include <tvm/relay/expr.h>
29#include <tvm/relay/qnn/attrs.h>
30#include <tvm/tir/expr.h>
31#include <tvm/tir/op.h>
32
33#include <limits>
34#include <string>
35#include <utility>
36#include <vector>
37
38#include "./op/requantize_config.h"
39
40namespace tvm {
41namespace relay {
42namespace qnn {
43
44static inline Array<IndexExpr> get_shape(const Type& type) {
45 auto input_tt = type.as<TensorTypeNode>();
46 ICHECK(input_tt != nullptr) << "Type information missing."
47 << " Please run infer_type pass.";
48 return input_tt->shape;
49}
50
51static inline int32_t GetQmin(const DataType& dtype) {
52 ICHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision";
53 if (dtype.is_int() || dtype.is_uint()) {
54 auto min_value_expr = tvm::min_value(dtype);
55 auto* min_value = tir::as_const_int(min_value_expr);
56 ICHECK(min_value != nullptr);
57 return static_cast<int32_t>(min_value[0]);
58 } else {
59 LOG(FATAL) << "Type not supported " << dtype;
60 }
61}
62
63static inline int32_t GetQmax(const DataType& dtype) {
64 ICHECK_LE(dtype.bits(), 32) << "QNN ops support int32 or lower precision";
65 if (dtype.is_int() || dtype.is_uint()) {
66 auto max_value_expr = tvm::max_value(dtype);
67 auto* max_value = tir::as_const_int(max_value_expr);
68 ICHECK(max_value != nullptr);
69 return static_cast<int32_t>(max_value[0]);
70 } else {
71 LOG(FATAL) << "Type not supported " << dtype;
72 }
73}
74
75/*
76 * \brief Convert FP32 representation into fixed point representation.
77 * \param double_multplier The input FP32 number.
78 * \return The pair of multiplier and shift for fixed point representation.
79 * \note Converts a floating point number so that it can be represented by
80 * integers. The representation is
81 * float_number = (significand) * 2^(exponent)
82 *
83 * The significand is a number between 0.5 and 1. This is represented by
84 * an integer number. For example, if it is int32, then the decimal point
85 * exists between bit 31 and 30 from LSB (or between first and second bit
86 * from the left).
87 *
88 * Some examples are
89 * 0.25 = (0.5) * 2^(-1)
90 * 0.125 = (0.5) * 2^(-2)
91 *
92 * Credit to TFLite reference implementation.
93 */
94std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier);
95
96Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
97 const Expr& input_zero_point, const Expr& output_scale,
98 const Expr& output_zero_point, const RequantizeAttrs* param,
99 const Array<IndexExpr>& input_shape, const DataType& out_dtype);
100
101std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value,
102 const bool is_cfg_default, const std::string& name);
103
104static inline Expr Requantize(const Expr& data, const Array<IndexExpr>& input_shape,
105 const Expr& input_scale, const Expr& input_zero_point,
106 const Expr& output_scale, const Expr& output_zero_point,
107 const DataType& out_dtype, const int& axis = -1,
108 const std::string& rounding = "None",
109 const std::string& compute_dtype = "None") {
110 auto attrs = make_object<RequantizeAttrs>();
111 attrs->axis = axis;
112 attrs->out_dtype = std::move(out_dtype);
113 const RequantizeConfig& cfg = RequantizeConfig::Current();
114 attrs->rounding =
115 SelectRequntizeParameter(rounding, cfg->get_rounding(), cfg->is_default, "rounding");
116 attrs->compute_dtype = SelectRequntizeParameter(compute_dtype, cfg->get_compute_dtype(),
117 cfg->is_default, "compute_dtype");
118 return RequantizeLower(data, input_scale, input_zero_point, output_scale, output_zero_point,
119 attrs.operator->(), input_shape, attrs->out_dtype);
120}
121
122Expr MakeRequantize(Expr data, Expr input_scale, Expr input_zero_point, Expr output_scale,
123 Expr output_zero_point, int axis, String rounding, String compute_dtype,
124 DataType out_dtype);
125
126Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
127 const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
128 const DequantizeAttrs* attrs);
129
130static inline Expr Dequantize(const Expr& data, const Expr& input_scale,
131 const Expr& input_zero_point, const Array<tvm::relay::Type>& types,
132 const int& axis = -1) {
133 auto attrs = make_object<DequantizeAttrs>();
134 attrs->axis = std::move(axis);
135
136 return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->());
137}
138Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis);
139
140Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
141 const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
142 const QuantizeAttrs* attrs);
143
144static inline Expr Quantize(const Expr& data, const Expr& output_scale,
145 const Expr& output_zero_point, const DataType& out_dtype,
146 const Array<tvm::relay::Type>& types, const int& axis = -1) {
147 auto attrs = make_object<QuantizeAttrs>();
148 attrs->axis = std::move(axis);
149 attrs->out_dtype = std::move(out_dtype);
150
151 return QuantizeLower(data, output_scale, output_zero_point, types, attrs.operator->());
152}
153Expr MakeQuantize(Expr data, Expr output_scale, Expr output_zero_point, int axis,
154 DataType out_dtype);
155
156static inline int64_t get_const_int(const tvm::PrimExpr& x) {
157 auto* value_ptr = tir::as_const_int(x);
158 ICHECK(value_ptr) << "Expr is not a constant int";
159 return value_ptr[0];
160}
161
162/*
163 * \brief Fixed point multiplication between integer tensor with floating point
164 * scalar. This implementation rounds to the nearest value when it is midway
165 * between two representable values.
166 * \param tensor The quantized input tensor of dtype int64.
167 * \param multiplier The scalar multiplier.
168 * \param input_shape Shape of the input tensor.
169 * \return The sequence of Relay ops for fixed point multiplication with TONEARES rounding.
170
171 * \note Original compuation is scale_fp32 * quantized_tensor. To convert into
172 * integer computation, the multiplication with fp32 scalar can be
173 * replaced by multiplication with an int value and then right shifting
174 * the result. This approximates the floating point computation with a
175 * fixed point computation.
176 *
177 * Computation of fixed point multiplication is consist of following
178 steps:
179 * 1) Multiply the fixed point multiplier with quantized tensor.
180 * 2) Round the result.
181 * 3) Right shift the result
182 */
183Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
184 const Array<IndexExpr>& input_shape);
185
186/*
187 * \brief Fixed point multiplication between integer tensor with floating point
188 scalar where the input tensor is per-axis/per-channel quantized..
189 * \param tensor The quantized input tensor of dtype int64.
190 * \param multiplier The scalar multiplier.
191 * \param input_shape Shape of the input tensor.
192 * \param channel_axis The channel_axis along which the input tensor is quantized. Default value is
193 -1 which corresponds to the last channel_axis.
194 * \param rounding "UPWARD" or "TONEAREST". The rounding direction when the value
195 is midway between" "two representable values.
196 * \return The sequence of Relay ops for fixed point multiplication.
197
198 * \note Original compuation is scale_fp32 * quantized_tensor. To convert into
199 * integer computation, the multiplication with fp32 vector can be
200 * replaced by multiplication with an int vector and then right shifting
201 * the result. This approximates the floating point computation with a
202 * fixed point computation.
203 *
204 * Computation of fixed point multiplication is consist of following
205 steps:
206 * 1) Multiply the fixed point multiplier with quantized tensor.
207 * 2) Round the result.
208 * 3) Right shift the result
209 */
210Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multiplier,
211 const Array<IndexExpr>& input_shape, int channel_axis,
212 const std::string& rounding);
213
214/*
215 * Wrapper for 'FixedPointMultiplyPerChannel' with rounding parameter == "TONEAREST".
216 */
217Expr FixedPointMultiplyPerChannelToNearest(Expr tensor, std::vector<double> multiplier,
218 const Array<IndexExpr>& input_shape, int channel_axis);
219
220/*
221 * \brief Creates FixedPointMultiply operation where the input tensor is
222 per-axis/per-channel quantized..
223 * \param tensor The quantized input tensor.
224 * \param multipliers List of scalar multipliers.
225 * \param channel_axis The channel_axis along which the input tensor is quantized.
226 * \return The Relay op.
227 */
228Expr FixedPointMultiplyPerChannel(Expr tensor, const std::vector<double>& multipliers, int axis);
229
230/*
231 * \brief Checks whether an expr type is scalar of a given data type.
232 * \param expr_type The type of expr to be checked.
233 * \param dtype The expected dtype.
234 * \return True if the type is a scalar of given dtype
235 */
236static inline bool IsScalarType(const Type& expr_type, const DataType& dtype) {
237 const auto* tensor_type = expr_type.as<TensorTypeNode>();
238 ICHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
239 << AsText(expr_type, false);
240 ICHECK_EQ(tensor_type->shape.size(), 0);
241 ICHECK(tensor_type->dtype == dtype) << "Expected " << dtype << " but got " << tensor_type->dtype;
242 return true;
243}
244
245/*
246 * \brief Checks whether an expr type is scalar.
247 * \param expr_type The type of expr to be checked.
248 * \return True if the type is a scalar
249 */
250static inline bool IsScalarType(const Type& expr_type) {
251 const auto* tensor_type = expr_type.as<TensorTypeNode>();
252 CHECK(tensor_type) << "Only tensor type can be checked for scalar values. But got"
253 << AsText(expr_type, false);
254 return tensor_type->shape.size() == 0;
255}
256
257/*
258 * \brief Checks and assigns types to scale and zero points.
259 * \param expr_type The type of expr to be checked.
260 * \param dtype The expected dtype.
261 * \param shape The shape at C dim of original tensor.
262 * \param reporter The type reported of original InferType call.
263 */
264static inline void AssignType(const Type& expr_type, const DataType& dtype, const IndexExpr& shape,
265 const TypeReporter& reporter) {
266 // Scale/Zero_points can be either const scalar or a vector with C axis num elems.
267 const auto* tensor_type = expr_type.as<TensorTypeNode>();
268 ICHECK(tensor_type) << "Can assign type to Tensor type only. But got "
269 << AsText(expr_type, false);
270 const auto tensor_dtype = tensor_type->dtype;
271 ICHECK(tensor_dtype == dtype) << "Expected type is " << dtype << " but received " << tensor_dtype;
272 if (tensor_type->shape.size() != 0) {
273 reporter->Assign(expr_type, TensorType({shape}, tensor_type->dtype));
274 }
275}
276
277static inline std::vector<float> GetFloatVectorFromConstant(const Expr& expr) {
278 const auto* n = expr.as<ConstantNode>();
279 std::vector<float> vals;
280 ICHECK(n) << "Expr must be a constant expr - " << AsText(expr, false);
281 int64_t num_elems = 1;
282 auto shape = n->data.Shape();
283 for (size_t i = 0; i < shape.size(); i++) {
284 num_elems *= shape[i];
285 }
286 for (int64_t i = 0; i < num_elems; i++) {
287 vals.push_back(static_cast<float*>(n->data->data)[i]);
288 }
289 return vals;
290}
291
292Expr MakeQnnConv2D(Expr data, Expr weight, Expr input_zero_point, Expr kernel_zero_point,
293 Expr input_scale, Expr kernel_scale, Array<IndexExpr> strides,
294 Array<IndexExpr> padding, Array<IndexExpr> dilation, int groups,
295 IndexExpr channels, Array<IndexExpr> kernel_size, String data_layout,
296 String kernel_layout, String out_layout, DataType out_dtype);
297
298} // namespace qnn
299} // namespace relay
300} // namespace tvm
301#endif // TVM_RELAY_QNN_UTILS_H_
302