1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/quantize.h" |
16 | |
17 | #include <cstddef> |
18 | #include <cstdint> |
19 | |
20 | #include "tensorflow/lite/c/common.h" |
21 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
22 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
23 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
24 | #include "tensorflow/lite/kernels/internal/reference/requantize.h" |
25 | #include "tensorflow/lite/kernels/internal/tensor.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
27 | #include "tensorflow/lite/kernels/internal/types.h" |
28 | #include "tensorflow/lite/kernels/kernel_util.h" |
29 | |
30 | namespace tflite { |
31 | namespace ops { |
32 | namespace builtin { |
33 | namespace quantize { |
34 | |
35 | // This file has two implementation of Quantize. |
36 | enum KernelType { |
37 | kReference, |
38 | kGenericOptimized, |
39 | }; |
40 | |
41 | struct OpData { |
42 | int32_t output_multiplier; |
43 | int output_shift; |
44 | }; |
45 | |
46 | inline bool IsQuantizedPerChannel(const TfLiteTensor* input) { |
47 | if (input->quantization.type == kTfLiteAffineQuantization && |
48 | input->quantization.params) { |
49 | auto* quant_params = |
50 | reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params); |
51 | return (quant_params->scale && quant_params->scale->size > 1); |
52 | } |
53 | return false; |
54 | } |
55 | |
56 | namespace { |
57 | template <KernelType kernel_type, typename output_type> |
58 | static inline void AffineQuantize(const tflite::QuantizationParams& op_params, |
59 | const RuntimeShape& input_shape, |
60 | const float* input_data, |
61 | const RuntimeShape& output_shape, |
62 | output_type* output_data) { |
63 | if (kernel_type == kReference) { |
64 | reference_ops::AffineQuantize(op_params, input_shape, input_data, |
65 | output_shape, output_data); |
66 | } else { |
67 | optimized_ops::AffineQuantize(op_params, input_shape, input_data, |
68 | output_shape, output_data); |
69 | } |
70 | } |
71 | |
72 | template <KernelType kernel_type, typename input_type, typename output_type> |
73 | static inline void Requantize(const input_type* input_data, int32_t size, |
74 | int32_t effective_scale_multiplier, |
75 | int32_t effective_scale_shift, |
76 | int32_t input_zeropoint, int32_t output_zeropoint, |
77 | output_type* output_data) { |
78 | if (kernel_type == kReference) { |
79 | reference_ops::Requantize(input_data, size, effective_scale_multiplier, |
80 | effective_scale_shift, input_zeropoint, |
81 | output_zeropoint, output_data); |
82 | } else { |
83 | optimized_ops::Requantize(input_data, size, effective_scale_multiplier, |
84 | effective_scale_shift, input_zeropoint, |
85 | output_zeropoint, output_data); |
86 | } |
87 | } |
88 | |
89 | void ReportError(TfLiteContext* context, TfLiteType input_type, |
90 | TfLiteType output_type) { |
91 | TF_LITE_KERNEL_LOG( |
92 | context, "Input type %s with Output type %s is not currently supported." , |
93 | TfLiteTypeGetName(input_type), TfLiteTypeGetName(output_type)); |
94 | } |
95 | } // namespace |
96 | |
97 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
98 | return new OpData; |
99 | } |
100 | |
101 | void Free(TfLiteContext* context, void* buffer) { |
102 | delete static_cast<OpData*>(buffer); |
103 | } |
104 | |
105 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
106 | OpData* data = static_cast<OpData*>(node->user_data); |
107 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
108 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
109 | |
110 | const TfLiteTensor* input; |
111 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
112 | TfLiteTensor* output; |
113 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
114 | |
115 | // Currently this only support affine quantization. |
116 | TF_LITE_ENSURE_EQ(context, output->quantization.type, |
117 | kTfLiteAffineQuantization); |
118 | |
119 | if (input->type == kTfLiteFloat32) { |
120 | // Quantize use case. |
121 | TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 || |
122 | output->type == kTfLiteInt8 || |
123 | output->type == kTfLiteInt16); |
124 | } else { |
125 | // Requantize use case. |
126 | if (input->type == kTfLiteInt16) { |
127 | TF_LITE_ENSURE(context, output->type == kTfLiteInt8 || |
128 | output->type == kTfLiteInt16 || |
129 | output->type == kTfLiteInt32); |
130 | } else if (input->type == kTfLiteInt32) { |
131 | TF_LITE_ENSURE( |
132 | context, output->type == kTfLiteInt8 || output->type == kTfLiteInt16); |
133 | } else { |
134 | TF_LITE_ENSURE(context, |
135 | input->type == kTfLiteInt8 || input->type == kTfLiteUInt8); |
136 | TF_LITE_ENSURE( |
137 | context, output->type == kTfLiteUInt8 || output->type == kTfLiteInt8); |
138 | } |
139 | const double effective_output_scale = |
140 | static_cast<double>(input->params.scale) / |
141 | static_cast<double>(output->params.scale); |
142 | QuantizeMultiplier(effective_output_scale, &data->output_multiplier, |
143 | &data->output_shift); |
144 | } |
145 | |
146 | if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) { |
147 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
148 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
149 | } |
150 | |
151 | return context->ResizeTensor(context, output, |
152 | TfLiteIntArrayCopy(input->dims)); |
153 | } |
154 | |
155 | template <KernelType kernel_type> |
156 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
157 | OpData* data = static_cast<OpData*>(node->user_data); |
158 | |
159 | const TfLiteTensor* input; |
160 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
161 | TfLiteTensor* output; |
162 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
163 | |
164 | const RuntimeShape input_shape = GetTensorShape(input); |
165 | const RuntimeShape output_shape = GetTensorShape(output); |
166 | |
167 | switch (input->type) { |
168 | case kTfLiteFloat32: { |
169 | // Float to int8, uint8, int16. |
170 | const float* input_data = GetTensorData<float>(input); |
171 | |
172 | if (IsQuantizedPerChannel(output)) { |
173 | // Per-channel quantization: one scale and zero point for each channel. |
174 | const auto* quantization_params = |
175 | reinterpret_cast<const TfLiteAffineQuantization*>( |
176 | output->quantization.params); |
177 | PerChannelQuantizationParams per_channel_op_params; |
178 | per_channel_op_params.quantized_dimension = |
179 | quantization_params->quantized_dimension; |
180 | per_channel_op_params.scale = quantization_params->scale->data; |
181 | per_channel_op_params.zero_point = |
182 | quantization_params->zero_point->data; |
183 | |
184 | switch (output->type) { |
185 | case kTfLiteInt8: |
186 | reference_ops::PerChannelQuantize( |
187 | per_channel_op_params, input_shape, input_data, output_shape, |
188 | GetTensorData<int8_t>(output)); |
189 | return kTfLiteOk; |
190 | case kTfLiteUInt8: |
191 | reference_ops::PerChannelQuantize( |
192 | per_channel_op_params, input_shape, input_data, output_shape, |
193 | GetTensorData<uint8_t>(output)); |
194 | return kTfLiteOk; |
195 | case kTfLiteInt16: |
196 | reference_ops::PerChannelQuantize( |
197 | per_channel_op_params, input_shape, input_data, output_shape, |
198 | GetTensorData<int16_t>(output)); |
199 | return kTfLiteOk; |
200 | default: |
201 | ReportError(context, input->type, output->type); |
202 | return kTfLiteError; |
203 | } |
204 | } else { |
205 | // Per-node quantization: single scale and zero point for all channels. |
206 | tflite::QuantizationParams op_params; |
207 | op_params.zero_point = output->params.zero_point; |
208 | op_params.scale = output->params.scale; |
209 | |
210 | switch (output->type) { |
211 | case kTfLiteInt8: |
212 | AffineQuantize<kernel_type>(op_params, input_shape, input_data, |
213 | output_shape, |
214 | GetTensorData<int8_t>(output)); |
215 | return kTfLiteOk; |
216 | case kTfLiteUInt8: |
217 | AffineQuantize<kernel_type>(op_params, input_shape, input_data, |
218 | output_shape, |
219 | GetTensorData<uint8_t>(output)); |
220 | return kTfLiteOk; |
221 | case kTfLiteInt16: |
222 | AffineQuantize<kernel_type>(op_params, input_shape, input_data, |
223 | output_shape, |
224 | GetTensorData<int16_t>(output)); |
225 | return kTfLiteOk; |
226 | default: |
227 | ReportError(context, input->type, output->type); |
228 | return kTfLiteError; |
229 | } |
230 | } |
231 | } |
232 | // This case is not supported by the converter or other TFLite tools. The |
233 | // only use case is for applications that take quantized int32 inference |
234 | // inputs. |
235 | case kTfLiteInt32: { |
236 | // int32 to int8 or int16. |
237 | switch (output->type) { |
238 | case kTfLiteInt8: |
239 | Requantize<kernel_type>(GetTensorData<int32_t>(input), |
240 | MatchingFlatSize(input_shape, output_shape), |
241 | data->output_multiplier, data->output_shift, |
242 | input->params.zero_point, |
243 | output->params.zero_point, |
244 | GetTensorData<int8_t>(output)); |
245 | return kTfLiteOk; |
246 | case kTfLiteInt16: |
247 | Requantize<kernel_type>(GetTensorData<int32_t>(input), |
248 | MatchingFlatSize(input_shape, output_shape), |
249 | data->output_multiplier, data->output_shift, |
250 | input->params.zero_point, |
251 | output->params.zero_point, |
252 | GetTensorData<int16_t>(output)); |
253 | return kTfLiteOk; |
254 | default: |
255 | ReportError(context, input->type, output->type); |
256 | return kTfLiteError; |
257 | } |
258 | } |
259 | case kTfLiteInt16: { |
260 | // int16 to int8 or int16. |
261 | switch (output->type) { |
262 | case kTfLiteInt8: |
263 | Requantize<kernel_type>(GetTensorData<int16_t>(input), |
264 | MatchingFlatSize(input_shape, output_shape), |
265 | data->output_multiplier, data->output_shift, |
266 | input->params.zero_point, |
267 | output->params.zero_point, |
268 | GetTensorData<int8_t>(output)); |
269 | return kTfLiteOk; |
270 | case kTfLiteInt16: |
271 | Requantize<kernel_type>(GetTensorData<int16_t>(input), |
272 | MatchingFlatSize(input_shape, output_shape), |
273 | data->output_multiplier, data->output_shift, |
274 | input->params.zero_point, |
275 | output->params.zero_point, |
276 | GetTensorData<int16_t>(output)); |
277 | return kTfLiteOk; |
278 | case kTfLiteInt32: |
279 | // This case is not supported by the converter or other TFLite tools. |
280 | // The only use case is for applications that take quantized int32 |
281 | // inference outputs. |
282 | Requantize<kernel_type>(GetTensorData<int16_t>(input), |
283 | MatchingFlatSize(input_shape, output_shape), |
284 | data->output_multiplier, data->output_shift, |
285 | input->params.zero_point, |
286 | output->params.zero_point, |
287 | GetTensorData<int32_t>(output)); |
288 | return kTfLiteOk; |
289 | default: |
290 | ReportError(context, input->type, output->type); |
291 | return kTfLiteError; |
292 | } |
293 | } |
294 | case kTfLiteInt8: { |
295 | // int8 to int8, uint8. |
296 | const int32_t size = MatchingFlatSize(input_shape, output_shape); |
297 | const int8_t* input_data = GetTensorData<int8_t>(input); |
298 | switch (output->type) { |
299 | case kTfLiteInt8: |
300 | Requantize<kernel_type>(input_data, size, data->output_multiplier, |
301 | data->output_shift, input->params.zero_point, |
302 | output->params.zero_point, |
303 | GetTensorData<int8_t>(output)); |
304 | return kTfLiteOk; |
305 | case kTfLiteUInt8: |
306 | Requantize<kernel_type>(input_data, size, data->output_multiplier, |
307 | data->output_shift, input->params.zero_point, |
308 | output->params.zero_point, |
309 | GetTensorData<uint8_t>(output)); |
310 | return kTfLiteOk; |
311 | default: |
312 | ReportError(context, input->type, output->type); |
313 | return kTfLiteError; |
314 | } |
315 | } |
316 | case kTfLiteUInt8: { |
317 | // uint8 to int8, uint8. |
318 | const int32_t size = MatchingFlatSize(input_shape, output_shape); |
319 | const uint8_t* input_data = GetTensorData<uint8_t>(input); |
320 | switch (output->type) { |
321 | case kTfLiteInt8: |
322 | Requantize<kernel_type>(input_data, size, data->output_multiplier, |
323 | data->output_shift, input->params.zero_point, |
324 | output->params.zero_point, |
325 | GetTensorData<int8_t>(output)); |
326 | return kTfLiteOk; |
327 | case kTfLiteUInt8: |
328 | Requantize<kernel_type>(input_data, size, data->output_multiplier, |
329 | data->output_shift, input->params.zero_point, |
330 | output->params.zero_point, |
331 | GetTensorData<uint8_t>(output)); |
332 | return kTfLiteOk; |
333 | default: |
334 | ReportError(context, input->type, output->type); |
335 | return kTfLiteError; |
336 | } |
337 | } |
338 | default: |
339 | ReportError(context, input->type, output->type); |
340 | return kTfLiteError; |
341 | } |
342 | } |
343 | |
344 | } // namespace quantize |
345 | |
346 | // This Op (QUANTIZE) quantizes the input and produces quantized output. |
347 | // The input can be either float or quantized. If the input is float, |
348 | // AffineQuantize takes scale and zero point and quantize the float value to |
349 | // quantized output, in int8 or uint8 format. If the input is quantized value, |
350 | // the op requantize the input (of a certain type, with a given scale and zero |
351 | // point) to the output of the same or different type with a same or different |
352 | // scale and zero point. |
353 | TfLiteRegistration* Register_QUANTIZE_OPT() { |
354 | static TfLiteRegistration r = {quantize::Init, quantize::Free, |
355 | quantize::Prepare, |
356 | quantize::Eval<quantize::kGenericOptimized>}; |
357 | return &r; |
358 | } |
359 | |
360 | TfLiteRegistration* Register_QUANTIZE_REF() { |
361 | static TfLiteRegistration r = {quantize::Init, quantize::Free, |
362 | quantize::Prepare, |
363 | quantize::Eval<quantize::kReference>}; |
364 | return &r; |
365 | } |
366 | |
367 | TfLiteRegistration* Register_QUANTIZE() { return Register_QUANTIZE_OPT(); } |
368 | |
369 | } // namespace builtin |
370 | } // namespace ops |
371 | } // namespace tflite |
372 | |