1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/quantize.h"
16
17#include <cstddef>
18#include <cstdint>
19
20#include "tensorflow/lite/c/common.h"
21#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22#include "tensorflow/lite/kernels/internal/quantization_util.h"
23#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24#include "tensorflow/lite/kernels/internal/reference/requantize.h"
25#include "tensorflow/lite/kernels/internal/tensor.h"
26#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27#include "tensorflow/lite/kernels/internal/types.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace quantize {
34
35// This file has two implementation of Quantize.
36enum KernelType {
37 kReference,
38 kGenericOptimized,
39};
40
41struct OpData {
42 int32_t output_multiplier;
43 int output_shift;
44};
45
46inline bool IsQuantizedPerChannel(const TfLiteTensor* input) {
47 if (input->quantization.type == kTfLiteAffineQuantization &&
48 input->quantization.params) {
49 auto* quant_params =
50 reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
51 return (quant_params->scale && quant_params->scale->size > 1);
52 }
53 return false;
54}
55
56namespace {
57template <KernelType kernel_type, typename output_type>
58static inline void AffineQuantize(const tflite::QuantizationParams& op_params,
59 const RuntimeShape& input_shape,
60 const float* input_data,
61 const RuntimeShape& output_shape,
62 output_type* output_data) {
63 if (kernel_type == kReference) {
64 reference_ops::AffineQuantize(op_params, input_shape, input_data,
65 output_shape, output_data);
66 } else {
67 optimized_ops::AffineQuantize(op_params, input_shape, input_data,
68 output_shape, output_data);
69 }
70}
71
72template <KernelType kernel_type, typename input_type, typename output_type>
73static inline void Requantize(const input_type* input_data, int32_t size,
74 int32_t effective_scale_multiplier,
75 int32_t effective_scale_shift,
76 int32_t input_zeropoint, int32_t output_zeropoint,
77 output_type* output_data) {
78 if (kernel_type == kReference) {
79 reference_ops::Requantize(input_data, size, effective_scale_multiplier,
80 effective_scale_shift, input_zeropoint,
81 output_zeropoint, output_data);
82 } else {
83 optimized_ops::Requantize(input_data, size, effective_scale_multiplier,
84 effective_scale_shift, input_zeropoint,
85 output_zeropoint, output_data);
86 }
87}
88
89void ReportError(TfLiteContext* context, TfLiteType input_type,
90 TfLiteType output_type) {
91 TF_LITE_KERNEL_LOG(
92 context, "Input type %s with Output type %s is not currently supported.",
93 TfLiteTypeGetName(input_type), TfLiteTypeGetName(output_type));
94}
95} // namespace
96
97void* Init(TfLiteContext* context, const char* buffer, size_t length) {
98 return new OpData;
99}
100
101void Free(TfLiteContext* context, void* buffer) {
102 delete static_cast<OpData*>(buffer);
103}
104
105TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
106 OpData* data = static_cast<OpData*>(node->user_data);
107 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
108 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
109
110 const TfLiteTensor* input;
111 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
112 TfLiteTensor* output;
113 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
114
115 // Currently this only support affine quantization.
116 TF_LITE_ENSURE_EQ(context, output->quantization.type,
117 kTfLiteAffineQuantization);
118
119 if (input->type == kTfLiteFloat32) {
120 // Quantize use case.
121 TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
122 output->type == kTfLiteInt8 ||
123 output->type == kTfLiteInt16);
124 } else {
125 // Requantize use case.
126 if (input->type == kTfLiteInt16) {
127 TF_LITE_ENSURE(context, output->type == kTfLiteInt8 ||
128 output->type == kTfLiteInt16 ||
129 output->type == kTfLiteInt32);
130 } else if (input->type == kTfLiteInt32) {
131 TF_LITE_ENSURE(
132 context, output->type == kTfLiteInt8 || output->type == kTfLiteInt16);
133 } else {
134 TF_LITE_ENSURE(context,
135 input->type == kTfLiteInt8 || input->type == kTfLiteUInt8);
136 TF_LITE_ENSURE(
137 context, output->type == kTfLiteUInt8 || output->type == kTfLiteInt8);
138 }
139 const double effective_output_scale =
140 static_cast<double>(input->params.scale) /
141 static_cast<double>(output->params.scale);
142 QuantizeMultiplier(effective_output_scale, &data->output_multiplier,
143 &data->output_shift);
144 }
145
146 if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
147 TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
148 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
149 }
150
151 return context->ResizeTensor(context, output,
152 TfLiteIntArrayCopy(input->dims));
153}
154
155template <KernelType kernel_type>
156TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
157 OpData* data = static_cast<OpData*>(node->user_data);
158
159 const TfLiteTensor* input;
160 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
161 TfLiteTensor* output;
162 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
163
164 const RuntimeShape input_shape = GetTensorShape(input);
165 const RuntimeShape output_shape = GetTensorShape(output);
166
167 switch (input->type) {
168 case kTfLiteFloat32: {
169 // Float to int8, uint8, int16.
170 const float* input_data = GetTensorData<float>(input);
171
172 if (IsQuantizedPerChannel(output)) {
173 // Per-channel quantization: one scale and zero point for each channel.
174 const auto* quantization_params =
175 reinterpret_cast<const TfLiteAffineQuantization*>(
176 output->quantization.params);
177 PerChannelQuantizationParams per_channel_op_params;
178 per_channel_op_params.quantized_dimension =
179 quantization_params->quantized_dimension;
180 per_channel_op_params.scale = quantization_params->scale->data;
181 per_channel_op_params.zero_point =
182 quantization_params->zero_point->data;
183
184 switch (output->type) {
185 case kTfLiteInt8:
186 reference_ops::PerChannelQuantize(
187 per_channel_op_params, input_shape, input_data, output_shape,
188 GetTensorData<int8_t>(output));
189 return kTfLiteOk;
190 case kTfLiteUInt8:
191 reference_ops::PerChannelQuantize(
192 per_channel_op_params, input_shape, input_data, output_shape,
193 GetTensorData<uint8_t>(output));
194 return kTfLiteOk;
195 case kTfLiteInt16:
196 reference_ops::PerChannelQuantize(
197 per_channel_op_params, input_shape, input_data, output_shape,
198 GetTensorData<int16_t>(output));
199 return kTfLiteOk;
200 default:
201 ReportError(context, input->type, output->type);
202 return kTfLiteError;
203 }
204 } else {
205 // Per-node quantization: single scale and zero point for all channels.
206 tflite::QuantizationParams op_params;
207 op_params.zero_point = output->params.zero_point;
208 op_params.scale = output->params.scale;
209
210 switch (output->type) {
211 case kTfLiteInt8:
212 AffineQuantize<kernel_type>(op_params, input_shape, input_data,
213 output_shape,
214 GetTensorData<int8_t>(output));
215 return kTfLiteOk;
216 case kTfLiteUInt8:
217 AffineQuantize<kernel_type>(op_params, input_shape, input_data,
218 output_shape,
219 GetTensorData<uint8_t>(output));
220 return kTfLiteOk;
221 case kTfLiteInt16:
222 AffineQuantize<kernel_type>(op_params, input_shape, input_data,
223 output_shape,
224 GetTensorData<int16_t>(output));
225 return kTfLiteOk;
226 default:
227 ReportError(context, input->type, output->type);
228 return kTfLiteError;
229 }
230 }
231 }
232 // This case is not supported by the converter or other TFLite tools. The
233 // only use case is for applications that take quantized int32 inference
234 // inputs.
235 case kTfLiteInt32: {
236 // int32 to int8 or int16.
237 switch (output->type) {
238 case kTfLiteInt8:
239 Requantize<kernel_type>(GetTensorData<int32_t>(input),
240 MatchingFlatSize(input_shape, output_shape),
241 data->output_multiplier, data->output_shift,
242 input->params.zero_point,
243 output->params.zero_point,
244 GetTensorData<int8_t>(output));
245 return kTfLiteOk;
246 case kTfLiteInt16:
247 Requantize<kernel_type>(GetTensorData<int32_t>(input),
248 MatchingFlatSize(input_shape, output_shape),
249 data->output_multiplier, data->output_shift,
250 input->params.zero_point,
251 output->params.zero_point,
252 GetTensorData<int16_t>(output));
253 return kTfLiteOk;
254 default:
255 ReportError(context, input->type, output->type);
256 return kTfLiteError;
257 }
258 }
259 case kTfLiteInt16: {
260 // int16 to int8 or int16.
261 switch (output->type) {
262 case kTfLiteInt8:
263 Requantize<kernel_type>(GetTensorData<int16_t>(input),
264 MatchingFlatSize(input_shape, output_shape),
265 data->output_multiplier, data->output_shift,
266 input->params.zero_point,
267 output->params.zero_point,
268 GetTensorData<int8_t>(output));
269 return kTfLiteOk;
270 case kTfLiteInt16:
271 Requantize<kernel_type>(GetTensorData<int16_t>(input),
272 MatchingFlatSize(input_shape, output_shape),
273 data->output_multiplier, data->output_shift,
274 input->params.zero_point,
275 output->params.zero_point,
276 GetTensorData<int16_t>(output));
277 return kTfLiteOk;
278 case kTfLiteInt32:
279 // This case is not supported by the converter or other TFLite tools.
280 // The only use case is for applications that take quantized int32
281 // inference outputs.
282 Requantize<kernel_type>(GetTensorData<int16_t>(input),
283 MatchingFlatSize(input_shape, output_shape),
284 data->output_multiplier, data->output_shift,
285 input->params.zero_point,
286 output->params.zero_point,
287 GetTensorData<int32_t>(output));
288 return kTfLiteOk;
289 default:
290 ReportError(context, input->type, output->type);
291 return kTfLiteError;
292 }
293 }
294 case kTfLiteInt8: {
295 // int8 to int8, uint8.
296 const int32_t size = MatchingFlatSize(input_shape, output_shape);
297 const int8_t* input_data = GetTensorData<int8_t>(input);
298 switch (output->type) {
299 case kTfLiteInt8:
300 Requantize<kernel_type>(input_data, size, data->output_multiplier,
301 data->output_shift, input->params.zero_point,
302 output->params.zero_point,
303 GetTensorData<int8_t>(output));
304 return kTfLiteOk;
305 case kTfLiteUInt8:
306 Requantize<kernel_type>(input_data, size, data->output_multiplier,
307 data->output_shift, input->params.zero_point,
308 output->params.zero_point,
309 GetTensorData<uint8_t>(output));
310 return kTfLiteOk;
311 default:
312 ReportError(context, input->type, output->type);
313 return kTfLiteError;
314 }
315 }
316 case kTfLiteUInt8: {
317 // uint8 to int8, uint8.
318 const int32_t size = MatchingFlatSize(input_shape, output_shape);
319 const uint8_t* input_data = GetTensorData<uint8_t>(input);
320 switch (output->type) {
321 case kTfLiteInt8:
322 Requantize<kernel_type>(input_data, size, data->output_multiplier,
323 data->output_shift, input->params.zero_point,
324 output->params.zero_point,
325 GetTensorData<int8_t>(output));
326 return kTfLiteOk;
327 case kTfLiteUInt8:
328 Requantize<kernel_type>(input_data, size, data->output_multiplier,
329 data->output_shift, input->params.zero_point,
330 output->params.zero_point,
331 GetTensorData<uint8_t>(output));
332 return kTfLiteOk;
333 default:
334 ReportError(context, input->type, output->type);
335 return kTfLiteError;
336 }
337 }
338 default:
339 ReportError(context, input->type, output->type);
340 return kTfLiteError;
341 }
342}
343
344} // namespace quantize
345
346// This Op (QUANTIZE) quantizes the input and produces quantized output.
347// The input can be either float or quantized. If the input is float,
348// AffineQuantize takes scale and zero point and quantize the float value to
349// quantized output, in int8 or uint8 format. If the input is quantized value,
350// the op requantize the input (of a certain type, with a given scale and zero
351// point) to the output of the same or different type with a same or different
352// scale and zero point.
353TfLiteRegistration* Register_QUANTIZE_OPT() {
354 static TfLiteRegistration r = {quantize::Init, quantize::Free,
355 quantize::Prepare,
356 quantize::Eval<quantize::kGenericOptimized>};
357 return &r;
358}
359
360TfLiteRegistration* Register_QUANTIZE_REF() {
361 static TfLiteRegistration r = {quantize::Init, quantize::Free,
362 quantize::Prepare,
363 quantize::Eval<quantize::kReference>};
364 return &r;
365}
366
367TfLiteRegistration* Register_QUANTIZE() { return Register_QUANTIZE_OPT(); }
368
369} // namespace builtin
370} // namespace ops
371} // namespace tflite
372