1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/compiler/mlir/lite/quantization/quantization_utils.h"
17
18#include <algorithm>
19#include <cstdint>
20#include <iterator>
21#include <limits>
22#include <memory>
23#include <numeric>
24#include <string>
25
26#include "llvm/ADT/STLExtras.h"
27#include "llvm/ADT/SmallVector.h"
28#include "llvm/Support/Casting.h"
29#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30#include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project
31#include "mlir/IR/Attributes.h" // from @llvm-project
32#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
33#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34#include "mlir/IR/Diagnostics.h" // from @llvm-project
35#include "mlir/IR/MLIRContext.h" // from @llvm-project
36#include "mlir/Support/LLVM.h" // from @llvm-project
37#include "mlir/Support/LogicalResult.h" // from @llvm-project
38#include "tensorflow/compiler/mlir/lite/quantization/ir/FakeQuantSupport.h"
39#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h"
40#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h"
41#include "tensorflow/compiler/mlir/lite/quantization/ir/UniformSupport.h"
42#include "tensorflow/compiler/mlir/lite/quantization/quantization_traits.h"
43#include "tensorflow/lite/kernels/internal/tensor_utils.h"
44#include "tensorflow/lite/tools/optimize/quantization_utils.h"
45
46namespace mlir {
47
48// This includes the interface class definition. It couldn't be in a namespace
49// because the table gen doesn't emit the namespace when it is used.
50#include "tensorflow/compiler/mlir/lite/quantization/quantization_interface.cc.inc"
51
52namespace quant {
53
54namespace {
55constexpr double kSmallestHalfRange = kNearZeroTolerance / 2;
56using QType = quant::QuantizedType;
57
58// This method expands the range to be larger than or equal to 1.0e-6, if it is
59// very small (< 1.0e-6). This is to prevent very large quantized value by this
60// range.
61void ExpandVerySmallRange(ArrayRef<double> mins, ArrayRef<double> maxs,
62 SmallVectorImpl<double>* effective_mins,
63 SmallVectorImpl<double>* effective_maxs) {
64 for (auto arg : llvm::zip(mins, maxs)) {
65 double min = std::get<0>(arg);
66 double max = std::get<1>(arg);
67 // The range is wide, then use the same min/max.
68 if ((max - min) > kNearZeroTolerance) {
69 effective_mins->push_back(min);
70 effective_maxs->push_back(max);
71 continue;
72 }
73
74 // The range is small. Expands the range to stride 0.0 and also at least
75 // 1.0e-6.
76 effective_mins->push_back(std::min(min, -kSmallestHalfRange));
77 effective_maxs->push_back(std::max(max, kSmallestHalfRange));
78 }
79}
80
81// Set the min / max, scale and zero_points from the fake quant num_bits
82// attribute from QAT.
83QuantizedType ResetMinMaxFromNumBits(QuantizedType type, int num_bits,
84 bool narrow_range, bool is_signed) {
85 if (num_bits >= 8) {
86 return type;
87 }
88 int64_t qmin = QType::getDefaultMinimumForInteger(is_signed, num_bits);
89 int64_t qmax = QType::getDefaultMaximumForInteger(is_signed, num_bits);
90 if (narrow_range) {
91 qmin += 1;
92 }
93 const int64_t storage_type_min = type.getStorageTypeMin();
94 const int64_t storage_type_max = type.getStorageTypeMax();
95 const double rate =
96 static_cast<double>(storage_type_max - storage_type_min) / (qmax - qmin);
97 const auto& recalculate_scale = [&](double scale) -> double {
98 return scale * rate;
99 };
100 const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t {
101 return qmax - std::round((storage_type_max - zero_point) / rate);
102 };
103 if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
104 const double scale = recalculate_scale(q_type.getScale());
105 const double zero_point = recalculate_zero_point(q_type.getZeroPoint());
106 return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(),
107 q_type.getExpressedType(), scale,
108 zero_point, qmin, qmax);
109 } else if (auto q_type =
110 type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
111 const int size = q_type.getScales().size();
112 SmallVector<double, 4> scales(size);
113 SmallVector<int64_t, 4> zero_points(size);
114 for (int i = 0; i < size; ++i) {
115 scales[i] = recalculate_scale(q_type.getScales()[i]);
116 zero_points[i] = recalculate_zero_point(q_type.getZeroPoints()[i]);
117 }
118 return quant::UniformQuantizedPerAxisType::get(
119 q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
120 scales, zero_points, q_type.getQuantizedDimension(), qmin, qmax);
121 } else {
122 llvm_unreachable("Unsupported QuantizedType in ResetMinMaxFromNumBits");
123 }
124 return type;
125}
126
127// Repeats the content of `data` multiple times to resize to `target_size`.
128// Note that this only broadcast across one dimension.
129template <typename T>
130bool BroadcastVector(int target_size, SmallVectorImpl<T>& data) {
131 int size = data.size();
132 if (size != target_size) {
133 if (target_size % size != 0) return true;
134 data.reserve(target_size);
135 for (int i = 1, e = target_size / size; i != e; ++i) {
136 data.insert(data.end(), data.begin(), data.begin() + size);
137 }
138 }
139 return false;
140}
141
142// Changes the axis of the input per-channel quantized type to match the
143// dimension of the target type. Returns nullptr if it fails.
144quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast(
145 ArrayRef<int64_t> shape, quant::UniformQuantizedPerAxisType qtype,
146 Type target, int quant_dim) {
147 auto shaped = target.dyn_cast<RankedTensorType>();
148 if (!shaped) return {};
149 ArrayRef<int64_t> new_shape = shaped.getShape();
150
151 SmallVector<double, 4> scales(qtype.getScales().begin(),
152 qtype.getScales().end());
153 SmallVector<int64_t, 4> zero_points(qtype.getZeroPoints().begin(),
154 qtype.getZeroPoints().end());
155
156 if (new_shape.size() == shape.size()) { // same rank
157 // Broadcast the scales and zero points to match the target size, which is
158 // usually the axis-th dimension of the target type. Currently, it covers
159 // two cases:
160 // - for Transpose, the data layout is changed so the `dim[axis]` still
161 // equals to the `scales_size`. The broadcast skips;
162 // - for Reshape, the data layout isn't changed but the innermost dimension
163 // is expand to cover the last two original dimensions. Thus we just need to
164 // be repeated the `scales` dim[2] times to covers the new dim length.
165 //
166 // TODO(b/141709944): after the fix, the `scales` can be for dim[2], thus we
167 // have to repeat each elements in the `scales` locally dim[3] times.
168 if (BroadcastVector<double>(shaped.getDimSize(quant_dim), scales) ||
169 BroadcastVector<int64_t>(shaped.getDimSize(quant_dim), zero_points)) {
170 return {};
171 }
172 } else if ((new_shape.size() == shape.size() + 1) && new_shape.front() == 1) {
173 // Handle the [A, B, C] -> [1, A, B, C] reshape case.
174 if (!(std::equal(shape.begin(), shape.end(), new_shape.begin() + 1) &&
175 quant_dim == new_shape.size() - 1)) {
176 return {};
177 }
178 } else {
179 return {};
180 }
181
182 return quant::UniformQuantizedPerAxisType::get(
183 qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
184 scales, zero_points, quant_dim, qtype.getStorageTypeMin(),
185 qtype.getStorageTypeMax());
186}
187
188} // namespace
189
190bool IsOpNotQuantizable(Operation* op) {
191 // If it is terminator or not quantizable or any ops form the mlir quant
192 // ops dialect, we shouldn't rewrite.
193 bool attr_enforced_quantizable =
194 op->hasAttrOfType<StringAttr>(kQuantTraitAttrName) &&
195 op->getAttrOfType<StringAttr>(kQuantTraitAttrName).getValue().str() ==
196 QuantTraitValues[QuantizationTrait::FullyQuantizable];
197
198 // Constant ops do not have QuantizableResult attribute but they can deal with
199 // quantized tensors.
200 if (llvm::isa<func::ConstantOp, arith::ConstantOp, quantfork::StatisticsOp>(
201 op))
202 return false;
203
204 bool prop_enforced_quantizable =
205 op->hasTrait<OpTrait::quant::QuantizableResult>();
206
207 return op->hasTrait<OpTrait::IsTerminator>() ||
208 llvm::isa<quantfork::QuantizeCastOp, quantfork::DequantizeCastOp>(
209 op) ||
210 (!attr_enforced_quantizable && !prop_enforced_quantizable);
211}
212
213// Returns the quantized type for the
214// input_type/min/max/storag_type_width/narrow_range.
215// This is entry point to the Quant dialect and used for both quantizing
216// activations and weights.
217Type GetQuantizedType(Builder builder, Type input_type, ArrayRef<double> min,
218 ArrayRef<double> max, int quant_dim,
219 int storage_type_width, bool narrow_range, bool is_signed,
220 bool legacy_float_scale, bool use_fake_quant_num_bits) {
221 auto converter =
222 quantfork::ExpressedToQuantizedConverter::forInputType(input_type);
223
224 // Expand the range to prevent extremely small scales and large quantized
225 // integers which can cause overflow. This leads to scale
226 // 7.843137254901961e-9 with 8 bits.
227 SmallVector<double, 4> effective_mins, effective_maxs;
228 ExpandVerySmallRange(min, max, &effective_mins, &effective_maxs);
229
230 quant::QuantizedType quantizedEleType;
231 if (min.size() == 1 && max.size() == 1 && quant_dim == -1) {
232 quantizedEleType = quantfork::fakeQuantAttrsToType(
233 builder.getUnknownLoc(), storage_type_width, effective_mins[0],
234 effective_maxs[0], narrow_range, converter.expressedType, is_signed);
235 if (legacy_float_scale) {
236 quantizedEleType =
237 DownCastScale(quantizedEleType, effective_mins[0], effective_maxs[0],
238 builder.getUnknownLoc());
239 }
240 } else if (min.size() == max.size()) {
241 auto shape = input_type.dyn_cast<ShapedType>();
242 if (!shape || shape.getRank() <= quant_dim ||
243 static_cast<int64_t>(min.size()) != shape.getDimSize(quant_dim)) {
244 return {};
245 }
246 // The quantization dim is set to the last dimension.
247 quantizedEleType = quantfork::fakeQuantAttrsToType(
248 builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins,
249 effective_maxs, narrow_range, converter.expressedType, is_signed);
250 if (legacy_float_scale) {
251 quantizedEleType = DownCastScale(quantizedEleType, effective_mins,
252 effective_maxs, builder.getUnknownLoc());
253 }
254 }
255 if (!quantizedEleType) return {};
256 // Use fake quant configured bit-widths (only supported for
257 // 1 < num_bits < 8 bits) instead of using 8bit defaults.
258 if (use_fake_quant_num_bits && (storage_type_width > 1) &&
259 (storage_type_width < 8) &&
260 (quantizedEleType.getStorageTypeMax() >
261 QType::getDefaultMinimumForInteger(is_signed, storage_type_width))) {
262 auto resetEleType = ResetMinMaxFromNumBits(
263 quantizedEleType, storage_type_width, narrow_range, is_signed);
264 return converter.convert(resetEleType);
265 }
266 return converter.convert(quantizedEleType);
267}
268
269// TODO(fengliuai): promote this utility method to mlir QuantOps.
270TypeAttr RescaleQuantizedType(Type input, Attribute factor) {
271 auto factor_values = factor.dyn_cast_or_null<DenseFPElementsAttr>();
272 if (!factor_values) return {};
273 auto ele_type = quant::QuantizedType::getQuantizedElementType(input);
274 if (!ele_type) return {};
275 if (auto qtype = ele_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
276 ArrayRef<double> scales = qtype.getScales();
277 // Broadcasting hasn't been implemented yet.
278 if (static_cast<int64_t>(scales.size()) != factor_values.getNumElements())
279 return {};
280 SmallVector<double, 4> new_scales;
281 new_scales.reserve(scales.size());
282 auto scales_iter = scales.begin();
283 for (const auto& f : factor_values) {
284 new_scales.push_back(*(scales_iter++) *
285 std::fabs(FloatAttr::getValueAsDouble(f)));
286 }
287 // We are assuming symmetric quantization.
288 auto new_ele_type = quant::UniformQuantizedPerAxisType::get(
289 qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(),
290 new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(),
291 qtype.getStorageTypeMin(), qtype.getStorageTypeMax());
292 if (auto new_type = new_ele_type.castFromExpressedType(
293 quant::QuantizedType::castToExpressedType(input))) {
294 return TypeAttr::get(new_type);
295 }
296 }
297 // Currently, we only support per-axis quantized type.
298 return {};
299}
300
301TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min,
302 Attribute max, int quant_dim,
303 IntegerAttr num_bits, BoolAttr narrow_range,
304 bool is_signed, bool legacy_float_scale,
305 bool use_fake_quant_num_bits) {
306 SmallVector<double, 4> min_value, max_value;
307 auto mins = min.dyn_cast<DenseFPElementsAttr>();
308 auto maxs = max.dyn_cast<DenseFPElementsAttr>();
309 if (mins && maxs) {
310 min_value.reserve(mins.getNumElements());
311 max_value.reserve(maxs.getNumElements());
312 for (auto it = mins.begin(), e = mins.end(); it != e; ++it) {
313 min_value.push_back(FloatAttr::getValueAsDouble(*it));
314 }
315 for (auto it = maxs.begin(), e = maxs.end(); it != e; ++it) {
316 max_value.push_back(FloatAttr::getValueAsDouble(*it));
317 }
318 } else {
319 auto fmin = min.dyn_cast<FloatAttr>();
320 auto fmax = max.dyn_cast<FloatAttr>();
321 if (fmin && fmax) {
322 min_value.push_back(fmin.getValueAsDouble());
323 max_value.push_back(fmax.getValueAsDouble());
324 } else {
325 return {};
326 }
327 }
328 Type final_type =
329 GetQuantizedType(builder, input_type, min_value, max_value, quant_dim,
330 num_bits.getInt(), narrow_range.getValue(), is_signed,
331 legacy_float_scale, use_fake_quant_num_bits);
332 if (!final_type) return {};
333 return TypeAttr::get(final_type);
334}
335
336TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder,
337 TypeAttr source, Type target,
338 int axis) {
339 auto source_type = source.getValue().dyn_cast_or_null<ShapedType>();
340 if (!source_type) return {};
341 auto src_ele_type = source_type.getElementType();
342 auto qtype = src_ele_type.dyn_cast<quant::QuantizedType>();
343
344 // Reset the quantization dimensions if it is per-axis.
345 if (auto per_axis =
346 qtype.dyn_cast_or_null<quant::UniformQuantizedPerAxisType>()) {
347 // For the pass-through ops, we don't know which the dimension will be the
348 // new quantization dimension. Only if the new quantization dimension can
349 // be inferred, it is safe to reset the per-axis quantized type.
350 if (axis == -1) return {};
351 qtype =
352 ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis);
353 }
354 if (!qtype) return {};
355 Type final_type = qtype.castFromExpressedType(target);
356 if (!final_type) return {};
357 return TypeAttr::get(final_type);
358}
359
360void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size,
361 int slice_size, bool symmetric,
362 SmallVectorImpl<double>& mins,
363 SmallVectorImpl<double>& maxs) {
364 // If all the element values are same we don't need to scan the content.
365 if (values.isSplat()) {
366 double single_value =
367 FloatAttr::getValueAsDouble(values.getSplatValue<llvm::APFloat>());
368
369 // When the single value isn't 0.0, we expand it to a range to include
370 // this single value and 0.0. This will give us a scale and zero point
371 // works for both this value and 0.0.
372 if (single_value < 0.0) {
373 mins[0] = single_value;
374 maxs[0] = symmetric ? -single_value : 0.0;
375 } else if (single_value > 0.0) {
376 mins[0] = symmetric ? -single_value : 0.0;
377 maxs[0] = single_value;
378 } else {
379 mins[0] = maxs[0] = single_value;
380 }
381 for (int i = 1; i < dim_size; ++i) {
382 mins[i] = mins[0];
383 maxs[i] = maxs[0];
384 }
385 } else {
386 int64_t flatten_index = 0;
387 for (auto it = values.begin(), e = values.end(); it != e;
388 ++it, ++flatten_index) {
389 double ele_value = FloatAttr::getValueAsDouble(*it);
390 int slice_index = flatten_index / slice_size;
391 int channel_index = slice_index % dim_size;
392 mins[channel_index] = std::min(mins[channel_index], ele_value);
393 maxs[channel_index] = std::max(maxs[channel_index], ele_value);
394 }
395 // Expand range to include 0.
396 for (int i = 0; i < dim_size; ++i) {
397 maxs[i] = std::max(maxs[i], 0.0);
398 mins[i] = std::min(mins[i], 0.0);
399 }
400 if (symmetric) {
401 for (int i = 0; i < dim_size; ++i) {
402 maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i]));
403 mins[i] = -maxs[i];
404 }
405 }
406 }
407}
408
409Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric,
410 unsigned num_bits, bool is_signed,
411 bool narrow_range,
412 bool legacy_float_scale,
413 bool use_fake_quant_num_bits) {
414 Builder builder(attr.getContext());
415 // `symmetric` can only be used when it is `signed` and `narrow_range`.
416 if (symmetric && (!is_signed || !narrow_range)) return {};
417
418 SmallVector<double, 4> mins(1, std::numeric_limits<double>::max());
419 SmallVector<double, 4> maxs(1, std::numeric_limits<double>::min());
420 auto fp = attr.dyn_cast<DenseFPElementsAttr>();
421 if (!fp) return {};
422
423 // Computes the effective min/max values of the attribute values.
424 ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins,
425 maxs);
426
427 auto type =
428 GetQuantizedType(builder, attr.getType(), mins[0], maxs[0],
429 /*quant_dim=*/-1, num_bits, narrow_range, is_signed,
430 legacy_float_scale, use_fake_quant_num_bits);
431 if (auto ele_type = type.dyn_cast_or_null<TensorType>())
432 return ele_type.getElementType();
433
434 return {};
435}
436
437Type GetUniformQuantizedPerAxisTypeForWeight(ElementsAttr attr, int quant_dim,
438 bool symmetric, unsigned num_bits,
439 bool is_signed, bool narrow_range,
440 bool legacy_float_scale,
441 bool use_fake_quant_num_bits) {
442 Builder builder(attr.getContext());
443 auto shape = attr.getType().cast<ShapedType>().getShape();
444 if (static_cast<int>(shape.size()) <= quant_dim) return {};
445 // `symmetric` can only be used when it is `signed` and `narrow_range`.
446 if (symmetric && (!is_signed || !narrow_range)) return {};
447
448 int dim_size = shape[quant_dim];
449 int slice_size = std::accumulate(std::next(shape.begin(), quant_dim + 1),
450 shape.end(), 1, std::multiplies<int64_t>());
451 SmallVector<double, 4> mins(dim_size, std::numeric_limits<double>::max());
452 SmallVector<double, 4> maxs(dim_size, std::numeric_limits<double>::min());
453 auto fp = attr.dyn_cast<DenseFPElementsAttr>();
454 if (!fp) return {};
455
456 // Computes the effective min/max values of the attribute values.
457 ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs);
458
459 auto type = GetQuantizedType(builder, attr.getType(), mins, maxs, quant_dim,
460 num_bits, narrow_range, is_signed,
461 legacy_float_scale, use_fake_quant_num_bits);
462 if (auto ele_type = type.dyn_cast_or_null<TensorType>())
463 return ele_type.getElementType();
464
465 return {};
466}
467
468quant::QuantizedType GetUniformQuantizedTypeForBias(
469 const std::vector<quant::QuantizedType>& op_types,
470 bool legacy_float_scale) {
471 if (op_types.empty()) return {};
472
473 size_t axis_size = 1;
474 int32_t quant_dim = -1;
475 Type expressed_type;
476 // Requires all the op types are valid UniformQuantizedTypes or
477 // UniformQuantizedPerAxisTypes and also have same expressed type. For all
478 // the UniformQuantizedPerAxisTypes, the quantization dimension index and
479 // dimension sizes are same.
480 for (auto op_type : op_types) {
481 if (!op_type) return {};
482 if (expressed_type && expressed_type != op_type.getExpressedType()) {
483 return {};
484 }
485 expressed_type = op_type.getExpressedType();
486
487 if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
488 if ((axis_size != 1 && axis_size != type.getScales().size())) return {};
489 if (quant_dim != -1 && quant_dim != type.getQuantizedDimension())
490 return {};
491 axis_size = type.getScales().size();
492 quant_dim = type.getQuantizedDimension();
493 } else if (!op_type.isa<quant::UniformQuantizedType>()) {
494 return {};
495 }
496 }
497
498 // The scale from the UniformQuantizedTypes is broadcasted if there are
499 // UniformQuantizedPerAxisTypes.
500 llvm::SmallVector<double, 4> scales(axis_size, 1.0);
501 for (auto op_type : op_types) {
502 if (auto type = op_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
503 for (const auto& index_scale : llvm::enumerate(type.getScales())) {
504 scales[index_scale.index()] *= index_scale.value();
505 }
506 } else if (auto type = op_type.dyn_cast<quant::UniformQuantizedType>()) {
507 for (int index = 0, e = axis_size; index != e; ++index) {
508 scales[index] *= type.getScale();
509 }
510 }
511 }
512 if (legacy_float_scale) {
513 for (int i = 0; i < scales.size(); ++i) {
514 scales[i] = static_cast<float>(scales[i]);
515 }
516 }
517
518 // Builds the result quantized type, which has signed 32 bits storage type.
519 Builder builder(expressed_type.getContext());
520 IntegerType storage_type = builder.getIntegerType(32);
521 int64_t storage_type_min =
522 quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32);
523 int64_t storage_type_max =
524 quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32);
525 if (axis_size == 1) {
526 return quant::UniformQuantizedType::getChecked(
527 builder.getUnknownLoc(),
528 /*flags=*/true, storage_type, expressed_type, scales[0],
529 /*zeroPoint=*/0, storage_type_min, storage_type_max);
530 } else {
531 llvm::SmallVector<int64_t, 4> zero_points(axis_size, 0);
532 // Assume the bias is a 1-D tensor, and set the quantization dim to the last
533 // dimension, which is 0. If the bias rank is larger than 1, this returned
534 // quantized type couldn't be used to quantize the bias.
535 return quant::UniformQuantizedPerAxisType::getChecked(
536 builder.getUnknownLoc(),
537 /*flags=*/true, storage_type, expressed_type, scales, zero_points,
538 /*quantizedDimension=*/0, storage_type_min, storage_type_max);
539 }
540}
541
542ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type) {
543 if (!real_value.isa<DenseFPElementsAttr>() ||
544 !quant::QuantizedType::getQuantizedElementType(tensor_type)) {
545 return {};
546 }
547 auto real_values_attr = real_value.cast<DenseFPElementsAttr>();
548 auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type);
549 std::vector<float> real_values;
550 llvm::SmallVector<APInt, 8> quantized_attr;
551 real_values.reserve(real_values_attr.getNumElements());
552 quantized_attr.reserve(real_values_attr.getNumElements());
553 std::transform(real_values_attr.begin(), real_values_attr.end(),
554 std::back_inserter(real_values), [&](APFloat value) -> float {
555 return value.convertToFloat();
556 });
557 ShapedType new_dense_type =
558 q_type.castExpressedToStorageType(real_values_attr.getType())
559 .dyn_cast_or_null<ShapedType>();
560 int width = q_type.getStorageType().dyn_cast<mlir::IntegerType>().getWidth();
561
562 if (width == 8 && q_type.getStorageTypeMax() == 127 &&
563 q_type.getStorageTypeMin() == -127) {
564 std::vector<int8_t> quantized_values(real_values_attr.getNumElements());
565 if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
566 float min, max, scale;
567 tflite::tensor_utils::SymmetricQuantizeFloats(
568 real_values.data(), real_values.size(), quantized_values.data(), &min,
569 &max, &scale);
570 // The scale has been adjusted, so the adjusted scale should be respected.
571 if (std::abs(scale - uniform_type.getScale()) > 1e-3) {
572 return Quantize(real_value, tensor_type);
573 }
574 } else if (auto uniform_type =
575 q_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
576 std::vector<float> scales_inv;
577 std::vector<int32_t> dimension;
578 dimension.insert(dimension.end(), new_dense_type.getShape().begin(),
579 new_dense_type.getShape().end());
580 std::transform(uniform_type.getScales().begin(),
581 uniform_type.getScales().end(),
582 std::back_inserter(scales_inv),
583 [](float scale) { return 1.0 / scale; });
584
585 tflite::optimize::utils::SymmetricPerChannelQuantizeValues(
586 real_values.data(), scales_inv, dimension,
587 uniform_type.getQuantizedDimension(), &quantized_values);
588 } else {
589 return {};
590 }
591 std::transform(quantized_values.begin(), quantized_values.end(),
592 std::back_inserter(quantized_attr),
593 [&](int8_t value) -> APInt {
594 return APInt(8, value, /*isSigned=*/true);
595 });
596 return DenseElementsAttr::get(new_dense_type, quantized_attr);
597 } else if (width == 8) {
598 // This can be a state tensor, or an actual constant tensor with
599 // asymmetric range. For a state tensor, assigining correct quantization
600 // parameters is sufficient, and for constants with asymmetric range it's
601 // not correctly quantized by legacy quantizer so call the new Quantize.
602 return Quantize(real_value, tensor_type);
603 } else if (width == 16) {
604 if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
605 auto quantized_values =
606 tflite::optimize::utils::SymmetricQuantizeFloatsToInt16(
607 real_values.data(), real_values.size(), uniform_type.getScale());
608 std::transform(quantized_values.begin(), quantized_values.end(),
609 std::back_inserter(quantized_attr),
610 [&](int16_t value) -> APInt {
611 return APInt(16, value, /*isSigned=*/true);
612 });
613 return DenseElementsAttr::get(new_dense_type, quantized_attr);
614 }
615 } else if (width == 32) {
616 std::vector<float> scales;
617 if (auto uniform_type = q_type.dyn_cast<UniformQuantizedType>()) {
618 scales.push_back(uniform_type.getScale());
619 } else if (auto uniform_type =
620 q_type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
621 scales.insert(scales.end(), uniform_type.getScales().begin(),
622 uniform_type.getScales().end());
623 } else {
624 return {};
625 }
626 auto quantized_bias =
627 tflite::optimize::utils::SymmetricBiasQuantize<std::int32_t>(
628 real_values.data(), real_values.size(), scales);
629 std::transform(quantized_bias.begin(), quantized_bias.end(),
630 std::back_inserter(quantized_attr),
631 [&](int32_t value) -> APInt {
632 return APInt(32, value, /*isSigned=*/true);
633 });
634 return DenseElementsAttr::get(new_dense_type, quantized_attr);
635 }
636 return {};
637}
638
639ElementsAttr Quantize(Attribute real_value, Type tensor_type) {
640 if (auto q_type =
641 quant::QuantizedType::getQuantizedElementType(tensor_type)) {
642 Type converted_type;
643 return quantfork::quantizeAttr(real_value, q_type, converted_type)
644 .dyn_cast<ElementsAttr>();
645 }
646 return {};
647}
648
649quant::QuantizedType DownCastScale(QuantizedType type, double min, double max,
650 Location loc) {
651 SmallVector<double, 1> mins = {min};
652 SmallVector<double, 1> maxs = {max};
653 return DownCastScale(type, mins, maxs, loc);
654}
655
656quant::QuantizedType DownCastScale(QuantizedType type,
657 const SmallVectorImpl<double>& mins,
658 const SmallVectorImpl<double>& maxs,
659 Location loc) {
660 // The given type can be null. For example, there can be an invalid scale and
661 // so on.
662 if (!type) return type;
663 SmallVector<double, 4> scales(mins.size());
664 SmallVector<int64_t, 4> zero_points(mins.size());
665 if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
666 zero_points.push_back(q_type.getZeroPoint());
667 } else if (auto q_type =
668 type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
669 zero_points = {q_type.getZeroPoints().begin(),
670 q_type.getZeroPoints().end()};
671 }
672 for (int i = 0; i < mins.size(); ++i) {
673 scales[i] = (static_cast<float>(maxs[i]) - static_cast<float>(mins[i])) /
674 (type.getStorageTypeMax() - type.getStorageTypeMin());
675 if (type.getStorageTypeMax() != -type.getStorageTypeMin()) {
676 // Only applies for asymmetric quantized range with original scale.
677 float zero_point_from_min =
678 type.getStorageTypeMin() - mins[i] / scales[i];
679 if (zero_point_from_min < type.getStorageTypeMin()) {
680 zero_points[i] = static_cast<int64_t>(type.getStorageTypeMin());
681 } else if (zero_point_from_min > type.getStorageTypeMax()) {
682 zero_points[i] = static_cast<int64_t>(type.getStorageTypeMax());
683 } else {
684 zero_points[i] = static_cast<int64_t>(std::round(zero_point_from_min));
685 }
686 }
687 }
688 if (auto q_type = type.dyn_cast<UniformQuantizedType>()) {
689 return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(),
690 q_type.getExpressedType(), scales[0],
691 zero_points[0], q_type.getStorageTypeMin(),
692 q_type.getStorageTypeMax());
693 } else if (auto q_type =
694 type.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
695 return quant::UniformQuantizedPerAxisType::get(
696 q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(),
697 scales, zero_points, q_type.getQuantizedDimension(),
698 q_type.getStorageTypeMin(), q_type.getStorageTypeMax());
699 }
700 return type;
701}
702
703// A heuristic to determine whether the scales needs to be from operands or
704// from results for the ops with the `SameOperandsAndResultsScale` property.
705// The current implementation is based on the number of operands.
706static bool PreferResultScale(Operation* op) {
707 int float_operands = 0;
708 for (auto operand : op->getOperands()) {
709 if (auto operand_type = operand.getType().dyn_cast<ShapedType>()) {
710 if (operand_type.getElementType().isa<FloatType>()) {
711 if (++float_operands > 1) return true;
712 }
713 }
714 }
715 return false;
716}
717
718std::unique_ptr<OpQuantScaleSpec> GetDefaultQuantScaleSpec(Operation* op) {
719 auto spec = std::make_unique<OpQuantScaleSpec>();
720 if (llvm::isa<SameScalesOpInterface>(op)) {
721 spec->has_same_scale_requirement = true;
722 spec->required_same_scale_func = [op](bool sign, int bit_width) {
723 return llvm::cast<SameScalesOpInterface>(op)
724 .RequiredSameOperandsAndResultsScale(sign, bit_width);
725 };
726 spec->required_same_quantized_axes_func = [op]() {
727 return llvm::cast<SameScalesOpInterface>(op).RequiredSameQuantizedAxes();
728 };
729 }
730 if (llvm::isa<FixedOutputRangeInterface>(op)) {
731 spec->has_fixed_output_range = true;
732 spec->fixed_output_range_func = [op](bool sign, int bit_width) {
733 return llvm::cast<FixedOutputRangeInterface>(op).GetFixedOutputRange(
734 sign, bit_width);
735 };
736 }
737 return spec;
738}
739
740// The stats op of some of the ops can be redundant. The current implementation
741// only considers the ops with restricted output params.
742static bool IsStatsRedundant(
743 Operation* op, OpQuantSpecGetter op_quant_spec_getter,
744 OpQuantScaleSpecGetter op_quant_scale_spec_getter) {
745 // If it has FixedOutputRangeInterface, no need to manually create spec.
746 return llvm::isa<FixedOutputRangeInterface>(op) ||
747 op_quant_scale_spec_getter(op)->has_fixed_output_range;
748}
749
750static bool IsSameScaleOp(Operation* op,
751 OpQuantScaleSpecGetter op_quant_scale_spec_getter) {
752 // If it has SameScalesOpInterface, no need to manually create spec.
753 return llvm::dyn_cast<SameScalesOpInterface>(op) ||
754 op_quant_scale_spec_getter(op)->has_same_scale_requirement;
755}
756
757bool RemoveRedundantStatsOps(
758 mlir::func::FuncOp func, OpQuantSpecGetter op_quant_spec_getter,
759 OpQuantScaleSpecGetter op_quant_scale_spec_getter) {
760 llvm::SmallVector<quantfork::StatisticsOp, 16> all_stats_ops;
761 llvm::DenseSet<Operation*> redundant_stats_ops;
762
763 // Step 0: remove the quantfork::StatisticsOp which are used by the
764 // quant.qcast op in case it overrides the information from training FakeQuant
765 // ops.
766 func.walk([&](quantfork::QuantizeCastOp q) {
767 auto input_op = q.getArg().getDefiningOp();
768 if (auto stats =
769 llvm::dyn_cast_or_null<quantfork::StatisticsOp>(input_op)) {
770 q.setOperand(stats.getArg());
771 if (stats.use_empty()) stats.erase();
772 }
773 });
774
775 // Step 1: forward pass: propagate any value scales which are not produces
776 // by `SameOperandsAndResultsScale`. Additionally, remove the value scales
777 // which are produced by the ops with the `FixedOutputRangeInterface`.
778 // Note that we don't propagate across the multiple-operands
779 // `SameOperandsAndResultsScale` ops like `concatenation`.
780 func.walk([&](quantfork::StatisticsOp stats_op) {
781 all_stats_ops.push_back(stats_op);
782 });
783
784 while (!all_stats_ops.empty()) {
785 quantfork::StatisticsOp stats_op = all_stats_ops.back();
786 all_stats_ops.pop_back();
787
788 if (auto def = stats_op.getArg().getDefiningOp()) {
789 if (IsStatsRedundant(def, op_quant_spec_getter,
790 op_quant_scale_spec_getter)) {
791 redundant_stats_ops.insert(stats_op);
792 }
793 }
794
795 for (auto user : stats_op.getResult().getUsers()) {
796 // We don't propagate this parameter down if it has multiple operands.
797 // We want to use the result parameter scales instead.
798 if (!IsSameScaleOp(user, op_quant_scale_spec_getter) ||
799 PreferResultScale(user)) {
800 continue;
801 }
802 for (Value res : user->getResults()) {
803 if (!res.hasOneUse()) {
804 continue;
805 }
806 if (auto next_stats = llvm::dyn_cast<quantfork::StatisticsOp>(
807 *res.getUsers().begin())) {
808 // quantization parameters can be propagated to next_stats
809 redundant_stats_ops.insert(next_stats);
810 // add next_stats to the work list so propagation can continue.
811 all_stats_ops.push_back(next_stats);
812 }
813 }
814 }
815 }
816
817 // Step 2: backward pass: For the ops skiped in the forward pass, propagate
818 // its results scale backwards as far as possible.
819 func.walk([&](quantfork::StatisticsOp stats_op) {
820 if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) {
821 all_stats_ops.push_back(stats_op);
822 }
823 });
824
825 while (!all_stats_ops.empty()) {
826 quantfork::StatisticsOp stats_op = all_stats_ops.back();
827 all_stats_ops.pop_back();
828
829 if (auto def = stats_op.getArg().getDefiningOp()) {
830 if (!IsSameScaleOp(def, op_quant_scale_spec_getter)) {
831 continue;
832 }
833 for (auto input : def->getOperands()) {
834 if (auto next_stats = llvm::dyn_cast_or_null<quantfork::StatisticsOp>(
835 input.getDefiningOp())) {
836 redundant_stats_ops.insert(next_stats);
837 all_stats_ops.push_back(next_stats);
838 }
839 }
840 }
841 }
842
843 // Step3: Remove all the redundant stats ops
844 for (auto it : redundant_stats_ops) {
845 if (!llvm::isa<quantfork::StatisticsOp>(it)) return true;
846 auto stats_op = llvm::cast<quantfork::StatisticsOp>(it);
847 stats_op.getResult().replaceAllUsesWith(stats_op.getArg());
848 stats_op.erase();
849 }
850
851 // Returns false if the steps finish without errors.
852 return false;
853}
854
855LogicalResult VerifySameScales(Operation* op) {
856 auto same_scale_op = llvm::cast<SameScalesOpInterface>(op);
857
858 llvm::SmallVector<QuantizedType, 4> collected_quant_params;
859 for (auto input : op->getOperands()) {
860 auto quant_params = QuantizedType::getQuantizedElementType(input.getType());
861 // Skip non-quantizable operands.
862 if (quant_params) {
863 collected_quant_params.push_back(quant_params);
864 }
865 }
866
867 for (auto output : op->getResults()) {
868 auto quant_params =
869 QuantizedType::getQuantizedElementType(output.getType());
870 // Skip non-quantizable results.
871 if (quant_params) {
872 collected_quant_params.push_back(quant_params);
873 }
874 }
875
876 if (collected_quant_params.size() <= 1) return success();
877 const auto& expected_params = collected_quant_params[0];
878 for (int i = 1; i < collected_quant_params.size(); i++) {
879 const auto& compared_params = collected_quant_params[i];
880 // For some ops (such as Transpose or Squeeze), the quantized axis might not
881 // be the same, this function only verifies the scale and zero point in
882 // that case. The quantized axis should be verified in their own verifier
883 // method.
884 if (!same_scale_op.RequiredSameQuantizedAxes()) {
885 auto expected_per_axis_qtype =
886 expected_params.dyn_cast<quant::UniformQuantizedPerAxisType>();
887 auto compared_per_axis_qtype =
888 compared_params.dyn_cast<quant::UniformQuantizedPerAxisType>();
889 if (expected_per_axis_qtype && compared_per_axis_qtype &&
890 llvm::equal(expected_per_axis_qtype.getScales(),
891 compared_per_axis_qtype.getScales()) &&
892 llvm::equal(expected_per_axis_qtype.getZeroPoints(),
893 compared_per_axis_qtype.getZeroPoints()) &&
894 expected_params.getStorageType() ==
895 compared_params.getStorageType() &&
896 expected_params.getExpressedType() ==
897 compared_params.getExpressedType()) {
898 continue;
899 }
900 }
901 // Same quantization parameters are always ok.
902 if (expected_params == compared_params) continue;
903 // If the quantization parameters are not the same, as long as it has the
904 // same storage type and the op interface doesn't require same scale
905 // constraint for this storage type, it is still ok.
906 if ((expected_params.isSigned() == compared_params.isSigned() &&
907 expected_params.getStorageTypeIntegralWidth() ==
908 compared_params.getStorageTypeIntegralWidth()) &&
909 !same_scale_op.RequiredSameOperandsAndResultsScale(
910 expected_params.isSigned(),
911 expected_params.getStorageTypeIntegralWidth()))
912 continue;
913
914 std::string err_msg =
915 "quantization parameters violate the same scale constraint: ";
916 llvm::raw_string_ostream os(err_msg);
917 expected_params.print(os);
918 os << " vs. ";
919 compared_params.print(os);
920 os.flush();
921 return op->emitOpError(err_msg);
922 }
923 return success();
924}
925
926quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width,
927 Type tensor_type, double scale,
928 int64_t zero_point,
929 int64_t storage_min,
930 int64_t storage_max) {
931 auto result_type = tensor_type.cast<ShapedType>();
932 if (!result_type.getElementType().isa<FloatType>()) return {};
933 Builder builder(result_type.getContext());
934
935 // Only support 8-bits
936 if (bit_width != 8) return {};
937 IntegerType storage_type = builder.getIntegerType(bit_width);
938 if (!is_signed) {
939 zero_point += 128;
940 storage_min += 128;
941 storage_max += 128;
942 }
943 return quant::UniformQuantizedType::getChecked(
944 builder.getUnknownLoc(), is_signed, storage_type,
945 result_type.getElementType(), scale, zero_point, storage_min,
946 storage_max);
947}
948
949Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc) {
950 auto qtype = QType::getQuantizedElementType(signed_tensor_type);
951 if (!qtype || !qtype.isSigned()) return {};
952
953 int num_bits = qtype.getStorageTypeIntegralWidth();
954 // This is a negative value, and will be applied on zero points and fixed
955 // point ranges.
956 int64_t offset =
957 QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits) -
958 QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits);
959
960 auto flags = !quant::QuantizationFlags::Signed;
961 QType new_qtype;
962 if (auto uqtype = qtype.dyn_cast<quant::UniformQuantizedType>()) {
963 new_qtype = quant::UniformQuantizedType::getChecked(
964 loc, flags, qtype.getStorageType(), qtype.getExpressedType(),
965 uqtype.getScale(), uqtype.getZeroPoint() - offset,
966 uqtype.getStorageTypeMin() - offset,
967 uqtype.getStorageTypeMax() - offset);
968 } else if (auto aqtype =
969 qtype.dyn_cast<quant::UniformQuantizedPerAxisType>()) {
970 auto zero_points = aqtype.getZeroPoints();
971 llvm::SmallVector<int64_t, 4> new_zero_points(zero_points.begin(),
972 zero_points.end());
973 for (int i = 0, e = new_zero_points.size(); i != e; ++i) {
974 new_zero_points[i] -= offset;
975 }
976 new_qtype = quant::UniformQuantizedPerAxisType::getChecked(
977 loc, flags, qtype.getStorageType(), qtype.getExpressedType(),
978 aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(),
979 aqtype.getStorageTypeMin() - offset,
980 aqtype.getStorageTypeMax() - offset);
981 }
982 return new_qtype.castFromExpressedType(
983 QType::castToExpressedType(signed_tensor_type));
984}
985
986LogicalResult RemoveDebugAttrPattern::matchAndRewrite(
987 Operation* op, PatternRewriter& rewriter) const {
988 // removeAttr will return nullptr if the attribute did not exist. Thus we can
989 // return success(result) to indicate if this op has changed.
990 return success(/*isSuccess=*/
991 op->removeAttr(kDebugModeOpQuantAttrName) ||
992 op->removeAttr(kDebugModeOpFloatAttrName));
993}
994
995} // namespace quant
996} // namespace mlir
997