1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stddef.h>
16#include <stdint.h>
17
18#include <algorithm>
19
20#include "ruy/profiler/instrumentation.h" // from @ruy
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
23#include "tensorflow/lite/kernels/internal/quantization_util.h"
24#include "tensorflow/lite/kernels/internal/reference/binary_function.h"
25#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
26#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27#include "tensorflow/lite/kernels/internal/tensor.h"
28#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29#include "tensorflow/lite/kernels/kernel_util.h"
30
31namespace tflite {
32namespace ops {
33namespace builtin {
34namespace squared_difference {
35
36constexpr int kInputTensor1 = 0;
37constexpr int kInputTensor2 = 1;
38constexpr int kOutputTensor = 0;
39
40struct OpData {
41 bool requires_broadcast;
42 ArithmeticParams arithmetic_params;
43};
44
45template <typename T>
46T SquaredDifference(T input1, T input2) {
47 const T difference = input1 - input2;
48 return difference * difference;
49}
50
51void* Init(TfLiteContext* context, const char* buffer, size_t length) {
52 auto* data = new OpData;
53 data->requires_broadcast = false;
54 return data;
55}
56
57void Free(TfLiteContext* context, void* buffer) {
58 delete reinterpret_cast<OpData*>(buffer);
59}
60
61TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
62 OpData* data = reinterpret_cast<OpData*>(node->user_data);
63
64 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
65 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
66
67 const TfLiteTensor* input1;
68 TF_LITE_ENSURE_OK(context,
69 GetInputSafe(context, node, kInputTensor1, &input1));
70 const TfLiteTensor* input2;
71 TF_LITE_ENSURE_OK(context,
72 GetInputSafe(context, node, kInputTensor2, &input2));
73 TfLiteTensor* output;
74 TF_LITE_ENSURE_OK(context,
75 GetOutputSafe(context, node, kOutputTensor, &output));
76
77 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
78 output->type = input2->type;
79
80 // Ensure the quantization parameters are equivalent.
81 if (input1->type == kTfLiteInt8) {
82 const auto& input1_quantization_params = input1->params;
83 const auto& input2_quantization_params = input2->params;
84 const auto& output_quantization_params = output->params;
85 const int32_t integer_type_min = std::numeric_limits<int8_t>::min();
86 const int32_t integer_type_max = std::numeric_limits<int8_t>::max();
87 TF_LITE_ENSURE(context,
88 input1_quantization_params.zero_point >= integer_type_min);
89 TF_LITE_ENSURE(context,
90 input1_quantization_params.zero_point <= integer_type_max);
91 TF_LITE_ENSURE(context,
92 input2_quantization_params.zero_point >= integer_type_min);
93 TF_LITE_ENSURE(context,
94 input2_quantization_params.zero_point <= integer_type_max);
95 TF_LITE_ENSURE(context,
96 output_quantization_params.zero_point >= integer_type_min);
97 TF_LITE_ENSURE(context,
98 output_quantization_params.zero_point <= integer_type_max);
99 data->arithmetic_params.input1_offset =
100 -input1_quantization_params.zero_point;
101 data->arithmetic_params.input2_offset =
102 -input2_quantization_params.zero_point;
103 data->arithmetic_params.output_offset =
104 output_quantization_params.zero_point;
105
106 // shift to make integer for scales.
107 data->arithmetic_params.left_shift = 7;
108 const double twice_max_input_scale =
109 2 * std::max(input1_quantization_params.scale,
110 input2_quantization_params.scale);
111 const double real_input1_multiplier =
112 input1_quantization_params.scale / twice_max_input_scale;
113 double real_input2_multiplier =
114 input2_quantization_params.scale / twice_max_input_scale;
115 const double real_output_multiplier =
116 (twice_max_input_scale * twice_max_input_scale) /
117 ((1 << data->arithmetic_params.left_shift * 2) *
118 output_quantization_params.scale);
119 tflite::QuantizeMultiplierSmallerThanOneExp(
120 real_input1_multiplier, &data->arithmetic_params.input1_multiplier,
121 &data->arithmetic_params.input1_shift);
122 tflite::QuantizeMultiplierSmallerThanOneExp(
123 real_input2_multiplier, &data->arithmetic_params.input2_multiplier,
124 &data->arithmetic_params.input2_shift);
125 tflite::QuantizeMultiplierSmallerThanOneExp(
126 real_output_multiplier, &data->arithmetic_params.output_multiplier,
127 &data->arithmetic_params.output_shift);
128 data->arithmetic_params.quantized_activation_min =
129 std::numeric_limits<int8_t>::min();
130 data->arithmetic_params.quantized_activation_max =
131 std::numeric_limits<int8_t>::max();
132 }
133
134 data->requires_broadcast = !HaveSameShapes(input1, input2);
135
136 TfLiteIntArray* output_size = nullptr;
137 if (data->requires_broadcast) {
138 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
139 context, input1, input2, &output_size));
140 } else {
141 output_size = TfLiteIntArrayCopy(input1->dims);
142 }
143
144 return context->ResizeTensor(context, output, output_size);
145}
146
147inline int8_t SquaredDifference(int8_t x, int8_t y,
148 const ArithmeticParams& params) {
149 const int32_t input1_val = params.input1_offset + x;
150 const int32_t input2_val = params.input2_offset + y;
151 const int32_t shifted_input1_val = input1_val * (1 << params.left_shift);
152 const int32_t shifted_input2_val = input2_val * (1 << params.left_shift);
153 const int32_t scaled_input1_val =
154 MultiplyByQuantizedMultiplierSmallerThanOneExp(
155 shifted_input1_val, params.input1_multiplier, params.input1_shift);
156 const int32_t scaled_input2_val =
157 MultiplyByQuantizedMultiplierSmallerThanOneExp(
158 shifted_input2_val, params.input2_multiplier, params.input2_shift);
159 const int32_t raw_diff = scaled_input1_val - scaled_input2_val;
160
161 // Max of this is 255^2 * (1 << 14), so won't overflow 32 bits.
162 const int32_t squared_raw_diff = raw_diff * raw_diff;
163 const int32_t raw_output =
164 MultiplyByQuantizedMultiplierSmallerThanOneExp(
165 squared_raw_diff, params.output_multiplier, params.output_shift) +
166 params.output_offset;
167 const int32_t clamped_output =
168 std::min(params.quantized_activation_max,
169 std::max(params.quantized_activation_min, raw_output));
170 return static_cast<int8_t>(clamped_output);
171}
172
173template <typename T>
174void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node,
175 const OpData* data,
176 const TfLiteTensor* input1,
177 const TfLiteTensor* input2,
178 TfLiteTensor* output) {
179 const auto* op_data = static_cast<const OpData*>(node->user_data);
180 if (data->requires_broadcast) {
181 reference_integer_ops::BroadcastBinaryFunction4DSlow(
182 op_data->arithmetic_params, GetTensorShape(input1),
183 GetTensorData<T>(input1), GetTensorShape(input2),
184 GetTensorData<T>(input2), GetTensorShape(output),
185 GetTensorData<T>(output), reference_integer_ops::CheckArithmeticParams,
186 SquaredDifference);
187 } else {
188 const int flat_size = GetTensorShape(input1).FlatSize();
189 reference_integer_ops::ElementWise(
190 flat_size, op_data->arithmetic_params, GetTensorData<int8_t>(input1),
191 GetTensorData<int8_t>(input2), GetTensorData<int8_t>(output),
192 reference_integer_ops::CheckArithmeticParams, SquaredDifference);
193 }
194}
195
196template <typename T>
197void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node,
198 const OpData* data, const TfLiteTensor* input1,
199 const TfLiteTensor* input2, TfLiteTensor* output) {
200 if (data->requires_broadcast) {
201 reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>(
202 GetTensorShape(input1), GetTensorData<T>(input1),
203 GetTensorShape(input2), GetTensorData<T>(input2),
204 GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
205 } else {
206 reference_ops::BinaryFunction<T, T, T>(
207 GetTensorShape(input1), GetTensorData<T>(input1),
208 GetTensorShape(input2), GetTensorData<T>(input2),
209 GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>);
210 }
211}
212
213TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
214 OpData* data = reinterpret_cast<OpData*>(node->user_data);
215 ruy::profiler::ScopeLabel label("SquaredDifference");
216
217 const TfLiteTensor* input1;
218 TF_LITE_ENSURE_OK(context,
219 GetInputSafe(context, node, kInputTensor1, &input1));
220 const TfLiteTensor* input2;
221 TF_LITE_ENSURE_OK(context,
222 GetInputSafe(context, node, kInputTensor2, &input2));
223 TfLiteTensor* output;
224 TF_LITE_ENSURE_OK(context,
225 GetOutputSafe(context, node, kOutputTensor, &output));
226
227 if (output->type == kTfLiteFloat32) {
228 EvalSquaredDifference<float>(context, node, data, input1, input2, output);
229 } else if (output->type == kTfLiteInt32) {
230 EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output);
231 } else if (output->type == kTfLiteInt8) {
232 EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2,
233 output);
234 } else {
235 TF_LITE_KERNEL_LOG(
236 context,
237 "SquaredDifference only supports FLOAT32 and INT32 now, got %d.",
238 output->type);
239 return kTfLiteError;
240 }
241
242 return kTfLiteOk;
243}
244
245} // namespace squared_difference
246
247TfLiteRegistration* Register_SQUARED_DIFFERENCE() {
248 static TfLiteRegistration r = {
249 squared_difference::Init, squared_difference::Free,
250 squared_difference::Prepare, squared_difference::Eval};
251 return &r;
252}
253
254} // namespace builtin
255} // namespace ops
256} // namespace tflite
257