1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <stdint.h> |
17 | #include <stdlib.h> |
18 | |
19 | #include <algorithm> |
20 | #include <cmath> |
21 | #include <functional> |
22 | #include <limits> |
23 | |
24 | #include "tensorflow/lite/c/common.h" |
25 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
26 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor.h" |
28 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
29 | #include "tensorflow/lite/kernels/kernel_util.h" |
30 | #include "tensorflow/lite/kernels/op_macros.h" |
31 | |
32 | namespace tflite { |
33 | namespace ops { |
34 | namespace builtin { |
35 | namespace elementwise { |
36 | namespace { |
37 | |
38 | const char kAbsName[] = "Abs" ; |
39 | const char kRsqrtName[] = "Rsqrt" ; |
40 | |
41 | struct OpData { |
42 | int32_t multiplier; |
43 | int32_t shift; |
44 | int input_offset; |
45 | int output_offset; |
46 | bool needs_rescale; |
47 | }; |
48 | |
49 | bool IsNumericSupportedType(const TfLiteType type) { |
50 | return type == kTfLiteFloat32; |
51 | } |
52 | |
53 | bool IsLogicalSupportedType(const TfLiteType type) { |
54 | return type == kTfLiteBool; |
55 | } |
56 | |
57 | bool IsAbsSupportedType(const TfLiteType type) { |
58 | return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16; |
59 | } |
60 | |
61 | bool IsRsqrtSupportedType(const TfLiteType type) { |
62 | return type == kTfLiteFloat32 || type == kTfLiteInt8; |
63 | } |
64 | |
65 | inline void SetAbsOutputMultiplier(const float input_scale, |
66 | const float output_scale, |
67 | int32_t* multiplier, int32_t* shift) { |
68 | QuantizeMultiplier(input_scale / output_scale, multiplier, shift); |
69 | } |
70 | |
71 | inline void SetRsqrtOutputMultiplier(const float input_scale, |
72 | const float output_scale, |
73 | int32_t* multiplier, int32_t* shift) { |
74 | const double scale = 1. / (std::sqrt(input_scale) * output_scale); |
75 | QuantizeMultiplier(scale, multiplier, shift); |
76 | } |
77 | |
78 | typedef bool (*IsSupportedType)(TfLiteType); |
79 | TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node, |
80 | IsSupportedType is_supported_type, |
81 | const char* op_name) { |
82 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
83 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
84 | const TfLiteTensor* input; |
85 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
86 | TfLiteTensor* output; |
87 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
88 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
89 | if (!is_supported_type(input->type)) { |
90 | TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name); |
91 | } |
92 | // For int16 type input, we support both quantized and non-quantized |
93 | // evaluation. |
94 | if (input->type == kTfLiteInt8 || |
95 | (input->type == kTfLiteInt16 && |
96 | input->quantization.type != kTfLiteNoQuantization)) { |
97 | TfLiteTensor* output = GetOutput(context, node, 0); |
98 | auto* op_data = static_cast<OpData*>(node->user_data); |
99 | TF_LITE_ENSURE_EQ(context, input->quantization.type, |
100 | kTfLiteAffineQuantization); |
101 | TF_LITE_ENSURE_EQ(context, output->quantization.type, |
102 | kTfLiteAffineQuantization); |
103 | const auto* input_params = |
104 | reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params); |
105 | const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>( |
106 | output->quantization.params); |
107 | TF_LITE_ENSURE(context, input_params != nullptr); |
108 | TF_LITE_ENSURE(context, input_params->scale != nullptr); |
109 | TF_LITE_ENSURE(context, input_params->scale->size > 0); |
110 | TF_LITE_ENSURE(context, input_params->zero_point->size > 0); |
111 | TF_LITE_ENSURE(context, output_params != nullptr); |
112 | TF_LITE_ENSURE(context, output_params->scale != nullptr); |
113 | TF_LITE_ENSURE(context, output_params->scale->size > 0); |
114 | TF_LITE_ENSURE(context, output_params->zero_point->size > 0); |
115 | op_data->input_offset = input_params->zero_point->data[0]; |
116 | op_data->output_offset = output_params->zero_point->data[0]; |
117 | if (input->type == kTfLiteInt16) { |
118 | TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0); |
119 | TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0); |
120 | } |
121 | const float input_scale = input_params->scale->data[0]; |
122 | const float output_scale = output_params->scale->data[0]; |
123 | op_data->needs_rescale = input_scale != output_scale; |
124 | if (op_name == kAbsName && op_data->needs_rescale) { |
125 | SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier, |
126 | &op_data->shift); |
127 | } else if (op_name == kRsqrtName) { |
128 | SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier, |
129 | &op_data->shift); |
130 | } |
131 | } |
132 | return context->ResizeTensor(context, output, |
133 | TfLiteIntArrayCopy(input->dims)); |
134 | } |
135 | |
136 | template <typename T> |
137 | inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, |
138 | std::function<T(T)> func, |
139 | std::function<TfLiteStatus(T)> validate_input_func, |
140 | TfLiteType expected_type) { |
141 | const TfLiteTensor* input; |
142 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
143 | TfLiteTensor* output; |
144 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
145 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type); |
146 | const int64_t num_elements = NumElements(input); |
147 | const T* in_data = GetTensorData<T>(input); |
148 | T* out_data = GetTensorData<T>(output); |
149 | for (int64_t i = 0; i < num_elements; ++i) { |
150 | if (validate_input_func) { |
151 | TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i])); |
152 | } |
153 | out_data[i] = func(in_data[i]); |
154 | } |
155 | return kTfLiteOk; |
156 | } |
157 | |
158 | // Non-quantized evaluation of Abs op when input is int16. |
159 | inline TfLiteStatus AbsInt16EvalImpl(TfLiteContext* context, TfLiteNode* node, |
160 | TfLiteType expected_type) { |
161 | const TfLiteTensor* input; |
162 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
163 | TfLiteTensor* output; |
164 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
165 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type); |
166 | const int64_t num_elements = NumElements(input); |
167 | const int16_t* in_data = GetTensorData<int16_t>(input); |
168 | int16_t* out_data = GetTensorData<int16_t>(output); |
169 | for (int64_t i = 0; i < num_elements; ++i) { |
170 | out_data[i] = static_cast<int16_t>( |
171 | std::abs<int32_t>(static_cast<int32_t>(in_data[i]))); |
172 | } |
173 | return kTfLiteOk; |
174 | } |
175 | |
176 | template <typename T> |
177 | inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, |
178 | std::function<T(T)> func, |
179 | TfLiteType expected_type) { |
180 | return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr, |
181 | expected_type); |
182 | } |
183 | |
184 | inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node, |
185 | float float_func(float)) { |
186 | return EvalImpl<float>(context, node, float_func, kTfLiteFloat32); |
187 | } |
188 | |
189 | inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node, |
190 | bool bool_func(bool)) { |
191 | return EvalImpl<bool>(context, node, bool_func, kTfLiteBool); |
192 | } |
193 | |
194 | void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer, |
195 | size_t length) { |
196 | return new OpData(); |
197 | } |
198 | |
199 | void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) { |
200 | delete static_cast<OpData*>(buffer); |
201 | } |
202 | |
203 | template <typename T> |
204 | TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node, |
205 | TfLiteType type) { |
206 | const auto* op_data = static_cast<const OpData*>(node->user_data); |
207 | const int kMin = std::numeric_limits<T>::min(); |
208 | const int kMax = std::numeric_limits<T>::max(); |
209 | |
210 | std::function<T(T)> func = [&](T i) { |
211 | const int32_t value = std::abs(i - op_data->input_offset); |
212 | if (!op_data->needs_rescale) { |
213 | return static_cast<T>( |
214 | std::min(std::max(value + op_data->output_offset, kMin), kMax)); |
215 | } |
216 | const int32_t output = MultiplyByQuantizedMultiplier( |
217 | value, op_data->multiplier, op_data->shift) + |
218 | op_data->output_offset; |
219 | return static_cast<T>(std::min(std::max(output, kMin), kMax)); |
220 | }; |
221 | |
222 | return EvalImpl<T>(context, node, func, type); |
223 | } |
224 | |
225 | TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) { |
226 | const TfLiteTensor* input = GetInput(context, node, 0); |
227 | const TfLiteType type = input->type; |
228 | switch (type) { |
229 | case kTfLiteFloat32: |
230 | return EvalImpl<float>(context, node, std::abs<float>, type); |
231 | case kTfLiteInt8: |
232 | return AbsEvalQuantized<int8_t>(context, node, type); |
233 | case kTfLiteInt16: |
234 | return input->quantization.type == kTfLiteNoQuantization |
235 | ? AbsInt16EvalImpl(context, node, type) |
236 | : AbsEvalQuantized<int16_t>(context, node, type); |
237 | default: |
238 | TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported." , |
239 | TfLiteTypeGetName(type)); |
240 | return kTfLiteError; |
241 | } |
242 | } |
243 | |
244 | TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { |
245 | return EvalNumeric(context, node, std::sin); |
246 | } |
247 | |
248 | TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) { |
249 | return EvalNumeric(context, node, std::cos); |
250 | } |
251 | |
252 | TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) { |
253 | return EvalNumeric(context, node, std::log); |
254 | } |
255 | |
256 | TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) { |
257 | return EvalNumeric(context, node, std::sqrt); |
258 | } |
259 | |
260 | TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node, |
261 | TfLiteType type) { |
262 | const auto* op_data = static_cast<const OpData*>(node->user_data); |
263 | const int kMin = std::numeric_limits<int8_t>::min(); |
264 | const int kMax = std::numeric_limits<int8_t>::max(); |
265 | std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) { |
266 | TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset, |
267 | "Rsqrt is only defined for positive values" ); |
268 | return kTfLiteOk; |
269 | }; |
270 | |
271 | std::function<int8_t(int8_t)> func = [&](int8_t i) { |
272 | const int32_t value = (i - op_data->input_offset); |
273 | const int32_t kShift = 20; // Shift to keep value integer. |
274 | if (value == 0) { |
275 | // Assume that any value close to 0 represents the max output value. |
276 | return static_cast<int8_t>(kMax); |
277 | } |
278 | int32_t inv_sqrt_multiplier; |
279 | int inv_sqrt_shift; |
280 | GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier, |
281 | &inv_sqrt_shift); |
282 | const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier, |
283 | inv_sqrt_shift + kShift); |
284 | const int32_t output = |
285 | MultiplyByQuantizedMultiplier(data, op_data->multiplier, |
286 | op_data->shift - kShift) + |
287 | op_data->output_offset; |
288 | return static_cast<int8_t>(std::min(std::max(output, kMin), kMax)); |
289 | }; |
290 | |
291 | return EvalImpl<int8_t>(context, node, func, validate_input_func, type); |
292 | } |
293 | |
294 | TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) { |
295 | const TfLiteType type = GetInput(context, node, 0)->type; |
296 | switch (type) { |
297 | case kTfLiteFloat32: |
298 | return EvalImpl<float>( |
299 | context, node, [](float f) { return 1.f / std::sqrt(f); }, type); |
300 | case kTfLiteInt8: |
301 | return RsqrtEvalQuantized(context, node, type); |
302 | default: |
303 | TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported." , |
304 | TfLiteTypeGetName(type)); |
305 | return kTfLiteError; |
306 | } |
307 | } |
308 | |
309 | TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) { |
310 | return EvalNumeric(context, node, [](float f) { return f * f; }); |
311 | } |
312 | |
313 | TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) { |
314 | return EvalLogical(context, node, [](bool v) { return !v; }); |
315 | } |
316 | |
317 | } // namespace |
318 | } // namespace elementwise |
319 | |
320 | // Given a function... |
321 | // template<int T> |
322 | // int Foo(int b) |
323 | // |
324 | // typedef int(*Bar)(int); |
325 | // |
326 | // MSVC2015 will not see Foo<10> as the same type as Bar. |
327 | // |
328 | // This works around the issue by instantiating wrapper methods around |
329 | // elementwise::GenericPrepare() rather than using a templated |
330 | // elementwise::GenericPrepare method. |
331 | #define GENERIC_PREPARE(function_name, is_supported_type_function, type_name) \ |
332 | static TfLiteStatus function_name(TfLiteContext* context, \ |
333 | TfLiteNode* node) { \ |
334 | return elementwise::GenericPrepare(context, node, \ |
335 | is_supported_type_function, type_name); \ |
336 | } |
337 | |
338 | GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType, |
339 | elementwise::kAbsName) |
340 | |
341 | TfLiteRegistration* Register_ABS() { |
342 | static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit, |
343 | elementwise::ElementWiseQuantizedFree, |
344 | PrepareAbs, elementwise::AbsEval}; |
345 | return &r; |
346 | } |
347 | |
348 | GENERIC_PREPARE(PrepareSin, elementwise::IsNumericSupportedType, "Sin" ) |
349 | |
350 | TfLiteRegistration* Register_SIN() { |
351 | static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareSin, |
352 | elementwise::SinEval}; |
353 | return &r; |
354 | } |
355 | |
356 | GENERIC_PREPARE(PrepareCos, elementwise::IsNumericSupportedType, "Cos" ) |
357 | |
358 | TfLiteRegistration* Register_COS() { |
359 | static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareCos, |
360 | elementwise::CosEval}; |
361 | return &r; |
362 | } |
363 | |
364 | GENERIC_PREPARE(PrepareLog, elementwise::IsNumericSupportedType, "Log" ) |
365 | |
366 | TfLiteRegistration* Register_LOG() { |
367 | static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareLog, |
368 | elementwise::LogEval}; |
369 | return &r; |
370 | } |
371 | |
372 | GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt" ) |
373 | |
374 | TfLiteRegistration* Register_SQRT() { |
375 | static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, |
376 | PrepareSqrt, elementwise::SqrtEval}; |
377 | return &r; |
378 | } |
379 | |
380 | GENERIC_PREPARE(PrepareRsqrt, elementwise::IsRsqrtSupportedType, |
381 | elementwise::kRsqrtName) |
382 | |
383 | TfLiteRegistration* Register_RSQRT() { |
384 | static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit, |
385 | elementwise::ElementWiseQuantizedFree, |
386 | PrepareRsqrt, elementwise::RsqrtEval}; |
387 | return &r; |
388 | } |
389 | |
390 | GENERIC_PREPARE(PrepareSquare, elementwise::IsNumericSupportedType, "Square" ) |
391 | |
392 | TfLiteRegistration* Register_SQUARE() { |
393 | static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, |
394 | PrepareSquare, elementwise::SquareEval}; |
395 | return &r; |
396 | } |
397 | |
398 | GENERIC_PREPARE(PrepareNot, elementwise::IsLogicalSupportedType, "Not" ) |
399 | |
400 | TfLiteRegistration* Register_LOGICAL_NOT() { |
401 | static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareNot, |
402 | elementwise::LogicalNotEval}; |
403 | return &r; |
404 | } |
405 | |
406 | } // namespace builtin |
407 | } // namespace ops |
408 | } // namespace tflite |
409 | |