1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/qnn/utils.cc
22 * \brief Utility functions for QNN.
23 */
24
25#include "utils.h"
26
27#include "../transforms/pattern_utils.h"
28
29namespace tvm {
30namespace relay {
31namespace qnn {
32
33std::pair<int32_t, int32_t> GetFixedPointMultiplierShift(double double_multiplier) {
34 int32_t significand, exponent;
35 if (double_multiplier == 0.) {
36 significand = 0;
37 exponent = 0;
38 return std::make_pair(significand, exponent);
39 }
40
41 // Get the significand and exponent.
42 double significand_d = std::frexp(double_multiplier, &exponent);
43
44 // Convert the double significand to int significand, i.e., convert into a
45 // integer where the decimal point is between bit 31 and 30. This is done by
46 // multiplying the double value with 2^31 and then casting to int.
47 significand_d = std::round(significand_d * (1ll << 31));
48 auto significand_int64 = static_cast<int64_t>(significand_d);
49 ICHECK_LE(significand_int64, (1ll << 31));
50 if (significand_int64 == (1ll << 31)) {
51 significand_int64 /= 2;
52 ++exponent;
53 }
54 ICHECK_LE(significand_int64, std::numeric_limits<int32_t>::max());
55 significand = static_cast<int32_t>(significand_int64);
56 return std::make_pair(significand, exponent);
57}
58
59Expr FixedPointMultiplyToNearest(Expr tensor, double multiplier,
60 const Array<IndexExpr>& input_shape) {
61 // Choose high precision datatype to be int64. This is for avoiding overflow
62 // in multiplication of two int32 values.
63 DataType hp_dtype = DataType::Int(64);
64 tensor = Cast(tensor, hp_dtype);
65
66 // 1) Calculating the integer multiplier and integer shift
67 auto [fixed_point_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
68 int left_shift = shift > 0 ? shift : 0;
69 int right_shift = shift > 0 ? 0 : -shift;
70
71 // 2) Multiply the integer multiplier
72 if (left_shift != 0) {
73 tensor = LeftShift(tensor, MakeConstantScalar(hp_dtype, left_shift));
74 }
75
76 // 3) Perform the multiplication in higher precision.
77 // The scalar is a fixed point value of int32 where the decimal point is
78 // between bits 31 and 30. After multiplying with input_tensor, the result
79 // is in int64 where the decimal point is sitting between bits 31 and 30
80 // (from the right, rightmost bit is bit 0). The computation is performed in
81 // higher precision to avoid overflow in multiplying two int32 values.
82 Expr scalar = MakeConstantScalar(hp_dtype, fixed_point_multiplier);
83 tensor = Multiply(tensor, scalar);
84
85 // 4) Find the rounding scalar. This depends on where the final decimal
86 // point sits. As we will be right shifting the multiplied_t, we need to
87 // first calculate the total_right_shift.
88 int total_right_shift = right_shift + 31;
89 int64_t pos_rounding_value = (1ll << (total_right_shift - 1));
90
91 Expr round_scalar;
92
93 auto pos_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value);
94 auto neg_rounder = MakeConstantScalar(hp_dtype, pos_rounding_value - 1);
95 auto pos_rounder_t = Full(pos_rounder, input_shape, hp_dtype);
96 auto neg_rounder_t = Full(neg_rounder, input_shape, hp_dtype);
97
98 auto zero_t = Zeros(input_shape, hp_dtype);
99 round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder_t, neg_rounder_t);
100
101 // Add the rounding scalar.
102 tensor = Add(tensor, round_scalar);
103
104 // 5) Simply right shift the result to get the final output.
105 tensor = RightShift(tensor, MakeConstantScalar(hp_dtype, total_right_shift));
106
107 // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
108 return Cast(tensor, DataType::Int(32));
109}
110
111Expr FixedPointMultiplyPerChannel(Expr tensor, const std::vector<double>& multipliers, int axis) {
112 DataType dtype = DataType::Int(32);
113 int64_t n_channels = static_cast<int64_t>(multipliers.size());
114
115 std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
116 bool is_lshift_required = false, is_rshift_required = false;
117 for (auto multiplier : multipliers) {
118 auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
119 int lshift = shift > 0 ? shift : 0;
120 int rshift = shift > 0 ? 0 : -shift;
121 fixed_pt_multipliers.push_back(fixed_pt_multiplier);
122 lshifts.push_back(lshift);
123 rshifts.push_back(rshift);
124 is_lshift_required = is_lshift_required | (lshift != 0);
125 is_rshift_required = is_rshift_required | (rshift != 0);
126 }
127
128 auto left_shift_expr = MakeConstantTensor(dtype, {n_channels}, lshifts);
129 auto right_shift_expr = MakeConstantTensor(dtype, {n_channels}, rshifts);
130 auto fixed_pt_multiplier_expr = MakeConstantTensor(dtype, {n_channels}, fixed_pt_multipliers);
131
132 return FixedPointMultiplyPerAxis(tensor, fixed_pt_multiplier_expr, left_shift_expr,
133 right_shift_expr, is_lshift_required, is_rshift_required,
134 {axis});
135}
136
137Expr FixedPointMultiplyPerChannel(Expr tensor, std::vector<double> multipliers,
138 const Array<IndexExpr>& input_shape, int channel_axis,
139 const std::string& rounding) {
140 // Get the n dim. This will be used to expand the multiplier to match the axis.
141 size_t n_dim = input_shape.size();
142
143 // Get the num of channels/axis along which the tensor was quantized.
144 int64_t n_channels = (int64_t)multipliers.size();
145
146 // Choose high precision datatype to be int64. This is for avoiding overflow
147 // in multiplication of two int32 values.
148 DataType hp_dtype = DataType::Int(64);
149 tensor = Cast(tensor, hp_dtype);
150
151 // 1) Calculating the integer multiplier and integer shift. These are calculated per axis/per
152 // channel.
153 std::vector<int32_t> fixed_pt_multipliers, lshifts, rshifts;
154 bool is_lshift_required = false;
155 for (auto multiplier : multipliers) {
156 auto [fixed_pt_multiplier, shift] = GetFixedPointMultiplierShift(multiplier);
157 int lshift = shift > 0 ? shift : 0;
158 int rshift = shift > 0 ? 0 : -shift;
159 fixed_pt_multipliers.push_back(fixed_pt_multiplier);
160 lshifts.push_back(lshift);
161 rshifts.push_back(rshift);
162 is_lshift_required = is_lshift_required | (lshift != 0);
163 }
164
165 // 2) Multiply the integer multiplier. Convert lefts shifts into expr and multiply.
166 if (is_lshift_required) {
167 auto lshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, lshifts);
168 auto exp_lshift_expr = ExpandBiasToMatchAxis(lshift_expr, n_dim, {channel_axis});
169 tensor = LeftShift(tensor, exp_lshift_expr);
170 }
171
172 // 3) Perform the multiplication in higher precision.
173 // The scalar is a fixed point value of int32 where the decimal point is
174 // between bits 31 and 30. After multiplying with input_tensor, the result
175 // is in int64 where the decimal point is sitting between bits 31 and 30
176 // (from the right, rightmost bit is bit 0). The computation is performed in
177 // higher precision to avoid overflow in multiplying two int32 values.
178 auto fixed_pt_multiplier_expr = MakeConstantTensor(hp_dtype, {n_channels}, fixed_pt_multipliers);
179 auto exp_fixed_pt_multiplier_expr =
180 ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {channel_axis});
181 tensor = Multiply(tensor, exp_fixed_pt_multiplier_expr);
182
183 // 4) Find the rounding scalar. This depends on where the final decimal point sits. As we will be
184 // right shifting the multiplied_t, we need to first calculate the total_rshift. Further, we can
185 // calculate the pos and neg rounding offset.
186 std::vector<int64_t> pos_rounding_values, neg_rounding_values, total_rshifts;
187 for (auto rshift : rshifts) {
188 int total_rshift = rshift + 31;
189 total_rshifts.push_back(total_rshift);
190 pos_rounding_values.push_back((1ll << (total_rshift - 1)));
191 neg_rounding_values.push_back((1ll << (total_rshift - 1)) - 1);
192 }
193 // Make a Relay expr from positive and negative rounding offset values.
194 auto pos_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, pos_rounding_values);
195 auto exp_pos_rounding_value_expr =
196 ExpandBiasToMatchAxis(pos_rounding_value_expr, n_dim, {channel_axis});
197 auto neg_rounding_value_expr = MakeConstantTensor(hp_dtype, {n_channels}, neg_rounding_values);
198 auto exp_neg_rounding_value_expr =
199 ExpandBiasToMatchAxis(neg_rounding_value_expr, n_dim, {channel_axis});
200
201 Expr round_scalar;
202 if (rounding == "UPWARD") {
203 round_scalar = exp_pos_rounding_value_expr;
204 } else if (rounding == "TONEAREST") {
205 // To satisfy where op shape requirements, the rounding values are broadcasted.
206 auto pos_rounder = BroadCastTo(exp_pos_rounding_value_expr, input_shape);
207 auto neg_rounder = BroadCastTo(exp_neg_rounding_value_expr, input_shape);
208
209 auto zero_t = Zeros(input_shape, hp_dtype);
210 round_scalar = Where(GreaterEqual(tensor, zero_t), pos_rounder, neg_rounder);
211 } else {
212 LOG(FATAL) << "Rounding mode " << rounding << " not supported.";
213 }
214 // Add the rounding scalar.
215 tensor = Add(tensor, round_scalar);
216
217 // 5) Simply right shift the result to get the final output.
218 auto total_rshift_expr = MakeConstantTensor(hp_dtype, {n_channels}, total_rshifts);
219 auto exp_total_rshift_expr = ExpandBiasToMatchAxis(total_rshift_expr, n_dim, {channel_axis});
220 tensor = RightShift(tensor, exp_total_rshift_expr);
221
222 // 6) The fixed point multiplication keeps the value in int32 range. Casting back to int32.
223 return Cast(tensor, DataType::Int(32));
224}
225
226Expr FixedPointMultiplyPerChannelToNearest(Expr tensor, std::vector<double> multipliers,
227 const Array<IndexExpr>& input_shape, int channel_axis) {
228 return FixedPointMultiplyPerChannel(tensor, multipliers, input_shape, channel_axis, "TONEAREST");
229}
230
231std::string SelectRequntizeParameter(const std::string& arg_value, const std::string& cfg_value,
232 const bool is_cfg_default, const std::string& name) {
233 if (arg_value == "None") {
234 return cfg_value;
235 } else {
236 if (!is_cfg_default && arg_value != cfg_value) {
237 DLOG(INFO) << "The value of parameter \"" << name
238 << "\" from the non-default requantize config will not be used. The value "
239 "provided from "
240 "requantize function argument will be used instead. The value used is \""
241 << arg_value << "\".";
242 }
243 return arg_value;
244 }
245}
246
247} // namespace qnn
248} // namespace relay
249} // namespace tvm
250