1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <math.h>
16#include <stddef.h>
17#include <stdlib.h>
18
19#include <algorithm>
20#include <cstdint>
21#include <numeric>
22#include <vector>
23
24#include "flatbuffers/flexbuffers.h" // from @flatbuffers
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/kernels/dequantize.h"
27#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
28#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
29#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
30#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31#include "tensorflow/lite/kernels/internal/tensor.h"
32#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33#include "tensorflow/lite/kernels/kernel_util.h"
34
35namespace tflite {
36namespace ops {
37namespace custom {
38namespace numeric_verify {
39
40static constexpr const char kToleranceStr[] = "tolerance";
41static constexpr const char kLogIfFailedStr[] = "log_if_failed";
42static constexpr const int kTemporaryDequantizedTensor = 0;
43static constexpr const int kOutputTensor = 0;
44
45struct OpContext {
46 OpContext(TfLiteContext* context, TfLiteNode* node) {
47 input = GetInput(context, node, 0);
48 ref = GetInput(context, node, 1);
49 output = GetOutput(context, node, 0);
50 }
51 const TfLiteTensor* input;
52 const TfLiteTensor* ref;
53 TfLiteTensor* output;
54};
55
56const int kTensorNotAllocated = -1;
57
58struct OpData {
59 // The percentage of the tensor value range. Must be a number less than 1.0.
60 float tolerance;
61 // This boolean value is only used when the input tensor is constant.
62 bool float_input_initialized;
63 int cache_tensor_id = kTensorNotAllocated;
64 // This boolean value is for controlling the behavior of numeric verify op.
65 bool log_if_failed;
66};
67
68void* Init(TfLiteContext* context, const char* buffer, size_t length) {
69 auto* op_data = new OpData();
70 op_data->float_input_initialized = false;
71
72 const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
73 const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
74 const float tolerance = m[kToleranceStr].AsFloat();
75 const bool log_if_failed = m[kLogIfFailedStr].AsBool();
76 op_data->tolerance = tolerance;
77 op_data->log_if_failed = log_if_failed;
78
79 return op_data;
80}
81
82void Free(TfLiteContext* context, void* buffer) {
83 delete reinterpret_cast<OpData*>(buffer);
84}
85
86TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
87 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
88 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
89 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
90
91 OpContext op_context(context, node);
92
93 TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
94 op_context.input->type == kTfLiteInt8 ||
95 op_context.input->type == kTfLiteInt16 ||
96 op_context.input->type == kTfLiteFloat16);
97 TF_LITE_ENSURE(context, op_context.ref->type == kTfLiteFloat32);
98
99 // Allocate tensor to store the dequantized inputs.
100 if (op_data->cache_tensor_id == kTensorNotAllocated) {
101 TF_LITE_ENSURE_OK(
102 context, context->AddTensors(context, 1, &op_data->cache_tensor_id));
103 }
104
105 TfLiteIntArrayFree(node->temporaries);
106 node->temporaries = TfLiteIntArrayCreate(1);
107 node->temporaries->data[0] = op_data->cache_tensor_id;
108
109 TfLiteTensor* dequantized;
110 TF_LITE_ENSURE_OK(context,
111 GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
112 &dequantized));
113 dequantized->type = op_context.ref->type;
114 dequantized->allocation_type = kTfLiteDynamic;
115
116 TF_LITE_ENSURE_OK(context, context->ResizeTensor(
117 context, dequantized,
118 TfLiteIntArrayCopy(op_context.input->dims)));
119
120 TF_LITE_ENSURE_OK(
121 context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
122 op_context.output->type = kTfLiteFloat32;
123 op_context.output->allocation_type = kTfLiteArenaRwPersistent;
124 return context->ResizeTensor(context, op_context.output,
125 TfLiteIntArrayCopy(op_context.input->dims));
126}
127
128static int32_t GetQuantizedValue(const OpContext& op_context, int index) {
129 switch (op_context.input->type) {
130 case kTfLiteUInt8:
131 return GetTensorData<uint8_t>(op_context.input)[index];
132 case kTfLiteInt8:
133 return GetTensorData<int8_t>(op_context.input)[index];
134 case kTfLiteInt16:
135 return GetTensorData<int16_t>(op_context.input)[index];
136 default:
137 return 0;
138 }
139}
140
141template <builtin::dequantize::KernelType kernel_type>
142TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
143 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
144 OpContext op_context(context, node);
145 if (IsConstantTensor(op_context.input) && op_data->float_input_initialized) {
146 return kTfLiteOk;
147 }
148
149 // Dequantize the input
150 TfLiteTensor* dequantized;
151 TF_LITE_ENSURE_OK(context,
152 GetTemporarySafe(context, node, kTemporaryDequantizedTensor,
153 &dequantized));
154 auto status = builtin::dequantize::DequantizeImpl<kernel_type>(
155 context, node, op_context.input, dequantized);
156 if (status != kTfLiteOk) {
157 return status;
158 }
159
160 if (IsConstantTensor(op_context.input)) {
161 op_data->float_input_initialized = true;
162 }
163
164 TF_LITE_ENSURE_OK(
165 context, GetOutputSafe(context, node, kOutputTensor, &op_context.output));
166 auto output_data = GetTensorData<float>(op_context.output);
167
168 // If log_if_failed is on, calculate differences between float and
169 // quantized values, their statistics and output logs.
170 // Throw errors if any diff greater than tolerance exists.
171 const int n = NumElements(dequantized);
172 if (op_data->log_if_failed && op_data->tolerance >= 0.1) {
173 // Verify the dequantized output.
174 auto max_diff = op_data->tolerance * op_context.input->params.scale;
175 for (int i = 0; i < n; ++i) {
176 int32_t value = GetQuantizedValue(op_context, i);
177 float dequant = GetTensorData<float>(dequantized)[i];
178 float reference = GetTensorData<float>(op_context.ref)[i];
179 output_data[i] = dequant - reference;
180 float diff = std::abs(output_data[i]);
181 if (diff > max_diff) {
182 TF_LITE_KERNEL_LOG(
183 context,
184 "Mismatch: %f is quantized to %d with (%f, %d). "
185 "abs(%f - %f) = %f > %f (tolerance) range percentage %f.\n",
186 reference, value, op_context.input->params.scale,
187 op_context.input->params.zero_point, reference, dequant, diff,
188 max_diff, op_data->tolerance);
189 return kTfLiteError;
190 }
191 }
192 } else {
193 // If tolerance is small or log_if_failed is off, then we only care about
194 // statistics.
195 // These statistics logging was added to identify some errors in practice.
196 std::vector<double> diffs, temp;
197 diffs.reserve(n);
198 temp.reserve(n);
199 diffs.resize(n);
200 temp.resize(n);
201 for (int i = 0; i < n; ++i) {
202 float dequant = GetTensorData<float>(dequantized)[i];
203 float reference = GetTensorData<float>(op_context.ref)[i];
204 diffs[i] = static_cast<double>(dequant - reference);
205 output_data[i] = dequant - reference;
206 }
207 double mean =
208 std::accumulate(diffs.begin(), diffs.end(), 0.0) / diffs.size();
209 double max_diff = 0.0;
210 std::transform(diffs.begin(), diffs.end(), temp.begin(),
211 [mean, &max_diff](double x) {
212 max_diff = std::max(max_diff, std::abs(x));
213 return x - mean;
214 });
215 double sq_sum =
216 std::inner_product(temp.begin(), temp.end(), temp.begin(), 0.0);
217 double std = std::sqrt(sq_sum / diffs.size());
218 TF_LITE_KERNEL_LOG(
219 context,
220 "std: %f, mean: %f, max_diff: %f (scale: %f, zero_point: %d).\n", std,
221 mean, max_diff, op_context.input->params.scale,
222 op_context.input->params.zero_point);
223 }
224 return kTfLiteOk;
225}
226
227} // namespace numeric_verify
228
229TfLiteRegistration* Register_NUMERIC_VERIFY_OPT() {
230 static TfLiteRegistration r = {
231 numeric_verify::Init, numeric_verify::Free, numeric_verify::Prepare,
232 numeric_verify::Eval<builtin::dequantize::kGenericOptimized>};
233 return &r;
234}
235
236TfLiteRegistration* Register_NUMERIC_VERIFY_REF() {
237 static TfLiteRegistration r = {
238 numeric_verify::Init, numeric_verify::Free, numeric_verify::Prepare,
239 numeric_verify::Eval<builtin::dequantize::kReference>};
240 return &r;
241}
242
243TfLiteRegistration* Register_NUMERIC_VERIFY() {
244#ifdef USE_NEON
245 return Register_NUMERIC_VERIFY_OPT();
246#else
247 return Register_NUMERIC_VERIFY_REF();
248#endif
249}
250
251} // namespace custom
252} // namespace ops
253} // namespace tflite
254