1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17#include <stdlib.h>
18
19#include <algorithm>
20#include <cmath>
21#include <functional>
22#include <limits>
23
24#include "tensorflow/lite/c/common.h"
25#include "tensorflow/lite/kernels/internal/quantization_util.h"
26#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
27#include "tensorflow/lite/kernels/internal/tensor.h"
28#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
29#include "tensorflow/lite/kernels/kernel_util.h"
30#include "tensorflow/lite/kernels/op_macros.h"
31
32namespace tflite {
33namespace ops {
34namespace builtin {
35namespace elementwise {
36namespace {
37
38const char kAbsName[] = "Abs";
39const char kRsqrtName[] = "Rsqrt";
40
41struct OpData {
42 int32_t multiplier;
43 int32_t shift;
44 int input_offset;
45 int output_offset;
46 bool needs_rescale;
47};
48
49bool IsNumericSupportedType(const TfLiteType type) {
50 return type == kTfLiteFloat32;
51}
52
53bool IsLogicalSupportedType(const TfLiteType type) {
54 return type == kTfLiteBool;
55}
56
57bool IsAbsSupportedType(const TfLiteType type) {
58 return type == kTfLiteFloat32 || type == kTfLiteInt8 || type == kTfLiteInt16;
59}
60
61bool IsRsqrtSupportedType(const TfLiteType type) {
62 return type == kTfLiteFloat32 || type == kTfLiteInt8;
63}
64
65inline void SetAbsOutputMultiplier(const float input_scale,
66 const float output_scale,
67 int32_t* multiplier, int32_t* shift) {
68 QuantizeMultiplier(input_scale / output_scale, multiplier, shift);
69}
70
71inline void SetRsqrtOutputMultiplier(const float input_scale,
72 const float output_scale,
73 int32_t* multiplier, int32_t* shift) {
74 const double scale = 1. / (std::sqrt(input_scale) * output_scale);
75 QuantizeMultiplier(scale, multiplier, shift);
76}
77
78typedef bool (*IsSupportedType)(TfLiteType);
79TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node,
80 IsSupportedType is_supported_type,
81 const char* op_name) {
82 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
83 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
84 const TfLiteTensor* input;
85 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
86 TfLiteTensor* output;
87 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
88 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
89 if (!is_supported_type(input->type)) {
90 TF_LITE_UNSUPPORTED_TYPE(context, input->type, op_name);
91 }
92 // For int16 type input, we support both quantized and non-quantized
93 // evaluation.
94 if (input->type == kTfLiteInt8 ||
95 (input->type == kTfLiteInt16 &&
96 input->quantization.type != kTfLiteNoQuantization)) {
97 TfLiteTensor* output = GetOutput(context, node, 0);
98 auto* op_data = static_cast<OpData*>(node->user_data);
99 TF_LITE_ENSURE_EQ(context, input->quantization.type,
100 kTfLiteAffineQuantization);
101 TF_LITE_ENSURE_EQ(context, output->quantization.type,
102 kTfLiteAffineQuantization);
103 const auto* input_params =
104 reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
105 const auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
106 output->quantization.params);
107 TF_LITE_ENSURE(context, input_params != nullptr);
108 TF_LITE_ENSURE(context, input_params->scale != nullptr);
109 TF_LITE_ENSURE(context, input_params->scale->size > 0);
110 TF_LITE_ENSURE(context, input_params->zero_point->size > 0);
111 TF_LITE_ENSURE(context, output_params != nullptr);
112 TF_LITE_ENSURE(context, output_params->scale != nullptr);
113 TF_LITE_ENSURE(context, output_params->scale->size > 0);
114 TF_LITE_ENSURE(context, output_params->zero_point->size > 0);
115 op_data->input_offset = input_params->zero_point->data[0];
116 op_data->output_offset = output_params->zero_point->data[0];
117 if (input->type == kTfLiteInt16) {
118 TF_LITE_ENSURE_EQ(context, op_data->input_offset, 0);
119 TF_LITE_ENSURE_EQ(context, op_data->output_offset, 0);
120 }
121 const float input_scale = input_params->scale->data[0];
122 const float output_scale = output_params->scale->data[0];
123 op_data->needs_rescale = input_scale != output_scale;
124 if (op_name == kAbsName && op_data->needs_rescale) {
125 SetAbsOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
126 &op_data->shift);
127 } else if (op_name == kRsqrtName) {
128 SetRsqrtOutputMultiplier(input_scale, output_scale, &op_data->multiplier,
129 &op_data->shift);
130 }
131 }
132 return context->ResizeTensor(context, output,
133 TfLiteIntArrayCopy(input->dims));
134}
135
136template <typename T>
137inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
138 std::function<T(T)> func,
139 std::function<TfLiteStatus(T)> validate_input_func,
140 TfLiteType expected_type) {
141 const TfLiteTensor* input;
142 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
143 TfLiteTensor* output;
144 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
145 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
146 const int64_t num_elements = NumElements(input);
147 const T* in_data = GetTensorData<T>(input);
148 T* out_data = GetTensorData<T>(output);
149 for (int64_t i = 0; i < num_elements; ++i) {
150 if (validate_input_func) {
151 TF_LITE_ENSURE_OK(context, validate_input_func(in_data[i]));
152 }
153 out_data[i] = func(in_data[i]);
154 }
155 return kTfLiteOk;
156}
157
158// Non-quantized evaluation of Abs op when input is int16.
159inline TfLiteStatus AbsInt16EvalImpl(TfLiteContext* context, TfLiteNode* node,
160 TfLiteType expected_type) {
161 const TfLiteTensor* input;
162 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
163 TfLiteTensor* output;
164 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
165 TF_LITE_ENSURE_TYPES_EQ(context, input->type, expected_type);
166 const int64_t num_elements = NumElements(input);
167 const int16_t* in_data = GetTensorData<int16_t>(input);
168 int16_t* out_data = GetTensorData<int16_t>(output);
169 for (int64_t i = 0; i < num_elements; ++i) {
170 out_data[i] = static_cast<int16_t>(
171 std::abs<int32_t>(static_cast<int32_t>(in_data[i])));
172 }
173 return kTfLiteOk;
174}
175
176template <typename T>
177inline TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node,
178 std::function<T(T)> func,
179 TfLiteType expected_type) {
180 return EvalImpl<T>(context, node, func, /*validate_input_func=*/nullptr,
181 expected_type);
182}
183
184inline TfLiteStatus EvalNumeric(TfLiteContext* context, TfLiteNode* node,
185 float float_func(float)) {
186 return EvalImpl<float>(context, node, float_func, kTfLiteFloat32);
187}
188
189inline TfLiteStatus EvalLogical(TfLiteContext* context, TfLiteNode* node,
190 bool bool_func(bool)) {
191 return EvalImpl<bool>(context, node, bool_func, kTfLiteBool);
192}
193
194void* ElementWiseQuantizedInit(TfLiteContext* context, const char* buffer,
195 size_t length) {
196 return new OpData();
197}
198
199void ElementWiseQuantizedFree(TfLiteContext* context, void* buffer) {
200 delete static_cast<OpData*>(buffer);
201}
202
203template <typename T>
204TfLiteStatus AbsEvalQuantized(TfLiteContext* context, TfLiteNode* node,
205 TfLiteType type) {
206 const auto* op_data = static_cast<const OpData*>(node->user_data);
207 const int kMin = std::numeric_limits<T>::min();
208 const int kMax = std::numeric_limits<T>::max();
209
210 std::function<T(T)> func = [&](T i) {
211 const int32_t value = std::abs(i - op_data->input_offset);
212 if (!op_data->needs_rescale) {
213 return static_cast<T>(
214 std::min(std::max(value + op_data->output_offset, kMin), kMax));
215 }
216 const int32_t output = MultiplyByQuantizedMultiplier(
217 value, op_data->multiplier, op_data->shift) +
218 op_data->output_offset;
219 return static_cast<T>(std::min(std::max(output, kMin), kMax));
220 };
221
222 return EvalImpl<T>(context, node, func, type);
223}
224
225TfLiteStatus AbsEval(TfLiteContext* context, TfLiteNode* node) {
226 const TfLiteTensor* input = GetInput(context, node, 0);
227 const TfLiteType type = input->type;
228 switch (type) {
229 case kTfLiteFloat32:
230 return EvalImpl<float>(context, node, std::abs<float>, type);
231 case kTfLiteInt8:
232 return AbsEvalQuantized<int8_t>(context, node, type);
233 case kTfLiteInt16:
234 return input->quantization.type == kTfLiteNoQuantization
235 ? AbsInt16EvalImpl(context, node, type)
236 : AbsEvalQuantized<int16_t>(context, node, type);
237 default:
238 TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
239 TfLiteTypeGetName(type));
240 return kTfLiteError;
241 }
242}
243
244TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) {
245 return EvalNumeric(context, node, std::sin);
246}
247
248TfLiteStatus CosEval(TfLiteContext* context, TfLiteNode* node) {
249 return EvalNumeric(context, node, std::cos);
250}
251
252TfLiteStatus LogEval(TfLiteContext* context, TfLiteNode* node) {
253 return EvalNumeric(context, node, std::log);
254}
255
256TfLiteStatus SqrtEval(TfLiteContext* context, TfLiteNode* node) {
257 return EvalNumeric(context, node, std::sqrt);
258}
259
260TfLiteStatus RsqrtEvalQuantized(TfLiteContext* context, TfLiteNode* node,
261 TfLiteType type) {
262 const auto* op_data = static_cast<const OpData*>(node->user_data);
263 const int kMin = std::numeric_limits<int8_t>::min();
264 const int kMax = std::numeric_limits<int8_t>::max();
265 std::function<TfLiteStatus(int8_t)> validate_input_func = [&](int8_t i) {
266 TF_LITE_ENSURE_MSG(context, i >= op_data->input_offset,
267 "Rsqrt is only defined for positive values");
268 return kTfLiteOk;
269 };
270
271 std::function<int8_t(int8_t)> func = [&](int8_t i) {
272 const int32_t value = (i - op_data->input_offset);
273 const int32_t kShift = 20; // Shift to keep value integer.
274 if (value == 0) {
275 // Assume that any value close to 0 represents the max output value.
276 return static_cast<int8_t>(kMax);
277 }
278 int32_t inv_sqrt_multiplier;
279 int inv_sqrt_shift;
280 GetInvSqrtQuantizedMultiplierExp(value, kReverseShift, &inv_sqrt_multiplier,
281 &inv_sqrt_shift);
282 const int32_t data = MultiplyByQuantizedMultiplier(1, inv_sqrt_multiplier,
283 inv_sqrt_shift + kShift);
284 const int32_t output =
285 MultiplyByQuantizedMultiplier(data, op_data->multiplier,
286 op_data->shift - kShift) +
287 op_data->output_offset;
288 return static_cast<int8_t>(std::min(std::max(output, kMin), kMax));
289 };
290
291 return EvalImpl<int8_t>(context, node, func, validate_input_func, type);
292}
293
294TfLiteStatus RsqrtEval(TfLiteContext* context, TfLiteNode* node) {
295 const TfLiteType type = GetInput(context, node, 0)->type;
296 switch (type) {
297 case kTfLiteFloat32:
298 return EvalImpl<float>(
299 context, node, [](float f) { return 1.f / std::sqrt(f); }, type);
300 case kTfLiteInt8:
301 return RsqrtEvalQuantized(context, node, type);
302 default:
303 TF_LITE_KERNEL_LOG(context, "Current data type %s is not supported.",
304 TfLiteTypeGetName(type));
305 return kTfLiteError;
306 }
307}
308
309TfLiteStatus SquareEval(TfLiteContext* context, TfLiteNode* node) {
310 return EvalNumeric(context, node, [](float f) { return f * f; });
311}
312
313TfLiteStatus LogicalNotEval(TfLiteContext* context, TfLiteNode* node) {
314 return EvalLogical(context, node, [](bool v) { return !v; });
315}
316
317} // namespace
318} // namespace elementwise
319
320// Given a function...
321// template<int T>
322// int Foo(int b)
323//
324// typedef int(*Bar)(int);
325//
326// MSVC2015 will not see Foo<10> as the same type as Bar.
327//
328// This works around the issue by instantiating wrapper methods around
329// elementwise::GenericPrepare() rather than using a templated
330// elementwise::GenericPrepare method.
331#define GENERIC_PREPARE(function_name, is_supported_type_function, type_name) \
332 static TfLiteStatus function_name(TfLiteContext* context, \
333 TfLiteNode* node) { \
334 return elementwise::GenericPrepare(context, node, \
335 is_supported_type_function, type_name); \
336 }
337
338GENERIC_PREPARE(PrepareAbs, elementwise::IsAbsSupportedType,
339 elementwise::kAbsName)
340
341TfLiteRegistration* Register_ABS() {
342 static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
343 elementwise::ElementWiseQuantizedFree,
344 PrepareAbs, elementwise::AbsEval};
345 return &r;
346}
347
348GENERIC_PREPARE(PrepareSin, elementwise::IsNumericSupportedType, "Sin")
349
350TfLiteRegistration* Register_SIN() {
351 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareSin,
352 elementwise::SinEval};
353 return &r;
354}
355
356GENERIC_PREPARE(PrepareCos, elementwise::IsNumericSupportedType, "Cos")
357
358TfLiteRegistration* Register_COS() {
359 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareCos,
360 elementwise::CosEval};
361 return &r;
362}
363
364GENERIC_PREPARE(PrepareLog, elementwise::IsNumericSupportedType, "Log")
365
366TfLiteRegistration* Register_LOG() {
367 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareLog,
368 elementwise::LogEval};
369 return &r;
370}
371
372GENERIC_PREPARE(PrepareSqrt, elementwise::IsNumericSupportedType, "Sqrt")
373
374TfLiteRegistration* Register_SQRT() {
375 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
376 PrepareSqrt, elementwise::SqrtEval};
377 return &r;
378}
379
380GENERIC_PREPARE(PrepareRsqrt, elementwise::IsRsqrtSupportedType,
381 elementwise::kRsqrtName)
382
383TfLiteRegistration* Register_RSQRT() {
384 static TfLiteRegistration r = {elementwise::ElementWiseQuantizedInit,
385 elementwise::ElementWiseQuantizedFree,
386 PrepareRsqrt, elementwise::RsqrtEval};
387 return &r;
388}
389
390GENERIC_PREPARE(PrepareSquare, elementwise::IsNumericSupportedType, "Square")
391
392TfLiteRegistration* Register_SQUARE() {
393 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
394 PrepareSquare, elementwise::SquareEval};
395 return &r;
396}
397
398GENERIC_PREPARE(PrepareNot, elementwise::IsLogicalSupportedType, "Not")
399
400TfLiteRegistration* Register_LOGICAL_NOT() {
401 static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, PrepareNot,
402 elementwise::LogicalNotEval};
403 return &r;
404}
405
406} // namespace builtin
407} // namespace ops
408} // namespace tflite
409