1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stddef.h> |
16 | #include <stdint.h> |
17 | |
18 | #include <algorithm> |
19 | |
20 | #include "ruy/profiler/instrumentation.h" // from @ruy |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
23 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
24 | #include "tensorflow/lite/kernels/internal/reference/binary_function.h" |
25 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h" |
26 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor.h" |
28 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
29 | #include "tensorflow/lite/kernels/kernel_util.h" |
30 | |
31 | namespace tflite { |
32 | namespace ops { |
33 | namespace builtin { |
34 | namespace squared_difference { |
35 | |
36 | constexpr int kInputTensor1 = 0; |
37 | constexpr int kInputTensor2 = 1; |
38 | constexpr int kOutputTensor = 0; |
39 | |
40 | struct OpData { |
41 | bool requires_broadcast; |
42 | ArithmeticParams arithmetic_params; |
43 | }; |
44 | |
45 | template <typename T> |
46 | T SquaredDifference(T input1, T input2) { |
47 | const T difference = input1 - input2; |
48 | return difference * difference; |
49 | } |
50 | |
51 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
52 | auto* data = new OpData; |
53 | data->requires_broadcast = false; |
54 | return data; |
55 | } |
56 | |
57 | void Free(TfLiteContext* context, void* buffer) { |
58 | delete reinterpret_cast<OpData*>(buffer); |
59 | } |
60 | |
61 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
62 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
63 | |
64 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
65 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
66 | |
67 | const TfLiteTensor* input1; |
68 | TF_LITE_ENSURE_OK(context, |
69 | GetInputSafe(context, node, kInputTensor1, &input1)); |
70 | const TfLiteTensor* input2; |
71 | TF_LITE_ENSURE_OK(context, |
72 | GetInputSafe(context, node, kInputTensor2, &input2)); |
73 | TfLiteTensor* output; |
74 | TF_LITE_ENSURE_OK(context, |
75 | GetOutputSafe(context, node, kOutputTensor, &output)); |
76 | |
77 | TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); |
78 | output->type = input2->type; |
79 | |
80 | // Ensure the quantization parameters are equivalent. |
81 | if (input1->type == kTfLiteInt8) { |
82 | const auto& input1_quantization_params = input1->params; |
83 | const auto& input2_quantization_params = input2->params; |
84 | const auto& output_quantization_params = output->params; |
85 | const int32_t integer_type_min = std::numeric_limits<int8_t>::min(); |
86 | const int32_t integer_type_max = std::numeric_limits<int8_t>::max(); |
87 | TF_LITE_ENSURE(context, |
88 | input1_quantization_params.zero_point >= integer_type_min); |
89 | TF_LITE_ENSURE(context, |
90 | input1_quantization_params.zero_point <= integer_type_max); |
91 | TF_LITE_ENSURE(context, |
92 | input2_quantization_params.zero_point >= integer_type_min); |
93 | TF_LITE_ENSURE(context, |
94 | input2_quantization_params.zero_point <= integer_type_max); |
95 | TF_LITE_ENSURE(context, |
96 | output_quantization_params.zero_point >= integer_type_min); |
97 | TF_LITE_ENSURE(context, |
98 | output_quantization_params.zero_point <= integer_type_max); |
99 | data->arithmetic_params.input1_offset = |
100 | -input1_quantization_params.zero_point; |
101 | data->arithmetic_params.input2_offset = |
102 | -input2_quantization_params.zero_point; |
103 | data->arithmetic_params.output_offset = |
104 | output_quantization_params.zero_point; |
105 | |
106 | // shift to make integer for scales. |
107 | data->arithmetic_params.left_shift = 7; |
108 | const double twice_max_input_scale = |
109 | 2 * std::max(input1_quantization_params.scale, |
110 | input2_quantization_params.scale); |
111 | const double real_input1_multiplier = |
112 | input1_quantization_params.scale / twice_max_input_scale; |
113 | double real_input2_multiplier = |
114 | input2_quantization_params.scale / twice_max_input_scale; |
115 | const double real_output_multiplier = |
116 | (twice_max_input_scale * twice_max_input_scale) / |
117 | ((1 << data->arithmetic_params.left_shift * 2) * |
118 | output_quantization_params.scale); |
119 | tflite::QuantizeMultiplierSmallerThanOneExp( |
120 | real_input1_multiplier, &data->arithmetic_params.input1_multiplier, |
121 | &data->arithmetic_params.input1_shift); |
122 | tflite::QuantizeMultiplierSmallerThanOneExp( |
123 | real_input2_multiplier, &data->arithmetic_params.input2_multiplier, |
124 | &data->arithmetic_params.input2_shift); |
125 | tflite::QuantizeMultiplierSmallerThanOneExp( |
126 | real_output_multiplier, &data->arithmetic_params.output_multiplier, |
127 | &data->arithmetic_params.output_shift); |
128 | data->arithmetic_params.quantized_activation_min = |
129 | std::numeric_limits<int8_t>::min(); |
130 | data->arithmetic_params.quantized_activation_max = |
131 | std::numeric_limits<int8_t>::max(); |
132 | } |
133 | |
134 | data->requires_broadcast = !HaveSameShapes(input1, input2); |
135 | |
136 | TfLiteIntArray* output_size = nullptr; |
137 | if (data->requires_broadcast) { |
138 | TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( |
139 | context, input1, input2, &output_size)); |
140 | } else { |
141 | output_size = TfLiteIntArrayCopy(input1->dims); |
142 | } |
143 | |
144 | return context->ResizeTensor(context, output, output_size); |
145 | } |
146 | |
147 | inline int8_t SquaredDifference(int8_t x, int8_t y, |
148 | const ArithmeticParams& params) { |
149 | const int32_t input1_val = params.input1_offset + x; |
150 | const int32_t input2_val = params.input2_offset + y; |
151 | const int32_t shifted_input1_val = input1_val * (1 << params.left_shift); |
152 | const int32_t shifted_input2_val = input2_val * (1 << params.left_shift); |
153 | const int32_t scaled_input1_val = |
154 | MultiplyByQuantizedMultiplierSmallerThanOneExp( |
155 | shifted_input1_val, params.input1_multiplier, params.input1_shift); |
156 | const int32_t scaled_input2_val = |
157 | MultiplyByQuantizedMultiplierSmallerThanOneExp( |
158 | shifted_input2_val, params.input2_multiplier, params.input2_shift); |
159 | const int32_t raw_diff = scaled_input1_val - scaled_input2_val; |
160 | |
161 | // Max of this is 255^2 * (1 << 14), so won't overflow 32 bits. |
162 | const int32_t squared_raw_diff = raw_diff * raw_diff; |
163 | const int32_t raw_output = |
164 | MultiplyByQuantizedMultiplierSmallerThanOneExp( |
165 | squared_raw_diff, params.output_multiplier, params.output_shift) + |
166 | params.output_offset; |
167 | const int32_t clamped_output = |
168 | std::min(params.quantized_activation_max, |
169 | std::max(params.quantized_activation_min, raw_output)); |
170 | return static_cast<int8_t>(clamped_output); |
171 | } |
172 | |
173 | template <typename T> |
174 | void EvalQuantizedSquaredDifference(TfLiteContext* context, TfLiteNode* node, |
175 | const OpData* data, |
176 | const TfLiteTensor* input1, |
177 | const TfLiteTensor* input2, |
178 | TfLiteTensor* output) { |
179 | const auto* op_data = static_cast<const OpData*>(node->user_data); |
180 | if (data->requires_broadcast) { |
181 | reference_integer_ops::BroadcastBinaryFunction4DSlow( |
182 | op_data->arithmetic_params, GetTensorShape(input1), |
183 | GetTensorData<T>(input1), GetTensorShape(input2), |
184 | GetTensorData<T>(input2), GetTensorShape(output), |
185 | GetTensorData<T>(output), reference_integer_ops::CheckArithmeticParams, |
186 | SquaredDifference); |
187 | } else { |
188 | const int flat_size = GetTensorShape(input1).FlatSize(); |
189 | reference_integer_ops::ElementWise( |
190 | flat_size, op_data->arithmetic_params, GetTensorData<int8_t>(input1), |
191 | GetTensorData<int8_t>(input2), GetTensorData<int8_t>(output), |
192 | reference_integer_ops::CheckArithmeticParams, SquaredDifference); |
193 | } |
194 | } |
195 | |
196 | template <typename T> |
197 | void EvalSquaredDifference(TfLiteContext* context, TfLiteNode* node, |
198 | const OpData* data, const TfLiteTensor* input1, |
199 | const TfLiteTensor* input2, TfLiteTensor* output) { |
200 | if (data->requires_broadcast) { |
201 | reference_ops::BroadcastBinaryFunction4DSlow<T, T, T>( |
202 | GetTensorShape(input1), GetTensorData<T>(input1), |
203 | GetTensorShape(input2), GetTensorData<T>(input2), |
204 | GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>); |
205 | } else { |
206 | reference_ops::BinaryFunction<T, T, T>( |
207 | GetTensorShape(input1), GetTensorData<T>(input1), |
208 | GetTensorShape(input2), GetTensorData<T>(input2), |
209 | GetTensorShape(output), GetTensorData<T>(output), SquaredDifference<T>); |
210 | } |
211 | } |
212 | |
213 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
214 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
215 | ruy::profiler::ScopeLabel label("SquaredDifference" ); |
216 | |
217 | const TfLiteTensor* input1; |
218 | TF_LITE_ENSURE_OK(context, |
219 | GetInputSafe(context, node, kInputTensor1, &input1)); |
220 | const TfLiteTensor* input2; |
221 | TF_LITE_ENSURE_OK(context, |
222 | GetInputSafe(context, node, kInputTensor2, &input2)); |
223 | TfLiteTensor* output; |
224 | TF_LITE_ENSURE_OK(context, |
225 | GetOutputSafe(context, node, kOutputTensor, &output)); |
226 | |
227 | if (output->type == kTfLiteFloat32) { |
228 | EvalSquaredDifference<float>(context, node, data, input1, input2, output); |
229 | } else if (output->type == kTfLiteInt32) { |
230 | EvalSquaredDifference<int32_t>(context, node, data, input1, input2, output); |
231 | } else if (output->type == kTfLiteInt8) { |
232 | EvalQuantizedSquaredDifference<int8_t>(context, node, data, input1, input2, |
233 | output); |
234 | } else { |
235 | TF_LITE_KERNEL_LOG( |
236 | context, |
237 | "SquaredDifference only supports FLOAT32 and INT32 now, got %d." , |
238 | output->type); |
239 | return kTfLiteError; |
240 | } |
241 | |
242 | return kTfLiteOk; |
243 | } |
244 | |
245 | } // namespace squared_difference |
246 | |
247 | TfLiteRegistration* Register_SQUARED_DIFFERENCE() { |
248 | static TfLiteRegistration r = { |
249 | squared_difference::Init, squared_difference::Free, |
250 | squared_difference::Prepare, squared_difference::Eval}; |
251 | return &r; |
252 | } |
253 | |
254 | } // namespace builtin |
255 | } // namespace ops |
256 | } // namespace tflite |
257 | |