1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <math.h> |
16 | #include <stddef.h> |
17 | #include <stdlib.h> |
18 | |
19 | #include <algorithm> |
20 | #include <cstdint> |
21 | #include <numeric> |
22 | #include <vector> |
23 | |
24 | #include "flatbuffers/flexbuffers.h" // from @flatbuffers |
25 | #include "tensorflow/lite/c/common.h" |
26 | #include "tensorflow/lite/kernels/dequantize.h" |
27 | #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" |
28 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
29 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" |
30 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
31 | #include "tensorflow/lite/kernels/internal/tensor.h" |
32 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
33 | #include "tensorflow/lite/kernels/kernel_util.h" |
34 | |
35 | namespace tflite { |
36 | namespace ops { |
37 | namespace custom { |
38 | namespace numeric_verify { |
39 | |
40 | static constexpr const char kToleranceStr[] = "tolerance" ; |
41 | static constexpr const char kLogIfFailedStr[] = "log_if_failed" ; |
42 | static constexpr const int kTemporaryDequantizedTensor = 0; |
43 | static constexpr const int kOutputTensor = 0; |
44 | |
45 | struct OpContext { |
46 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
47 | input = GetInput(context, node, 0); |
48 | ref = GetInput(context, node, 1); |
49 | output = GetOutput(context, node, 0); |
50 | } |
51 | const TfLiteTensor* input; |
52 | const TfLiteTensor* ref; |
53 | TfLiteTensor* output; |
54 | }; |
55 | |
56 | const int kTensorNotAllocated = -1; |
57 | |
58 | struct OpData { |
59 | // The percentage of the tensor value range. Must be a number less than 1.0. |
60 | float tolerance; |
61 | // This boolean value is only used when the input tensor is constant. |
62 | bool float_input_initialized; |
63 | int cache_tensor_id = kTensorNotAllocated; |
64 | // This boolean value is for controlling the behavior of numeric verify op. |
65 | bool log_if_failed; |
66 | }; |
67 | |
68 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
69 | auto* op_data = new OpData(); |
70 | op_data->float_input_initialized = false; |
71 | |
72 | const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer); |
73 | const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap(); |
74 | const float tolerance = m[kToleranceStr].AsFloat(); |
75 | const bool log_if_failed = m[kLogIfFailedStr].AsBool(); |
76 | op_data->tolerance = tolerance; |
77 | op_data->log_if_failed = log_if_failed; |
78 | |
79 | return op_data; |
80 | } |
81 | |
82 | void Free(TfLiteContext* context, void* buffer) { |
83 | delete reinterpret_cast<OpData*>(buffer); |
84 | } |
85 | |
86 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
87 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
88 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
89 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
90 | |
91 | OpContext op_context(context, node); |
92 | |
93 | TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || |
94 | op_context.input->type == kTfLiteInt8 || |
95 | op_context.input->type == kTfLiteInt16 || |
96 | op_context.input->type == kTfLiteFloat16); |
97 | TF_LITE_ENSURE(context, op_context.ref->type == kTfLiteFloat32); |
98 | |
99 | // Allocate tensor to store the dequantized inputs. |
100 | if (op_data->cache_tensor_id == kTensorNotAllocated) { |
101 | TF_LITE_ENSURE_OK( |
102 | context, context->AddTensors(context, 1, &op_data->cache_tensor_id)); |
103 | } |
104 | |
105 | TfLiteIntArrayFree(node->temporaries); |
106 | node->temporaries = TfLiteIntArrayCreate(1); |
107 | node->temporaries->data[0] = op_data->cache_tensor_id; |
108 | |
109 | TfLiteTensor* dequantized; |
110 | TF_LITE_ENSURE_OK(context, |
111 | GetTemporarySafe(context, node, kTemporaryDequantizedTensor, |
112 | &dequantized)); |
113 | dequantized->type = op_context.ref->type; |
114 | dequantized->allocation_type = kTfLiteDynamic; |
115 | |
116 | TF_LITE_ENSURE_OK(context, context->ResizeTensor( |
117 | context, dequantized, |
118 | TfLiteIntArrayCopy(op_context.input->dims))); |
119 | |
120 | TF_LITE_ENSURE_OK( |
121 | context, GetOutputSafe(context, node, kOutputTensor, &op_context.output)); |
122 | op_context.output->type = kTfLiteFloat32; |
123 | op_context.output->allocation_type = kTfLiteArenaRwPersistent; |
124 | return context->ResizeTensor(context, op_context.output, |
125 | TfLiteIntArrayCopy(op_context.input->dims)); |
126 | } |
127 | |
128 | static int32_t GetQuantizedValue(const OpContext& op_context, int index) { |
129 | switch (op_context.input->type) { |
130 | case kTfLiteUInt8: |
131 | return GetTensorData<uint8_t>(op_context.input)[index]; |
132 | case kTfLiteInt8: |
133 | return GetTensorData<int8_t>(op_context.input)[index]; |
134 | case kTfLiteInt16: |
135 | return GetTensorData<int16_t>(op_context.input)[index]; |
136 | default: |
137 | return 0; |
138 | } |
139 | } |
140 | |
141 | template <builtin::dequantize::KernelType kernel_type> |
142 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
143 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
144 | OpContext op_context(context, node); |
145 | if (IsConstantTensor(op_context.input) && op_data->float_input_initialized) { |
146 | return kTfLiteOk; |
147 | } |
148 | |
149 | // Dequantize the input |
150 | TfLiteTensor* dequantized; |
151 | TF_LITE_ENSURE_OK(context, |
152 | GetTemporarySafe(context, node, kTemporaryDequantizedTensor, |
153 | &dequantized)); |
154 | auto status = builtin::dequantize::DequantizeImpl<kernel_type>( |
155 | context, node, op_context.input, dequantized); |
156 | if (status != kTfLiteOk) { |
157 | return status; |
158 | } |
159 | |
160 | if (IsConstantTensor(op_context.input)) { |
161 | op_data->float_input_initialized = true; |
162 | } |
163 | |
164 | TF_LITE_ENSURE_OK( |
165 | context, GetOutputSafe(context, node, kOutputTensor, &op_context.output)); |
166 | auto output_data = GetTensorData<float>(op_context.output); |
167 | |
168 | // If log_if_failed is on, calculate differences between float and |
169 | // quantized values, their statistics and output logs. |
170 | // Throw errors if any diff greater than tolerance exists. |
171 | const int n = NumElements(dequantized); |
172 | if (op_data->log_if_failed && op_data->tolerance >= 0.1) { |
173 | // Verify the dequantized output. |
174 | auto max_diff = op_data->tolerance * op_context.input->params.scale; |
175 | for (int i = 0; i < n; ++i) { |
176 | int32_t value = GetQuantizedValue(op_context, i); |
177 | float dequant = GetTensorData<float>(dequantized)[i]; |
178 | float reference = GetTensorData<float>(op_context.ref)[i]; |
179 | output_data[i] = dequant - reference; |
180 | float diff = std::abs(output_data[i]); |
181 | if (diff > max_diff) { |
182 | TF_LITE_KERNEL_LOG( |
183 | context, |
184 | "Mismatch: %f is quantized to %d with (%f, %d). " |
185 | "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n" , |
186 | reference, value, op_context.input->params.scale, |
187 | op_context.input->params.zero_point, reference, dequant, diff, |
188 | max_diff, op_data->tolerance); |
189 | return kTfLiteError; |
190 | } |
191 | } |
192 | } else { |
193 | // If tolerance is small or log_if_failed is off, then we only care about |
194 | // statistics. |
195 | // These statistics logging was added to identify some errors in practice. |
196 | std::vector<double> diffs, temp; |
197 | diffs.reserve(n); |
198 | temp.reserve(n); |
199 | diffs.resize(n); |
200 | temp.resize(n); |
201 | for (int i = 0; i < n; ++i) { |
202 | float dequant = GetTensorData<float>(dequantized)[i]; |
203 | float reference = GetTensorData<float>(op_context.ref)[i]; |
204 | diffs[i] = static_cast<double>(dequant - reference); |
205 | output_data[i] = dequant - reference; |
206 | } |
207 | double mean = |
208 | std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size(); |
209 | double max_diff = 0.0; |
210 | std::transform(diffs.begin(), diffs.end(), temp.begin(), |
211 | [mean, &max_diff](double x) { |
212 | max_diff = std::max(max_diff, std::abs(x)); |
213 | return x - mean; |
214 | }); |
215 | double sq_sum = |
216 | std::inner_product(temp.begin(), temp.end(), temp.begin(), 0.0); |
217 | double std = std::sqrt(sq_sum / diffs.size()); |
218 | TF_LITE_KERNEL_LOG( |
219 | context, |
220 | "std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n" , std, |
221 | mean, max_diff, op_context.input->params.scale, |
222 | op_context.input->params.zero_point); |
223 | } |
224 | return kTfLiteOk; |
225 | } |
226 | |
227 | } // namespace numeric_verify |
228 | |
229 | TfLiteRegistration* Register_NUMERIC_VERIFY_OPT() { |
230 | static TfLiteRegistration r = { |
231 | numeric_verify::Init, numeric_verify::Free, numeric_verify::Prepare, |
232 | numeric_verify::Eval<builtin::dequantize::kGenericOptimized>}; |
233 | return &r; |
234 | } |
235 | |
236 | TfLiteRegistration* Register_NUMERIC_VERIFY_REF() { |
237 | static TfLiteRegistration r = { |
238 | numeric_verify::Init, numeric_verify::Free, numeric_verify::Prepare, |
239 | numeric_verify::Eval<builtin::dequantize::kReference>}; |
240 | return &r; |
241 | } |
242 | |
243 | TfLiteRegistration* Register_NUMERIC_VERIFY() { |
244 | #ifdef USE_NEON |
245 | return Register_NUMERIC_VERIFY_OPT(); |
246 | #else |
247 | return Register_NUMERIC_VERIFY_REF(); |
248 | #endif |
249 | } |
250 | |
251 | } // namespace custom |
252 | } // namespace ops |
253 | } // namespace tflite |
254 | |