1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stddef.h> |
16 | |
17 | #include <algorithm> |
18 | #include <cmath> |
19 | #include <cstdint> |
20 | #include <functional> |
21 | #include <limits> |
22 | |
23 | #include "tensorflow/lite/c/builtin_op_data.h" |
24 | #include "tensorflow/lite/c/common.h" |
25 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
26 | #include "tensorflow/lite/kernels/internal/common.h" |
27 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
28 | #include "tensorflow/lite/kernels/internal/cppmath.h" |
29 | #include "tensorflow/lite/kernels/internal/optimized/integer_ops/leaky_relu.h" |
30 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
31 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
32 | #include "tensorflow/lite/kernels/internal/reference/binary_function.h" |
33 | #include "tensorflow/lite/kernels/internal/reference/gelu.h" |
34 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/log_softmax.h" |
35 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/logistic.h" |
36 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/tanh.h" |
37 | #include "tensorflow/lite/kernels/internal/reference/logistic.h" |
38 | #include "tensorflow/lite/kernels/internal/reference/prelu.h" |
39 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
40 | #include "tensorflow/lite/kernels/internal/reference/softmax.h" |
41 | #include "tensorflow/lite/kernels/internal/reference/tanh.h" |
42 | #include "tensorflow/lite/kernels/internal/tensor.h" |
43 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
44 | #include "tensorflow/lite/kernels/internal/types.h" |
45 | #include "tensorflow/lite/kernels/kernel_util.h" |
46 | |
47 | #if __aarch64__ && __clang__ |
48 | #include <arm_neon.h> |
49 | #endif |
50 | |
51 | namespace tflite { |
52 | namespace ops { |
53 | namespace builtin { |
54 | namespace activations { |
55 | |
56 | // TODO(b/142762739): We should figure out a multi-threading plan for most of |
57 | // the activation ops below. |
58 | |
59 | enum KernelType { |
60 | kReference, |
61 | kGenericOptimized, |
62 | kFixedPointOptimized, |
63 | }; |
64 | |
65 | struct OpData { |
66 | int32_t input_multiplier = 0; |
67 | int input_left_shift = 0; |
68 | int32_t input_range_radius = 0; |
69 | int diff_min = 0; |
70 | uint8_t table[256] = {0}; |
71 | }; |
72 | |
73 | struct SoftmaxOpData { |
74 | struct SoftmaxParams params = {}; |
75 | float table[256]; |
76 | #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT |
77 | uint8_t uint8_table1[256]; |
78 | uint8_t uint8_table2[256]; |
79 | #endif |
80 | static constexpr int kInt16LUTArraySize = lut_size<int16_t>(); |
81 | int16_t exp_lut[kInt16LUTArraySize]; // int16 LUT for exp(x), where x uniform |
82 | // distributed between [-10.0 , 0.0] |
83 | int16_t one_over_one_plus_x_lut[kInt16LUTArraySize]; // int16 LUT for 1 / |
84 | // (1 + x), where x |
85 | // uniform distributed |
86 | // between [0.0 , 1.0] |
87 | }; |
88 | |
89 | struct LogSoftmaxOpData : public OpData { |
90 | int32_t reverse_scaling_divisor = 0; |
91 | int32_t reverse_scaling_right_shift = 0; |
92 | struct SoftmaxParams params = {}; |
93 | float f_table[256]; |
94 | }; |
95 | |
96 | struct LeakyReluOpData : public OpData { |
97 | int32_t output_multiplier_alpha = 0; |
98 | int32_t output_shift_alpha = 0; |
99 | int32_t output_multiplier_identity = 0; |
100 | int32_t output_shift_identity = 0; |
101 | }; |
102 | |
103 | struct PreluOpData : public OpData { |
104 | int32_t output_multiplier_1 = 0; |
105 | int32_t output_shift_1 = 0; |
106 | int32_t output_multiplier_2 = 0; |
107 | int32_t output_shift_2 = 0; |
108 | bool requires_broadcast; |
109 | }; |
110 | |
111 | struct HardSwishData { |
112 | HardSwishParams params; |
113 | }; |
114 | |
115 | struct ReluOpData : public OpData { |
116 | int32_t output_multiplier = 0; |
117 | int output_shift = 0; |
118 | }; |
119 | |
120 | namespace { |
121 | template <typename T> |
122 | void PopulateLookupTable(struct OpData* data, const TfLiteTensor* input, |
123 | TfLiteTensor* output, |
124 | const std::function<float(float)>& transform) { |
125 | static_assert(sizeof(T) == 1, "Lookup table valid only for 8bit" ); |
126 | const float inverse_scale = 1 / output->params.scale; |
127 | int32_t maxval = std::numeric_limits<T>::max(); |
128 | int32_t minval = std::numeric_limits<T>::min(); |
129 | for (int32_t val = minval; val <= maxval; ++val) { |
130 | const float dequantized = |
131 | input->params.scale * (val - input->params.zero_point); |
132 | const float transformed = transform(dequantized); |
133 | const float rescaled = std::round(transformed * inverse_scale); |
134 | const int32_t quantized = |
135 | static_cast<int32_t>(rescaled + output->params.zero_point); |
136 | data->table[static_cast<uint8_t>(static_cast<T>(val))] = |
137 | static_cast<uint8_t>( |
138 | static_cast<T>(std::max(std::min(maxval, quantized), minval))); |
139 | } |
140 | } |
141 | |
142 | // TODO(b/143696793): move this to optimized_ops. |
143 | void EvalUsingLookupTable(struct OpData* data, const TfLiteTensor* input, |
144 | TfLiteTensor* output) { |
145 | const int size = |
146 | MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)); |
147 | uint8_t* output_data = GetTensorData<uint8_t>(output); |
148 | const uint8_t* input_data = GetTensorData<uint8_t>(input); |
149 | int i = 0; |
150 | #if __aarch64__ && __clang__ |
151 | // This code uses ARM64-only instructions. |
152 | // TODO(b/143709993): Port to ARMv7 |
153 | |
154 | // Load the tables into registers. (4*4 128-bit registers) |
155 | uint8x16x4_t table[4]; |
156 | table[0] = vld1q_u8_x4(data->table + 16 * 4 * 0); |
157 | table[1] = vld1q_u8_x4(data->table + 16 * 4 * 1); |
158 | table[2] = vld1q_u8_x4(data->table + 16 * 4 * 2); |
159 | table[3] = vld1q_u8_x4(data->table + 16 * 4 * 3); |
160 | |
161 | // Vectorized loop; process uint8x16_t (16 elements) at a time. |
162 | constexpr int vectorized_16_loop_step = 16; |
163 | const int vectorized_16_loop_end = |
164 | size / vectorized_16_loop_step * vectorized_16_loop_step; |
165 | for (; i < vectorized_16_loop_end; i += vectorized_16_loop_step) { |
166 | uint8x16_t input = vld1q_u8(input_data + i); |
167 | uint8x16_t output = optimized_ops::aarch64_lookup_vector(table, input); |
168 | vst1q_u8(output_data + i, output); |
169 | } |
170 | // Postamble and non-ARM64 code: simple for loop. |
171 | #endif |
172 | for (; i < size; ++i) { |
173 | output_data[i] = data->table[input_data[i]]; |
174 | } |
175 | } |
176 | |
177 | template <typename T> |
178 | void QuantizedReluX(float act_min, float act_max, const TfLiteTensor* input, |
179 | TfLiteTensor* output, const ReluOpData* data) { |
180 | ReluParams params; |
181 | params.quantized_activation_min = |
182 | std::max(static_cast<int32_t>(std::numeric_limits<T>::min()), |
183 | output->params.zero_point + |
184 | static_cast<int32>(roundf(act_min / output->params.scale))); |
185 | params.quantized_activation_max = |
186 | act_max == std::numeric_limits<float>::infinity() |
187 | ? static_cast<int32_t>(std::numeric_limits<T>::max()) |
188 | : std::min( |
189 | static_cast<int32_t>(std::numeric_limits<T>::max()), |
190 | output->params.zero_point + |
191 | static_cast<int32>(roundf(act_max / output->params.scale))); |
192 | params.input_offset = input->params.zero_point; |
193 | params.output_offset = output->params.zero_point; |
194 | params.output_multiplier = data->output_multiplier; |
195 | params.output_shift = data->output_shift; |
196 | optimized_ops::ReluX(params, GetTensorShape(input), GetTensorData<T>(input), |
197 | GetTensorShape(output), GetTensorData<T>(output)); |
198 | } |
199 | |
200 | } // namespace |
201 | |
202 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
203 | // This is a builtin op, so we don't use the contents in 'buffer', if any. |
204 | // Instead, we allocate a new object to carry information from Prepare() to |
205 | // Eval(). |
206 | return new OpData; |
207 | } |
208 | |
209 | void* SoftmaxInit(TfLiteContext* context, const char* buffer, size_t length) { |
210 | return new SoftmaxOpData; |
211 | } |
212 | |
213 | void SoftmaxFree(TfLiteContext* context, void* buffer) { |
214 | delete reinterpret_cast<SoftmaxOpData*>(buffer); |
215 | } |
216 | |
217 | void* LogSoftmaxInit(TfLiteContext* context, const char* buffer, |
218 | size_t length) { |
219 | return new LogSoftmaxOpData; |
220 | } |
221 | |
222 | void* PreluInit(TfLiteContext* context, const char* buffer, size_t length) { |
223 | return new PreluOpData; |
224 | } |
225 | |
226 | void Free(TfLiteContext* context, void* buffer) { |
227 | delete reinterpret_cast<OpData*>(buffer); |
228 | } |
229 | |
230 | void LogSoftmaxFree(TfLiteContext* context, void* buffer) { |
231 | delete reinterpret_cast<LogSoftmaxOpData*>(buffer); |
232 | } |
233 | |
234 | void PreluFree(TfLiteContext* context, void* buffer) { |
235 | delete reinterpret_cast<PreluOpData*>(buffer); |
236 | } |
237 | |
238 | void* HardSwishInit(TfLiteContext* context, const char* buffer, size_t length) { |
239 | return new HardSwishData; |
240 | } |
241 | |
242 | TfLiteStatus GenericPrepare(TfLiteContext* context, TfLiteNode* node) { |
243 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
244 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
245 | const TfLiteTensor* input; |
246 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
247 | TfLiteTensor* output; |
248 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
249 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
250 | |
251 | return context->ResizeTensor(context, output, |
252 | TfLiteIntArrayCopy(input->dims)); |
253 | } |
254 | |
255 | void* ReluInit(TfLiteContext* context, const char* buffer, size_t length) { |
256 | return new ReluOpData; |
257 | } |
258 | |
259 | void ReluFree(TfLiteContext* context, void* buffer) { |
260 | delete reinterpret_cast<ReluOpData*>(buffer); |
261 | } |
262 | |
263 | TfLiteStatus ReluPrepare(TfLiteContext* context, TfLiteNode* node) { |
264 | ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); |
265 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
266 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
267 | const TfLiteTensor* input; |
268 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
269 | TfLiteTensor* output; |
270 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
271 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
272 | |
273 | if (input->type == kTfLiteInt8 || input->type == kTfLiteUInt8 || |
274 | input->type == kTfLiteInt16) { |
275 | double real_multiplier = input->params.scale / output->params.scale; |
276 | QuantizeMultiplier(real_multiplier, &data->output_multiplier, |
277 | &data->output_shift); |
278 | } |
279 | |
280 | if (input->type == kTfLiteInt16) { |
281 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
282 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
283 | } |
284 | |
285 | return context->ResizeTensor(context, output, |
286 | TfLiteIntArrayCopy(input->dims)); |
287 | } |
288 | |
289 | void* LeakyReluInit(TfLiteContext* context, const char* buffer, size_t length) { |
290 | return new LeakyReluOpData; |
291 | } |
292 | |
293 | void LeakyReluFree(TfLiteContext* context, void* buffer) { |
294 | delete reinterpret_cast<LeakyReluOpData*>(buffer); |
295 | } |
296 | |
297 | void HardSwishFree(TfLiteContext* context, void* buffer) { |
298 | delete static_cast<HardSwishData*>(buffer); |
299 | } |
300 | |
301 | TfLiteStatus HardSwishPrepare(TfLiteContext* context, TfLiteNode* node) { |
302 | TF_LITE_ENSURE_STATUS(GenericPrepare(context, node)); |
303 | TfLiteTensor* output; |
304 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
305 | |
306 | if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { |
307 | HardSwishData* data = static_cast<HardSwishData*>(node->user_data); |
308 | HardSwishParams* params = &data->params; |
309 | const TfLiteTensor* input; |
310 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
311 | params->input_zero_point = input->params.zero_point; |
312 | params->output_zero_point = output->params.zero_point; |
313 | const float input_scale = input->params.scale; |
314 | const float hires_input_scale = (1.0f / 128.0f) * input_scale; |
315 | const float reluish_scale = 3.0f / 32768.0f; |
316 | const float output_scale = output->params.scale; |
317 | |
318 | const float output_multiplier = hires_input_scale / output_scale; |
319 | |
320 | int32_t output_multiplier_fixedpoint_int32; |
321 | QuantizeMultiplier(output_multiplier, &output_multiplier_fixedpoint_int32, |
322 | ¶ms->output_multiplier_exponent); |
323 | DownScaleInt32ToInt16Multiplier( |
324 | output_multiplier_fixedpoint_int32, |
325 | ¶ms->output_multiplier_fixedpoint_int16); |
326 | TF_LITE_ENSURE(context, params->output_multiplier_exponent <= 0); |
327 | |
328 | const float reluish_multiplier = hires_input_scale / reluish_scale; |
329 | int32_t reluish_multiplier_fixedpoint_int32; |
330 | QuantizeMultiplier(reluish_multiplier, &reluish_multiplier_fixedpoint_int32, |
331 | ¶ms->reluish_multiplier_exponent); |
332 | DownScaleInt32ToInt16Multiplier( |
333 | reluish_multiplier_fixedpoint_int32, |
334 | ¶ms->reluish_multiplier_fixedpoint_int16); |
335 | } |
336 | return kTfLiteOk; |
337 | } |
338 | |
339 | TfLiteStatus LeakyReluPrepare(TfLiteContext* context, TfLiteNode* node) { |
340 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
341 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
342 | const TfLiteTensor* input; |
343 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
344 | TfLiteTensor* output; |
345 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
346 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
347 | |
348 | LeakyReluOpData* data = reinterpret_cast<LeakyReluOpData*>(node->user_data); |
349 | |
350 | if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || |
351 | output->type == kTfLiteInt16) { |
352 | const auto* params = |
353 | reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data); |
354 | |
355 | double alpha_multiplier = |
356 | input->params.scale * params->alpha / output->params.scale; |
357 | QuantizeMultiplier(alpha_multiplier, &data->output_multiplier_alpha, |
358 | &data->output_shift_alpha); |
359 | double identity_multiplier = input->params.scale / output->params.scale; |
360 | QuantizeMultiplier(identity_multiplier, &data->output_multiplier_identity, |
361 | &data->output_shift_identity); |
362 | } |
363 | |
364 | if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) { |
365 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
366 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
367 | } |
368 | |
369 | return context->ResizeTensor(context, output, |
370 | TfLiteIntArrayCopy(input->dims)); |
371 | } |
372 | |
373 | template <KernelType kernel_type> |
374 | TfLiteStatus TanhPrepare(TfLiteContext* context, TfLiteNode* node) { |
375 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
376 | |
377 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
378 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
379 | const TfLiteTensor* input; |
380 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
381 | TfLiteTensor* output; |
382 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
383 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
384 | |
385 | if (kernel_type == kFixedPointOptimized) { |
386 | if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { |
387 | static constexpr int kInputIntegerBits = 4; |
388 | |
389 | const double input_real_multiplier = |
390 | input->params.scale * |
391 | static_cast<double>(1 << (15 - kInputIntegerBits)); |
392 | |
393 | const double q = |
394 | std::frexp(input_real_multiplier, &data->input_left_shift); |
395 | auto q_fixed = static_cast<int32_t>(TfLiteRound(q * (1LL << 15))); |
396 | data->input_multiplier = static_cast<int16_t>(q_fixed); |
397 | |
398 | int16_t input_range_radius = |
399 | CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 15); |
400 | data->input_range_radius = input_range_radius; |
401 | } |
402 | } |
403 | |
404 | if (kernel_type == kGenericOptimized || kernel_type == kReference) { |
405 | if (input->type == kTfLiteUInt8) { |
406 | PopulateLookupTable<uint8_t>( |
407 | data, input, output, [](float value) { return std::tanh(value); }); |
408 | } else if (input->type == kTfLiteInt8) { |
409 | PopulateLookupTable<int8_t>(data, input, output, |
410 | [](float value) { return std::tanh(value); }); |
411 | } |
412 | } |
413 | |
414 | if (input->type == kTfLiteInt16) { |
415 | static constexpr int kInputIntegerBits = 3; |
416 | static constexpr int kOutputFractionalBits = 15; |
417 | |
418 | // These operators are implemented in fixed-point arithmetic, |
419 | // which intrinsically wants symmetric ranges (zero_point==0) |
420 | // and power-of-two scales (power-of-two is abbreviated below as POT). |
421 | // While more general support would be possible by means of rescaling, |
422 | // that would add some overhead and some loss of accuracy and wouldn't |
423 | // be used at the moment as current quantized LSTM applications are |
424 | // happy with symmetric, power-of-two-scales quantization. So we just |
425 | // implement that narrow case only for now. |
426 | |
427 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
428 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
429 | |
430 | int input_scale_log2_rounded; |
431 | bool param_scale_pot = |
432 | CheckedLog2(input->params.scale, &input_scale_log2_rounded); |
433 | |
434 | data->input_left_shift = |
435 | (15 - kInputIntegerBits) + input_scale_log2_rounded; |
436 | param_scale_pot &= |
437 | (data->input_left_shift == 0 || data->input_left_shift == 1); |
438 | |
439 | if (!param_scale_pot) { |
440 | // Calculate multiplier to change input scale to 1/(3*4096) |
441 | // as required by the table lookup. |
442 | // The number 3.0 in the multiplier comes from here, |
443 | // because the interval is [-10.7, 10.7] instead of [-8, 8]. |
444 | // So, in this scaling +/-2^17 represents +/-10.7. |
445 | |
446 | double multiplier = input->params.scale * 4096.0 * 3.0; |
447 | data->input_left_shift = 0; |
448 | |
449 | while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) { |
450 | data->input_left_shift++; |
451 | multiplier = multiplier * 2.0; |
452 | } |
453 | |
454 | data->input_multiplier = static_cast<int32_t>(multiplier); |
455 | } |
456 | |
457 | int output_scale_log2_rounded; |
458 | TF_LITE_ENSURE( |
459 | context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); |
460 | TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, |
461 | -kOutputFractionalBits); |
462 | } |
463 | |
464 | return context->ResizeTensor(context, output, |
465 | TfLiteIntArrayCopy(input->dims)); |
466 | } |
467 | |
468 | template <KernelType kernel_type> |
469 | TfLiteStatus SigmoidPrepare(TfLiteContext* context, TfLiteNode* node) { |
470 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
471 | |
472 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
473 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
474 | const TfLiteTensor* input; |
475 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
476 | TfLiteTensor* output; |
477 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
478 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
479 | |
480 | if (kernel_type == kFixedPointOptimized) { |
481 | if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { |
482 | if (input->type == kTfLiteUInt8) { |
483 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, |
484 | std::numeric_limits<uint8_t>::min()); |
485 | } |
486 | if (input->type == kTfLiteInt8) { |
487 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, |
488 | std::numeric_limits<int8_t>::min()); |
489 | } |
490 | TF_LITE_ENSURE(context, output->params.scale == 1. / 256); |
491 | |
492 | static constexpr int kInputIntegerBits = 4; |
493 | |
494 | const double input_real_multiplier = |
495 | input->params.scale * |
496 | static_cast<double>(1 << (15 - kInputIntegerBits)); |
497 | |
498 | const double q = |
499 | std::frexp(input_real_multiplier, &data->input_left_shift); |
500 | auto q_fixed = static_cast<int32_t>(TfLiteRound(q * (1LL << 15))); |
501 | data->input_multiplier = static_cast<int16_t>(q_fixed); |
502 | |
503 | int16_t input_range_radius = |
504 | CalculateInputRadius(kInputIntegerBits, data->input_left_shift, 15); |
505 | data->input_range_radius = input_range_radius; |
506 | } |
507 | } |
508 | |
509 | if (kernel_type == kGenericOptimized || kernel_type == kReference) { |
510 | if (input->type == kTfLiteUInt8) { |
511 | TF_LITE_ENSURE(context, output->params.scale == 1. / 256); |
512 | PopulateLookupTable<uint8_t>(data, input, output, [](float value) { |
513 | return 1.0f / (1.0f + std::exp(-value)); |
514 | }); |
515 | } else if (input->type == kTfLiteInt8) { |
516 | TF_LITE_ENSURE(context, output->params.scale == 1. / 256); |
517 | PopulateLookupTable<int8_t>(data, input, output, [](float value) { |
518 | return 1.0f / (1.0f + std::exp(-value)); |
519 | }); |
520 | } else if (input->type == kTfLiteInt16) { |
521 | TF_LITE_ENSURE(context, output->params.scale == 1. / 32768); |
522 | TF_LITE_ENSURE(context, output->params.zero_point == 0); |
523 | } |
524 | } |
525 | |
526 | if (input->type == kTfLiteInt16) { |
527 | static constexpr int kInputIntegerBits = 3; |
528 | static constexpr int kOutputFractionalBits = 15; |
529 | |
530 | // See comments in TanhPrepare about requiring zero_point==0 |
531 | // and a power-of-two ("POT") scale. |
532 | |
533 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
534 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
535 | |
536 | int input_scale_log2_rounded; |
537 | bool param_scale_pot = |
538 | CheckedLog2(input->params.scale, &input_scale_log2_rounded); |
539 | |
540 | data->input_left_shift = |
541 | (15 - kInputIntegerBits) + input_scale_log2_rounded; |
542 | param_scale_pot &= (data->input_left_shift == 0); |
543 | |
544 | if (!param_scale_pot) { |
545 | // Calculate multiplier to change input scale to 1/(3*4096) |
546 | // as required by the table lookup. |
547 | // In this scaling +/-2^17 represents +/-10.7 |
548 | double multiplier = input->params.scale * 4096.0 * 3.0; |
549 | |
550 | data->input_left_shift = 0; |
551 | |
552 | while (multiplier <= 32767.0 / 2.0 && data->input_left_shift <= 30) { |
553 | data->input_left_shift++; |
554 | multiplier = multiplier * 2.0; |
555 | } |
556 | |
557 | data->input_multiplier = static_cast<int32_t>(multiplier); |
558 | } |
559 | |
560 | int output_scale_log2_rounded; |
561 | TF_LITE_ENSURE( |
562 | context, CheckedLog2(output->params.scale, &output_scale_log2_rounded)); |
563 | TF_LITE_ENSURE_EQ(context, output_scale_log2_rounded, |
564 | -kOutputFractionalBits); |
565 | } |
566 | |
567 | return context->ResizeTensor(context, output, |
568 | TfLiteIntArrayCopy(input->dims)); |
569 | } |
570 | |
571 | template <KernelType kernel_type> |
572 | TfLiteStatus SoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { |
573 | auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); |
574 | SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data); |
575 | |
576 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
577 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
578 | const TfLiteTensor* input; |
579 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
580 | TfLiteTensor* output; |
581 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
582 | |
583 | TF_LITE_ENSURE(context, NumDimensions(input) >= 1); |
584 | |
585 | if (input->type == kTfLiteInt8 && output->type == kTfLiteInt8) { |
586 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128); |
587 | TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 256, |
588 | (0.001f * 1.f / 256)); |
589 | } else if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) { |
590 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
591 | TF_LITE_ENSURE_NEAR(context, output->params.scale, 1.f / 32768, |
592 | (0.001f * 1.f / 32768)); |
593 | } |
594 | |
595 | if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { |
596 | if (kernel_type == kReference) { |
597 | const int kScaledDiffIntegerBits = 5; |
598 | int input_left_shift; |
599 | tflite::PreprocessSoftmaxScaling( |
600 | static_cast<double>(params->beta), |
601 | static_cast<double>(input->params.scale), kScaledDiffIntegerBits, |
602 | &data->params.input_multiplier, &input_left_shift); |
603 | data->params.input_left_shift = input_left_shift; |
604 | data->params.diff_min = |
605 | -1.0 * tflite::CalculateInputRadius(kScaledDiffIntegerBits, |
606 | input_left_shift); |
607 | } else { |
608 | switch (output->type) { |
609 | case kTfLiteUInt8: |
610 | case kTfLiteInt8: |
611 | #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT |
612 | // Only apply when both input & output are uint8/int8 & build with |
613 | // clang on aarch64. |
614 | // TODO(b/143709993): Port to ARMv7 and other platforms. |
615 | data->params.uint8_table1 = data->uint8_table1; |
616 | data->params.uint8_table2 = data->uint8_table2; |
617 | optimized_ops::PopulateSoftmaxUInt8LookupTable( |
618 | &data->params, input->params.scale, params->beta); |
619 | break; |
620 | #endif |
621 | case kTfLiteInt16: |
622 | default: |
623 | data->params.table = data->table; |
624 | optimized_ops::PopulateSoftmaxLookupTable( |
625 | &data->params, input->params.scale, params->beta); |
626 | } |
627 | |
628 | data->params.zero_point = output->params.zero_point; |
629 | data->params.scale = output->params.scale; |
630 | } |
631 | } else if (input->type == kTfLiteInt16) { |
632 | TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0); |
633 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
634 | |
635 | data->params.exp_lut = data->exp_lut; |
636 | // exp LUT only used on nagative values |
637 | // we consider exp(-10.0) is insignificant to accumulation |
638 | gen_lut<double, int16_t, int16_t>( |
639 | [](double value) { return std::exp(value); }, -10.0, 0.0, -1.0, 1.0, |
640 | data->params.exp_lut); |
641 | data->params.one_over_one_plus_x_lut = data->one_over_one_plus_x_lut; |
642 | gen_lut<double, int16_t, int16_t>( |
643 | [](double value) { return 1.0 / (1.0 + value); }, 0.0, 1.0, -1.0, 1.0, |
644 | data->params.one_over_one_plus_x_lut); |
645 | data->params.zero_point = output->params.zero_point; |
646 | data->params.scale = output->params.scale; |
647 | |
648 | double input_scale_beta_rescale = |
649 | input->params.scale * params->beta / |
650 | (10.0 / 65535.0); // scale the input_diff such that [-65535, 0] |
651 | // correspond to [-10.0, 0.0] |
652 | QuantizeMultiplier(input_scale_beta_rescale, &data->params.input_multiplier, |
653 | &data->params.input_left_shift); |
654 | } |
655 | |
656 | return context->ResizeTensor(context, output, |
657 | TfLiteIntArrayCopy(input->dims)); |
658 | } |
659 | |
660 | template <KernelType kernel_type> |
661 | TfLiteStatus LogSoftmaxPrepare(TfLiteContext* context, TfLiteNode* node) { |
662 | LogSoftmaxOpData* data = reinterpret_cast<LogSoftmaxOpData*>(node->user_data); |
663 | |
664 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
665 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
666 | const TfLiteTensor* input; |
667 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
668 | TfLiteTensor* output; |
669 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
670 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type); |
671 | |
672 | if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8) { |
673 | TF_LITE_ENSURE_EQ(context, output->params.scale, 16.0 / 256); |
674 | static const double kBeta = 1.0; |
675 | if (input->type == kTfLiteUInt8) { |
676 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 255); |
677 | } |
678 | if (input->type == kTfLiteInt8) { |
679 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 127); |
680 | } |
681 | |
682 | if (kernel_type == kReference) { |
683 | const int kScaledDiffIntegerBits = 5; |
684 | int input_left_shift; |
685 | int reverse_scaling_right_shift; |
686 | tflite::PreprocessLogSoftmaxScalingExp( |
687 | kBeta, static_cast<double>(input->params.scale), |
688 | kScaledDiffIntegerBits, &data->params.input_multiplier, |
689 | &input_left_shift, &data->params.reverse_scaling_divisor, |
690 | &reverse_scaling_right_shift); |
691 | reverse_scaling_right_shift *= -1; |
692 | data->params.input_left_shift = input_left_shift; |
693 | data->params.reverse_scaling_right_shift = reverse_scaling_right_shift; |
694 | data->params.diff_min = -tflite::CalculateInputRadius( |
695 | kScaledDiffIntegerBits, input_left_shift); |
696 | } else { |
697 | data->params.table = data->f_table; |
698 | optimized_ops::PopulateSoftmaxLookupTable(&data->params, |
699 | input->params.scale, kBeta); |
700 | data->params.zero_point = output->params.zero_point; |
701 | data->params.scale = output->params.scale; |
702 | } |
703 | } |
704 | |
705 | return context->ResizeTensor(context, output, |
706 | TfLiteIntArrayCopy(input->dims)); |
707 | } |
708 | |
709 | TfLiteStatus PreluPrepare(TfLiteContext* context, TfLiteNode* node) { |
710 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
711 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
712 | const TfLiteTensor* input; |
713 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
714 | TfLiteTensor* output; |
715 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
716 | const TfLiteTensor* alpha; |
717 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha)); |
718 | PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data); |
719 | |
720 | TF_LITE_ENSURE_TYPES_EQ(context, input->type, alpha->type); |
721 | |
722 | output->type = input->type; |
723 | |
724 | if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) { |
725 | // prelu(x) = x if x >= 0 else x * alpha. |
726 | // So if we translate that for quantized computation: |
727 | // |
728 | // input_float = (input_q - input_zp) * input_scale |
729 | // output_float = (output_q - output_zp) * output_scale |
730 | // alpha_float = (alpha_q - alpha_zp) * alpha_scale |
731 | // |
732 | // When input_q - input_zp >= 0: |
733 | // output_q = (input_q - input_zp) * input_scale / output_scale + output_q |
734 | // else: |
735 | // output_q = (input_q - input_zp) * (alpha_q - alpha_zp) * input_scale |
736 | // * alpha_scale / output_scale + output_q |
737 | // |
738 | // So for input_q - input_zp >= 0: |
739 | // output real multiplier 1 is input_scale / output_scale; |
740 | // for input_q - input_zp < 0: |
741 | // output real multiplier 2 is input_scale * alpha_scale/ output_scale. |
742 | double real_multiplier_1 = input->params.scale / output->params.scale; |
743 | double real_multiplier_2 = |
744 | input->params.scale * alpha->params.scale / output->params.scale; |
745 | QuantizeMultiplier(real_multiplier_1, &data->output_multiplier_1, |
746 | &data->output_shift_1); |
747 | QuantizeMultiplier(real_multiplier_2, &data->output_multiplier_2, |
748 | &data->output_shift_2); |
749 | } |
750 | |
751 | data->requires_broadcast = !HaveSameShapes(input, alpha); |
752 | // PRelu (parameteric Relu) shares the same alpha value on "shared axis". |
753 | // This means it's always required to "broadcast" alpha values in PRelu. |
754 | TfLiteIntArray* output_size = nullptr; |
755 | TF_LITE_ENSURE_OK( |
756 | context, CalculateShapeForBroadcast(context, input, alpha, &output_size)); |
757 | |
758 | TF_LITE_ENSURE_OK(context, |
759 | context->ResizeTensor(context, output, output_size)); |
760 | // After broadcasting, the output shape should always be the same as the |
761 | // input shape. |
762 | TF_LITE_ENSURE(context, HaveSameShapes(input, output)); |
763 | |
764 | return kTfLiteOk; |
765 | } |
766 | |
767 | TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { |
768 | const TfLiteTensor* input; |
769 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
770 | TfLiteTensor* output; |
771 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
772 | const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); |
773 | switch (input->type) { |
774 | case kTfLiteFloat32: { |
775 | optimized_ops::Relu(GetTensorShape(input), GetTensorData<float>(input), |
776 | GetTensorShape(output), GetTensorData<float>(output)); |
777 | } break; |
778 | // TODO(renjieliu): We may revisit the quantization calculation logic, |
779 | // the unbounded upper limit is actually hard to quantize. |
780 | case kTfLiteUInt8: { |
781 | QuantizedReluX<uint8_t>(0.0f, std::numeric_limits<float>::infinity(), |
782 | input, output, data); |
783 | } break; |
784 | case kTfLiteInt8: { |
785 | QuantizedReluX<int8_t>(0.0f, std::numeric_limits<float>::infinity(), |
786 | input, output, data); |
787 | } break; |
788 | case kTfLiteInt16: { |
789 | QuantizedReluX<int16_t>(0.0f, std::numeric_limits<float>::infinity(), |
790 | input, output, data); |
791 | } break; |
792 | default: |
793 | TF_LITE_KERNEL_LOG(context, |
794 | "Only float32, uint8, int8 and int16 are supported " |
795 | "currently, got %s." , |
796 | TfLiteTypeGetName(input->type)); |
797 | return kTfLiteError; |
798 | } |
799 | return kTfLiteOk; |
800 | } |
801 | |
802 | TfLiteStatus Relu1Eval(TfLiteContext* context, TfLiteNode* node) { |
803 | const TfLiteTensor* input; |
804 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
805 | TfLiteTensor* output; |
806 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
807 | const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); |
808 | switch (input->type) { |
809 | case kTfLiteFloat32: { |
810 | optimized_ops::Relu1(GetTensorShape(input), GetTensorData<float>(input), |
811 | GetTensorShape(output), |
812 | GetTensorData<float>(output)); |
813 | return kTfLiteOk; |
814 | } |
815 | case kTfLiteUInt8: { |
816 | QuantizedReluX<uint8_t>(-1.0f, 1.0f, input, output, data); |
817 | return kTfLiteOk; |
818 | } |
819 | case kTfLiteInt8: { |
820 | QuantizedReluX<int8_t>(-1, 1, input, output, data); |
821 | return kTfLiteOk; |
822 | } |
823 | default: |
824 | TF_LITE_KERNEL_LOG(context, |
825 | "Only float32, uint8, int8 supported " |
826 | "currently, got %s." , |
827 | TfLiteTypeGetName(input->type)); |
828 | return kTfLiteError; |
829 | } |
830 | } |
831 | |
832 | template <KernelType kernel_type> |
833 | TfLiteStatus HardSwishEval(TfLiteContext* context, TfLiteNode* node) { |
834 | HardSwishData* data = static_cast<HardSwishData*>(node->user_data); |
835 | |
836 | const TfLiteTensor* input; |
837 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
838 | TfLiteTensor* output; |
839 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
840 | switch (input->type) { |
841 | case kTfLiteFloat32: { |
842 | if (kernel_type == kReference) { |
843 | reference_ops::HardSwish( |
844 | GetTensorShape(input), GetTensorData<float>(input), |
845 | GetTensorShape(output), GetTensorData<float>(output)); |
846 | } else { |
847 | optimized_ops::HardSwish( |
848 | GetTensorShape(input), GetTensorData<float>(input), |
849 | GetTensorShape(output), GetTensorData<float>(output)); |
850 | } |
851 | return kTfLiteOk; |
852 | } break; |
853 | case kTfLiteUInt8: { |
854 | HardSwishParams& params = data->params; |
855 | if (kernel_type == kReference) { |
856 | reference_ops::HardSwish( |
857 | params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
858 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
859 | } else { |
860 | optimized_ops::HardSwish( |
861 | params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
862 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
863 | } |
864 | return kTfLiteOk; |
865 | } break; |
866 | case kTfLiteInt8: { |
867 | HardSwishParams& params = data->params; |
868 | if (kernel_type == kReference) { |
869 | reference_ops::HardSwish( |
870 | params, GetTensorShape(input), GetTensorData<int8_t>(input), |
871 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
872 | } else { |
873 | optimized_ops::HardSwish( |
874 | params, GetTensorShape(input), GetTensorData<int8_t>(input), |
875 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
876 | } |
877 | return kTfLiteOk; |
878 | } break; |
879 | default: |
880 | TF_LITE_KERNEL_LOG( |
881 | context, |
882 | "Only float32, uint8 and int8 are supported currently, got %s." , |
883 | TfLiteTypeGetName(input->type)); |
884 | return kTfLiteError; |
885 | } |
886 | } |
887 | |
888 | TfLiteStatus Relu0to1Eval(TfLiteContext* context, TfLiteNode* node) { |
889 | const TfLiteTensor* input; |
890 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
891 | TfLiteTensor* output; |
892 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
893 | const ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); |
894 | switch (input->type) { |
895 | case kTfLiteFloat32: { |
896 | optimized_ops::Relu0To1( |
897 | GetTensorShape(input), GetTensorData<float>(input), |
898 | GetTensorShape(output), GetTensorData<float>(output)); |
899 | return kTfLiteOk; |
900 | } |
901 | case kTfLiteUInt8: { |
902 | QuantizedReluX<uint8_t>(0.0f, 1.0f, input, output, data); |
903 | return kTfLiteOk; |
904 | } |
905 | case kTfLiteInt8: { |
906 | QuantizedReluX<int8_t>(0, 1, input, output, data); |
907 | return kTfLiteOk; |
908 | } |
909 | default: |
910 | TF_LITE_KERNEL_LOG(context, |
911 | "Only float32, uint8, int8 supported " |
912 | "currently, got %s." , |
913 | TfLiteTypeGetName(input->type)); |
914 | return kTfLiteError; |
915 | } |
916 | } |
917 | |
918 | TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { |
919 | const TfLiteTensor* input; |
920 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
921 | TfLiteTensor* output; |
922 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
923 | ReluOpData* data = reinterpret_cast<ReluOpData*>(node->user_data); |
924 | switch (input->type) { |
925 | case kTfLiteFloat32: { |
926 | size_t elements = input->bytes / sizeof(float); |
927 | const float* in = GetTensorData<float>(input); |
928 | const float* in_end = in + elements; |
929 | float* out = GetTensorData<float>(output); |
930 | for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); |
931 | return kTfLiteOk; |
932 | } |
933 | case kTfLiteUInt8: |
934 | QuantizedReluX<uint8_t>(0.0f, 6.0f, input, output, data); |
935 | return kTfLiteOk; |
936 | case kTfLiteInt8: { |
937 | QuantizedReluX<int8_t>(0.0f, 6.0f, input, output, data); |
938 | return kTfLiteOk; |
939 | } |
940 | case kTfLiteInt16: { |
941 | QuantizedReluX<int16_t>(0.0f, 6.0f, input, output, data); |
942 | return kTfLiteOk; |
943 | } |
944 | default: |
945 | TF_LITE_KERNEL_LOG(context, |
946 | "Only float32, uint8, int8 and int16 are supported " |
947 | "currently, got %s." , |
948 | TfLiteTypeGetName(input->type)); |
949 | return kTfLiteError; |
950 | } |
951 | } |
952 | |
953 | template <KernelType kernel_type> |
954 | TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { |
955 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
956 | const TfLiteTensor* input; |
957 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
958 | TfLiteTensor* output; |
959 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
960 | switch (input->type) { |
961 | case kTfLiteFloat32: { |
962 | if (kernel_type == kReference) { |
963 | reference_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input), |
964 | GetTensorShape(output), |
965 | GetTensorData<float>(output)); |
966 | } else { |
967 | optimized_ops::Tanh(GetTensorShape(input), GetTensorData<float>(input), |
968 | GetTensorShape(output), |
969 | GetTensorData<float>(output)); |
970 | } |
971 | return kTfLiteOk; |
972 | } break; |
973 | case kTfLiteInt16: { |
974 | TanhParams params; |
975 | params.input_left_shift = data->input_left_shift; |
976 | if (kernel_type == kReference || (data->input_multiplier > 0)) { |
977 | reference_integer_ops::Tanh( |
978 | data->input_multiplier, data->input_left_shift, |
979 | GetTensorShape(input), GetTensorData<int16_t>(input), |
980 | GetTensorShape(output), GetTensorData<int16_t>(output)); |
981 | } else { |
982 | optimized_ops::Tanh( |
983 | params, GetTensorShape(input), GetTensorData<int16_t>(input), |
984 | GetTensorShape(output), GetTensorData<int16_t>(output)); |
985 | } |
986 | return kTfLiteOk; |
987 | } break; |
988 | case kTfLiteUInt8: { |
989 | if (kernel_type == kFixedPointOptimized) { |
990 | TanhParams params; |
991 | params.input_zero_point = input->params.zero_point; |
992 | params.input_range_radius = data->input_range_radius; |
993 | params.input_multiplier = data->input_multiplier; |
994 | params.input_left_shift = data->input_left_shift; |
995 | optimized_ops::Tanh16bitPrecision( |
996 | params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
997 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
998 | } else { |
999 | EvalUsingLookupTable(data, input, output); |
1000 | } |
1001 | return kTfLiteOk; |
1002 | } break; |
1003 | case kTfLiteInt8: { |
1004 | if (kernel_type == kFixedPointOptimized) { |
1005 | TanhParams params; |
1006 | params.input_zero_point = input->params.zero_point; |
1007 | params.input_range_radius = data->input_range_radius; |
1008 | params.input_multiplier = data->input_multiplier; |
1009 | params.input_left_shift = data->input_left_shift; |
1010 | optimized_ops::Tanh16bitPrecision( |
1011 | params, GetTensorShape(input), GetTensorData<int8_t>(input), |
1012 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
1013 | } else { |
1014 | EvalUsingLookupTable(data, input, output); |
1015 | } |
1016 | return kTfLiteOk; |
1017 | } break; |
1018 | default: |
1019 | TF_LITE_KERNEL_LOG(context, |
1020 | "Only float32, uint8, int16 and int8 are supported " |
1021 | "currently, got %s." , |
1022 | TfLiteTypeGetName(input->type)); |
1023 | return kTfLiteError; |
1024 | } |
1025 | } |
1026 | |
1027 | // Sigmoid is also know as "Logistic". |
1028 | template <KernelType kernel_type> |
1029 | TfLiteStatus SigmoidEval(TfLiteContext* context, TfLiteNode* node) { |
1030 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
1031 | |
1032 | const TfLiteTensor* input; |
1033 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1034 | TfLiteTensor* output; |
1035 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1036 | switch (input->type) { |
1037 | case kTfLiteFloat32: { |
1038 | if (kernel_type == kReference) { |
1039 | reference_ops::Logistic( |
1040 | GetTensorShape(input), GetTensorData<float>(input), |
1041 | GetTensorShape(output), GetTensorData<float>(output)); |
1042 | } else { |
1043 | optimized_ops::Logistic( |
1044 | GetTensorShape(input), GetTensorData<float>(input), |
1045 | GetTensorShape(output), GetTensorData<float>(output)); |
1046 | } |
1047 | break; |
1048 | } |
1049 | case kTfLiteInt16: { |
1050 | LogisticParams params; |
1051 | if (kernel_type == kReference || (data->input_multiplier > 0)) { |
1052 | const int size = |
1053 | MatchingFlatSize(GetTensorShape(input), GetTensorShape(output)); |
1054 | |
1055 | reference_integer_ops::Logistic( |
1056 | data->input_multiplier, data->input_left_shift, size, |
1057 | GetTensorData<int16_t>(input), GetTensorData<int16_t>(output)); |
1058 | } else { |
1059 | optimized_ops::Logistic( |
1060 | params, GetTensorShape(input), GetTensorData<int16_t>(input), |
1061 | GetTensorShape(output), GetTensorData<int16_t>(output)); |
1062 | } |
1063 | break; |
1064 | } |
1065 | case kTfLiteUInt8: { |
1066 | if (kernel_type == kFixedPointOptimized) { |
1067 | LogisticParams params; |
1068 | params.input_zero_point = input->params.zero_point; |
1069 | params.input_range_radius = data->input_range_radius; |
1070 | params.input_multiplier = data->input_multiplier; |
1071 | params.input_left_shift = data->input_left_shift; |
1072 | optimized_ops::Logistic16bitPrecision( |
1073 | params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1074 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1075 | } else { |
1076 | EvalUsingLookupTable(data, input, output); |
1077 | } |
1078 | break; |
1079 | } |
1080 | case kTfLiteInt8: { |
1081 | if (kernel_type == kFixedPointOptimized) { |
1082 | LogisticParams params; |
1083 | params.input_zero_point = input->params.zero_point; |
1084 | params.input_range_radius = data->input_range_radius; |
1085 | params.input_multiplier = data->input_multiplier; |
1086 | params.input_left_shift = data->input_left_shift; |
1087 | optimized_ops::Logistic16bitPrecision( |
1088 | params, GetTensorShape(input), GetTensorData<int8_t>(input), |
1089 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
1090 | } else { |
1091 | EvalUsingLookupTable(data, input, output); |
1092 | } |
1093 | break; |
1094 | } |
1095 | default: |
1096 | TF_LITE_KERNEL_LOG(context, |
1097 | "Only float32, uint8, int16 and int8 are supported " |
1098 | "currently, got %s." , |
1099 | TfLiteTypeGetName(input->type)); |
1100 | return kTfLiteError; |
1101 | } |
1102 | return kTfLiteOk; |
1103 | } |
1104 | |
1105 | TfLiteStatus SoftmaxFloat(TfLiteContext* context, const TfLiteTensor* input, |
1106 | TfLiteTensor* output, TfLiteSoftmaxParams* params, |
1107 | KernelType kernel_type = kGenericOptimized) { |
1108 | SoftmaxParams op_params; |
1109 | op_params.beta = params->beta; |
1110 | if (kernel_type == kReference) { |
1111 | reference_ops::Softmax(op_params, GetTensorShape(input), |
1112 | GetTensorData<float>(input), GetTensorShape(output), |
1113 | GetTensorData<float>(output)); |
1114 | } else { |
1115 | optimized_ops::Softmax(op_params, GetTensorShape(input), |
1116 | GetTensorData<float>(input), GetTensorShape(output), |
1117 | GetTensorData<float>(output), |
1118 | CpuBackendContext::GetFromContext(context)); |
1119 | } |
1120 | return kTfLiteOk; |
1121 | } |
1122 | |
1123 | template <typename In, typename Out> |
1124 | TfLiteStatus SoftmaxQuantized(TfLiteContext* context, const TfLiteTensor* input, |
1125 | TfLiteTensor* output, SoftmaxOpData* data, |
1126 | KernelType kernel_type = kGenericOptimized) { |
1127 | if (kernel_type == kReference) { |
1128 | reference_ops::Softmax(data->params, GetTensorShape(input), |
1129 | GetTensorData<In>(input), GetTensorShape(output), |
1130 | GetTensorData<Out>(output)); |
1131 | } else { |
1132 | optimized_ops::Softmax(data->params, GetTensorShape(input), |
1133 | GetTensorData<In>(input), GetTensorShape(output), |
1134 | GetTensorData<Out>(output)); |
1135 | } |
1136 | return kTfLiteOk; |
1137 | } |
1138 | |
1139 | template <> |
1140 | TfLiteStatus SoftmaxQuantized<int8_t, int8_t>(TfLiteContext* context, |
1141 | const TfLiteTensor* input, |
1142 | TfLiteTensor* output, |
1143 | SoftmaxOpData* data, |
1144 | KernelType kernel_type) { |
1145 | if (kernel_type == kReference) { |
1146 | reference_ops::Softmax(data->params, GetTensorShape(input), |
1147 | GetTensorData<int8_t>(input), GetTensorShape(output), |
1148 | GetTensorData<int8_t>(output)); |
1149 | } else { |
1150 | #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT |
1151 | optimized_ops::SoftmaxInt8LUT( |
1152 | data->params, GetTensorShape(input), GetTensorData<int8_t>(input), |
1153 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
1154 | #else |
1155 | optimized_ops::Softmax(data->params, GetTensorShape(input), |
1156 | GetTensorData<int8_t>(input), GetTensorShape(output), |
1157 | GetTensorData<int8_t>(output)); |
1158 | #endif |
1159 | } |
1160 | return kTfLiteOk; |
1161 | } |
1162 | |
1163 | template <> |
1164 | TfLiteStatus SoftmaxQuantized<uint8_t, uint8_t>(TfLiteContext* context, |
1165 | const TfLiteTensor* input, |
1166 | TfLiteTensor* output, |
1167 | SoftmaxOpData* data, |
1168 | KernelType kernel_type) { |
1169 | if (kernel_type == kReference) { |
1170 | reference_ops::Softmax( |
1171 | data->params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1172 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1173 | } else { |
1174 | #ifdef TFLITE_SOFTMAX_USE_UINT16_LUT |
1175 | optimized_ops::SoftmaxInt8LUT( |
1176 | data->params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1177 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1178 | #else |
1179 | optimized_ops::Softmax( |
1180 | data->params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1181 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1182 | #endif |
1183 | } |
1184 | return kTfLiteOk; |
1185 | } |
1186 | |
1187 | template <> |
1188 | TfLiteStatus SoftmaxQuantized<int16, int16>(TfLiteContext* context, |
1189 | const TfLiteTensor* input, |
1190 | TfLiteTensor* output, |
1191 | SoftmaxOpData* data, |
1192 | KernelType kernel_type) { |
1193 | if (NumDimensions(input) >= 1 && NumDimensions(input) <= 4) { |
1194 | reference_ops::SoftmaxInt16( |
1195 | data->params, GetTensorShape(input), GetTensorData<int16_t>(input), |
1196 | GetTensorShape(output), GetTensorData<int16_t>(output)); |
1197 | return kTfLiteOk; |
1198 | } else { |
1199 | TF_LITE_KERNEL_LOG(context, |
1200 | "Only 1D, 2D, 3D and 4D tensors supported for int16 " |
1201 | "input with int16 output, got %dD." , |
1202 | NumDimensions(input)); |
1203 | return kTfLiteError; |
1204 | } |
1205 | } |
1206 | |
1207 | template <KernelType kernel_type> |
1208 | TfLiteStatus SoftmaxEval(TfLiteContext* context, TfLiteNode* node) { |
1209 | auto* params = reinterpret_cast<TfLiteSoftmaxParams*>(node->builtin_data); |
1210 | SoftmaxOpData* data = reinterpret_cast<SoftmaxOpData*>(node->user_data); |
1211 | |
1212 | const TfLiteTensor* input; |
1213 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1214 | TfLiteTensor* output; |
1215 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1216 | |
1217 | switch (input->type) { |
1218 | case kTfLiteFloat32: { |
1219 | return SoftmaxFloat(context, input, output, params, kernel_type); |
1220 | } |
1221 | case kTfLiteUInt8: { |
1222 | switch (output->type) { |
1223 | case kTfLiteUInt8: |
1224 | return SoftmaxQuantized<uint8_t, uint8_t>(context, input, output, |
1225 | data, kernel_type); |
1226 | case kTfLiteInt16: |
1227 | return SoftmaxQuantized<uint8_t, int16_t>(context, input, output, |
1228 | data, kernel_type); |
1229 | default: |
1230 | TF_LITE_KERNEL_LOG(context, |
1231 | "Only uint8_t and int16_t outputs are supported " |
1232 | "with uint8_t inputs currently, got %s." , |
1233 | TfLiteTypeGetName(output->type)); |
1234 | return kTfLiteError; |
1235 | } |
1236 | } |
1237 | case kTfLiteInt8: { |
1238 | switch (output->type) { |
1239 | case kTfLiteInt8: |
1240 | return SoftmaxQuantized<int8_t, int8_t>(context, input, output, data, |
1241 | kernel_type); |
1242 | case kTfLiteInt16: |
1243 | return SoftmaxQuantized<int8_t, int16_t>(context, input, output, data, |
1244 | kernel_type); |
1245 | default: |
1246 | TF_LITE_KERNEL_LOG(context, |
1247 | "Only int8_t and int16_t outputs are supported " |
1248 | "with int8_t inputs currently, got %s." , |
1249 | TfLiteTypeGetName(output->type)); |
1250 | return kTfLiteError; |
1251 | } |
1252 | } |
1253 | case kTfLiteInt16: { |
1254 | return SoftmaxQuantized<int16_t, int16_t>(context, input, output, data, |
1255 | kernel_type); |
1256 | } |
1257 | |
1258 | default: |
1259 | TF_LITE_KERNEL_LOG(context, |
1260 | "Only float32, uint8_t, Int8_t, Int16_t are supported " |
1261 | "currently, got %s." , |
1262 | TfLiteTypeGetName(input->type)); |
1263 | return kTfLiteError; |
1264 | } |
1265 | } |
1266 | |
1267 | template <KernelType kernel_type> |
1268 | TfLiteStatus LogSoftmaxEval(TfLiteContext* context, TfLiteNode* node) { |
1269 | const LogSoftmaxOpData* data = |
1270 | reinterpret_cast<LogSoftmaxOpData*>(node->user_data); |
1271 | const TfLiteTensor* input; |
1272 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1273 | TfLiteTensor* output; |
1274 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1275 | switch (input->type) { |
1276 | case kTfLiteFloat32: { |
1277 | SoftmaxParams op_params; |
1278 | if (kernel_type == kGenericOptimized) { |
1279 | optimized_ops::LogSoftmax( |
1280 | op_params, GetTensorShape(input), GetTensorData<float>(input), |
1281 | GetTensorShape(output), GetTensorData<float>(output)); |
1282 | } else { |
1283 | reference_ops::LogSoftmax( |
1284 | op_params, GetTensorShape(input), GetTensorData<float>(input), |
1285 | GetTensorShape(output), GetTensorData<float>(output)); |
1286 | } |
1287 | return kTfLiteOk; |
1288 | } |
1289 | case kTfLiteUInt8: { |
1290 | const SoftmaxParams& op_params = data->params; |
1291 | if (kernel_type == kGenericOptimized) { |
1292 | optimized_ops::LogSoftmax( |
1293 | op_params, input->params.scale, GetTensorShape(input), |
1294 | GetTensorData<uint8_t>(input), GetTensorShape(output), |
1295 | GetTensorData<uint8_t>(output)); |
1296 | } else { |
1297 | reference_ops::LogSoftmax( |
1298 | op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1299 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1300 | } |
1301 | return kTfLiteOk; |
1302 | } |
1303 | case kTfLiteInt8: { |
1304 | const SoftmaxParams& op_params = data->params; |
1305 | if (kernel_type == kGenericOptimized) { |
1306 | optimized_ops::LogSoftmax( |
1307 | op_params, input->params.scale, GetTensorShape(input), |
1308 | GetTensorData<int8_t>(input), GetTensorShape(output), |
1309 | GetTensorData<int8_t>(output)); |
1310 | } else { |
1311 | const auto input_shape = GetTensorShape(input); |
1312 | const auto output_shape = GetTensorShape(output); |
1313 | const int trailing_dim = input_shape.DimensionsCount() - 1; |
1314 | const int outer_size = |
1315 | MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape); |
1316 | const int depth = |
1317 | MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim); |
1318 | reference_integer_ops::LogSoftmax( |
1319 | op_params.input_multiplier, op_params.input_left_shift, |
1320 | op_params.reverse_scaling_divisor, |
1321 | op_params.reverse_scaling_right_shift, op_params.diff_min, |
1322 | outer_size, depth, GetTensorData<int8_t>(input), |
1323 | GetTensorData<int8_t>(output)); |
1324 | } |
1325 | return kTfLiteOk; |
1326 | } |
1327 | default: |
1328 | TF_LITE_KERNEL_LOG( |
1329 | context, |
1330 | "Only float32, uint8 and int8 are supported currently, got %s." , |
1331 | TfLiteTypeGetName(input->type)); |
1332 | return kTfLiteError; |
1333 | } |
1334 | } |
1335 | |
1336 | template <typename T> |
1337 | T ApplyPrelu(T input, T alpha) { |
1338 | return input >= 0.0 ? input : input * alpha; |
1339 | } |
1340 | |
1341 | template <KernelType kernel_type> |
1342 | TfLiteStatus PreluEval(TfLiteContext* context, TfLiteNode* node) { |
1343 | const TfLiteTensor* input; |
1344 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1345 | const TfLiteTensor* alpha; |
1346 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &alpha)); |
1347 | TfLiteTensor* output; |
1348 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1349 | const PreluOpData* data = reinterpret_cast<PreluOpData*>(node->user_data); |
1350 | switch (input->type) { |
1351 | case kTfLiteFloat32: { |
1352 | if (kernel_type == kGenericOptimized) { |
1353 | tflite::ArithmeticParams op_params; |
1354 | bool need_broadcast = optimized_ops::ProcessBroadcastShapes( |
1355 | GetTensorShape(input), GetTensorShape(alpha), &op_params); |
1356 | if (need_broadcast) { |
1357 | optimized_ops::BroadcastPReluDispatch( |
1358 | op_params, GetTensorShape(input), GetTensorData<float>(input), |
1359 | GetTensorShape(alpha), GetTensorData<float>(alpha), |
1360 | GetTensorShape(output), GetTensorData<float>(output), |
1361 | ApplyPrelu<float>); |
1362 | } else { |
1363 | const int flat_size = |
1364 | MatchingElementsSize(GetTensorShape(input), GetTensorShape(alpha), |
1365 | GetTensorShape(output)); |
1366 | optimized_ops::PReluElementWise( |
1367 | flat_size, op_params, GetTensorData<float>(alpha), |
1368 | GetTensorData<float>(input), GetTensorData<float>(output)); |
1369 | } |
1370 | } else { |
1371 | if (data->requires_broadcast) { |
1372 | reference_ops::BroadcastBinaryFunction4DSlow<float, float, float>( |
1373 | GetTensorShape(input), GetTensorData<float>(input), |
1374 | GetTensorShape(alpha), GetTensorData<float>(alpha), |
1375 | GetTensorShape(output), GetTensorData<float>(output), |
1376 | ApplyPrelu<float>); |
1377 | } else { |
1378 | reference_ops::BinaryFunction<float, float, float>( |
1379 | GetTensorShape(input), GetTensorData<float>(input), |
1380 | GetTensorShape(alpha), GetTensorData<float>(alpha), |
1381 | GetTensorShape(output), GetTensorData<float>(output), |
1382 | ApplyPrelu<float>); |
1383 | } |
1384 | } |
1385 | return kTfLiteOk; |
1386 | } |
1387 | case kTfLiteUInt8: { |
1388 | PreluParams op_params; |
1389 | op_params.input_offset = -input->params.zero_point; |
1390 | op_params.alpha_offset = -alpha->params.zero_point; |
1391 | op_params.output_offset = output->params.zero_point; |
1392 | op_params.output_multiplier_1 = data->output_multiplier_1; |
1393 | op_params.output_shift_1 = data->output_shift_1; |
1394 | op_params.output_multiplier_2 = data->output_multiplier_2; |
1395 | op_params.output_shift_2 = data->output_shift_2; |
1396 | if (data->requires_broadcast) { |
1397 | reference_ops::BroadcastPrelu4DSlow( |
1398 | op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1399 | GetTensorShape(alpha), GetTensorData<uint8_t>(alpha), |
1400 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1401 | } else { |
1402 | reference_ops::Prelu( |
1403 | op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
1404 | GetTensorShape(alpha), GetTensorData<uint8_t>(alpha), |
1405 | GetTensorShape(output), GetTensorData<uint8_t>(output)); |
1406 | } |
1407 | return kTfLiteOk; |
1408 | } |
1409 | case kTfLiteInt8: { |
1410 | PreluParams op_params; |
1411 | op_params.input_offset = -input->params.zero_point; |
1412 | op_params.alpha_offset = -alpha->params.zero_point; |
1413 | op_params.output_offset = output->params.zero_point; |
1414 | op_params.output_multiplier_1 = data->output_multiplier_1; |
1415 | op_params.output_shift_1 = data->output_shift_1; |
1416 | op_params.output_multiplier_2 = data->output_multiplier_2; |
1417 | op_params.output_shift_2 = data->output_shift_2; |
1418 | if (data->requires_broadcast) { |
1419 | reference_ops::BroadcastPrelu4DSlow( |
1420 | op_params, GetTensorShape(input), GetTensorData<int8_t>(input), |
1421 | GetTensorShape(alpha), GetTensorData<int8_t>(alpha), |
1422 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
1423 | } else { |
1424 | reference_ops::Prelu( |
1425 | op_params, GetTensorShape(input), GetTensorData<int8_t>(input), |
1426 | GetTensorShape(alpha), GetTensorData<int8_t>(alpha), |
1427 | GetTensorShape(output), GetTensorData<int8_t>(output)); |
1428 | } |
1429 | return kTfLiteOk; |
1430 | } |
1431 | default: |
1432 | TF_LITE_KERNEL_LOG( |
1433 | context, |
1434 | "Only float32 and uint8 and int8 are supported currently, got %d." , |
1435 | TfLiteTypeGetName(input->type)); |
1436 | return kTfLiteError; |
1437 | } |
1438 | } |
1439 | |
1440 | template <KernelType kernel_type, typename T> |
1441 | void QuantizeLeakyRelu(const TfLiteTensor* input, TfLiteTensor* output, |
1442 | const LeakyReluOpData* data) { |
1443 | LeakyReluParams op_params; |
1444 | |
1445 | op_params.input_offset = input->params.zero_point; |
1446 | op_params.output_offset = output->params.zero_point; |
1447 | op_params.output_multiplier_alpha = data->output_multiplier_alpha; |
1448 | op_params.output_shift_alpha = data->output_shift_alpha; |
1449 | op_params.output_multiplier_identity = data->output_multiplier_identity; |
1450 | op_params.output_shift_identity = data->output_shift_identity; |
1451 | if (kernel_type != KernelType::kReference && input->type == kTfLiteInt16) { |
1452 | optimized_integer_ops::QuantizeLeakyRelu( |
1453 | op_params, GetTensorShape(input), GetTensorData<int16>(input), |
1454 | GetTensorShape(output), GetTensorData<int16>(output)); |
1455 | } else { |
1456 | reference_ops::QuantizeLeakyRelu( |
1457 | op_params, GetTensorShape(input), GetTensorData<T>(input), |
1458 | GetTensorShape(output), GetTensorData<T>(output)); |
1459 | } |
1460 | } |
1461 | |
1462 | template <KernelType kernel_type> |
1463 | TfLiteStatus LeakyReluEval(TfLiteContext* context, TfLiteNode* node) { |
1464 | const TfLiteTensor* input; |
1465 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1466 | TfLiteTensor* output; |
1467 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1468 | const auto* params = |
1469 | reinterpret_cast<TfLiteLeakyReluParams*>(node->builtin_data); |
1470 | const LeakyReluOpData* data = |
1471 | reinterpret_cast<LeakyReluOpData*>(node->user_data); |
1472 | |
1473 | LeakyReluParams op_params; |
1474 | switch (input->type) { |
1475 | case kTfLiteFloat32: { |
1476 | op_params.alpha = params->alpha; |
1477 | optimized_ops::LeakyRelu( |
1478 | op_params, GetTensorShape(input), GetTensorData<float>(input), |
1479 | GetTensorShape(output), GetTensorData<float>(output)); |
1480 | return kTfLiteOk; |
1481 | } |
1482 | case kTfLiteUInt8: { |
1483 | QuantizeLeakyRelu<kernel_type, uint8_t>(input, output, data); |
1484 | return kTfLiteOk; |
1485 | } |
1486 | case kTfLiteInt8: { |
1487 | QuantizeLeakyRelu<kernel_type, int8_t>(input, output, data); |
1488 | return kTfLiteOk; |
1489 | } |
1490 | case kTfLiteInt16: { |
1491 | QuantizeLeakyRelu<kernel_type, int16_t>(input, output, data); |
1492 | return kTfLiteOk; |
1493 | } |
1494 | default: |
1495 | TF_LITE_KERNEL_LOG( |
1496 | context, |
1497 | "Only float32, int8, int16 and uint8 is supported currently, got %s." , |
1498 | TfLiteTypeGetName(input->type)); |
1499 | return kTfLiteError; |
1500 | } |
1501 | } |
1502 | |
1503 | TfLiteStatus EluPrepare(TfLiteContext* context, TfLiteNode* node) { |
1504 | const TfLiteTensor* input; |
1505 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1506 | TfLiteTensor* output; |
1507 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1508 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
1509 | |
1510 | // Use LUT to handle quantized elu path. |
1511 | if (input->type == kTfLiteInt8) { |
1512 | PopulateLookupTable<int8_t>(data, input, output, [](float value) { |
1513 | return value < 0.0f ? std::expm1(value) : value; |
1514 | }); |
1515 | } |
1516 | return GenericPrepare(context, node); |
1517 | } |
1518 | |
1519 | TfLiteStatus EluEval(TfLiteContext* context, TfLiteNode* node) { |
1520 | const TfLiteTensor* input; |
1521 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1522 | TfLiteTensor* output; |
1523 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1524 | switch (input->type) { |
1525 | case kTfLiteFloat32: { |
1526 | optimized_ops::Elu(GetTensorShape(input), GetTensorData<float>(input), |
1527 | GetTensorShape(output), GetTensorData<float>(output)); |
1528 | return kTfLiteOk; |
1529 | } |
1530 | case kTfLiteInt8: { |
1531 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
1532 | EvalUsingLookupTable(data, input, output); |
1533 | return kTfLiteOk; |
1534 | } |
1535 | default: |
1536 | TF_LITE_KERNEL_LOG( |
1537 | context, "Only float32 and int8 is supported currently, got %s." , |
1538 | TfLiteTypeGetName(input->type)); |
1539 | return kTfLiteError; |
1540 | } |
1541 | } |
1542 | |
1543 | TfLiteStatus GeluPrepare(TfLiteContext* context, TfLiteNode* node) { |
1544 | const TfLiteTensor* input; |
1545 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1546 | TfLiteTensor* output; |
1547 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1548 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
1549 | auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data); |
1550 | |
1551 | if (input->type == kTfLiteInt8) { |
1552 | PopulateLookupTable<int8_t>( |
1553 | data, input, output, reference_ops::GeluTransform(params->approximate)); |
1554 | } else if (input->type == kTfLiteUInt8) { |
1555 | PopulateLookupTable<uint8_t>( |
1556 | data, input, output, reference_ops::GeluTransform(params->approximate)); |
1557 | } |
1558 | return GenericPrepare(context, node); |
1559 | } |
1560 | |
1561 | TfLiteStatus GeluEval(TfLiteContext* context, TfLiteNode* node) { |
1562 | auto* params = reinterpret_cast<TfLiteGeluParams*>(node->builtin_data); |
1563 | const TfLiteTensor* input; |
1564 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input)); |
1565 | TfLiteTensor* output; |
1566 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
1567 | |
1568 | switch (input->type) { |
1569 | case kTfLiteFloat32: { |
1570 | reference_ops::Gelu(GetTensorShape(input), GetTensorData<float>(input), |
1571 | params->approximate, GetTensorShape(output), |
1572 | GetTensorData<float>(output)); |
1573 | return kTfLiteOk; |
1574 | } |
1575 | case kTfLiteInt8: |
1576 | case kTfLiteUInt8: { |
1577 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
1578 | EvalUsingLookupTable(data, input, output); |
1579 | return kTfLiteOk; |
1580 | } |
1581 | default: |
1582 | TF_LITE_KERNEL_LOG( |
1583 | context, "Only float32, int8 and uint8 supported currently, got %s." , |
1584 | TfLiteTypeGetName(input->type)); |
1585 | return kTfLiteError; |
1586 | } |
1587 | return kTfLiteOk; |
1588 | } |
1589 | |
1590 | } // namespace activations |
1591 | |
1592 | TfLiteRegistration* Register_ELU() { |
1593 | static TfLiteRegistration r = {activations::Init, activations::Free, |
1594 | activations::EluPrepare, activations::EluEval}; |
1595 | return &r; |
1596 | } |
1597 | |
1598 | TfLiteRegistration* Register_RELU() { |
1599 | static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree, |
1600 | activations::ReluPrepare, |
1601 | activations::ReluEval}; |
1602 | return &r; |
1603 | } |
1604 | |
1605 | TfLiteRegistration* Register_RELU_N1_TO_1() { |
1606 | static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree, |
1607 | activations::ReluPrepare, |
1608 | activations::Relu1Eval}; |
1609 | return &r; |
1610 | } |
1611 | |
1612 | TfLiteRegistration* Register_RELU6() { |
1613 | static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree, |
1614 | activations::ReluPrepare, |
1615 | activations::Relu6Eval}; |
1616 | return &r; |
1617 | } |
1618 | |
1619 | TfLiteRegistration* Register_RELU_0_TO_1() { |
1620 | static TfLiteRegistration r = {activations::ReluInit, activations::ReluFree, |
1621 | activations::ReluPrepare, |
1622 | activations::Relu0to1Eval}; |
1623 | return &r; |
1624 | } |
1625 | |
1626 | TfLiteRegistration* Register_TANH_REF() { |
1627 | static TfLiteRegistration r = { |
1628 | activations::Init, activations::Free, |
1629 | activations::TanhPrepare<activations::kReference>, |
1630 | activations::TanhEval<activations::kReference>}; |
1631 | return &r; |
1632 | } |
1633 | |
1634 | TfLiteRegistration* Register_TANH_GENERIC_OPT() { |
1635 | static TfLiteRegistration r = { |
1636 | activations::Init, activations::Free, |
1637 | activations::TanhPrepare<activations::kGenericOptimized>, |
1638 | activations::TanhEval<activations::kGenericOptimized>}; |
1639 | return &r; |
1640 | } |
1641 | |
1642 | TfLiteRegistration* Register_TANH_FIXED_POINT_OPT() { |
1643 | static TfLiteRegistration r = { |
1644 | activations::Init, activations::Free, |
1645 | activations::TanhPrepare<activations::kFixedPointOptimized>, |
1646 | activations::TanhEval<activations::kFixedPointOptimized>}; |
1647 | return &r; |
1648 | } |
1649 | |
1650 | TfLiteRegistration* Register_TANH() { |
1651 | // TODO(b/134622898): Switch over from the LUT optimized method to the fixed |
1652 | // point optimized method when typical Android hardware performs better on |
1653 | // the latter one. |
1654 | return Register_TANH_GENERIC_OPT(); |
1655 | } |
1656 | |
1657 | TfLiteRegistration* Register_LOGISTIC_REF() { |
1658 | static TfLiteRegistration r = { |
1659 | activations::Init, activations::Free, |
1660 | activations::SigmoidPrepare<activations::kReference>, |
1661 | activations::SigmoidEval<activations::kReference>}; |
1662 | return &r; |
1663 | } |
1664 | |
1665 | TfLiteRegistration* Register_LOGISTIC_GENERIC_OPT() { |
1666 | static TfLiteRegistration r = { |
1667 | activations::Init, activations::Free, |
1668 | activations::SigmoidPrepare<activations::kGenericOptimized>, |
1669 | activations::SigmoidEval<activations::kGenericOptimized>}; |
1670 | return &r; |
1671 | } |
1672 | |
1673 | TfLiteRegistration* Register_LOGISTIC_FIXED_POINT_OPT() { |
1674 | static TfLiteRegistration r = { |
1675 | activations::Init, activations::Free, |
1676 | activations::SigmoidPrepare<activations::kFixedPointOptimized>, |
1677 | activations::SigmoidEval<activations::kFixedPointOptimized>}; |
1678 | return &r; |
1679 | } |
1680 | |
1681 | TfLiteRegistration* Register_LOGISTIC() { |
1682 | // TODO(b/134622898): Switch over from the LUT optimized method to the fixed |
1683 | // point optimized method when typical Android hardware performs better on |
1684 | // the latter one. |
1685 | return Register_LOGISTIC_GENERIC_OPT(); |
1686 | } |
1687 | |
1688 | TfLiteRegistration* Register_SOFTMAX_REF() { |
1689 | static TfLiteRegistration r = { |
1690 | activations::SoftmaxInit, activations::SoftmaxFree, |
1691 | activations::SoftmaxPrepare<activations::kReference>, |
1692 | activations::SoftmaxEval<activations::kReference>}; |
1693 | return &r; |
1694 | } |
1695 | |
1696 | TfLiteRegistration* Register_SOFTMAX() { |
1697 | static TfLiteRegistration r = { |
1698 | activations::SoftmaxInit, activations::SoftmaxFree, |
1699 | activations::SoftmaxPrepare<activations::kGenericOptimized>, |
1700 | activations::SoftmaxEval<activations::kGenericOptimized>}; |
1701 | return &r; |
1702 | } |
1703 | |
1704 | TfLiteRegistration* Register_LOG_SOFTMAX_REF() { |
1705 | static TfLiteRegistration r = { |
1706 | activations::LogSoftmaxInit, activations::LogSoftmaxFree, |
1707 | activations::LogSoftmaxPrepare<activations::kReference>, |
1708 | activations::LogSoftmaxEval<activations::kReference>}; |
1709 | return &r; |
1710 | } |
1711 | |
1712 | TfLiteRegistration* Register_LOG_SOFTMAX() { |
1713 | static TfLiteRegistration r = { |
1714 | activations::LogSoftmaxInit, activations::LogSoftmaxFree, |
1715 | activations::LogSoftmaxPrepare<activations::kGenericOptimized>, |
1716 | activations::LogSoftmaxEval<activations::kGenericOptimized>}; |
1717 | return &r; |
1718 | } |
1719 | |
1720 | TfLiteRegistration* Register_PRELU_REF() { |
1721 | static TfLiteRegistration r = { |
1722 | activations::PreluInit, activations::PreluFree, activations::PreluPrepare, |
1723 | activations::PreluEval<activations::kReference>}; |
1724 | return &r; |
1725 | } |
1726 | |
1727 | TfLiteRegistration* Register_PRELU() { |
1728 | static TfLiteRegistration r = { |
1729 | activations::PreluInit, activations::PreluFree, activations::PreluPrepare, |
1730 | activations::PreluEval<activations::kGenericOptimized>}; |
1731 | return &r; |
1732 | } |
1733 | |
1734 | TfLiteRegistration* Register_LEAKY_RELU_REF() { |
1735 | static TfLiteRegistration r = { |
1736 | activations::LeakyReluInit, activations::LeakyReluFree, |
1737 | activations::LeakyReluPrepare, |
1738 | activations::LeakyReluEval<activations::kReference>}; |
1739 | return &r; |
1740 | } |
1741 | |
1742 | TfLiteRegistration* Register_LEAKY_RELU() { |
1743 | static TfLiteRegistration r = { |
1744 | activations::LeakyReluInit, activations::LeakyReluFree, |
1745 | activations::LeakyReluPrepare, |
1746 | activations::LeakyReluEval<activations::kGenericOptimized>}; |
1747 | return &r; |
1748 | } |
1749 | |
1750 | TfLiteRegistration* Register_HARD_SWISH() { |
1751 | static TfLiteRegistration r = { |
1752 | activations::HardSwishInit, activations::HardSwishFree, |
1753 | activations::HardSwishPrepare, |
1754 | activations::HardSwishEval<activations::kGenericOptimized>}; |
1755 | return &r; |
1756 | } |
1757 | |
1758 | TfLiteRegistration* Register_HARD_SWISH_REF() { |
1759 | static TfLiteRegistration r = { |
1760 | activations::HardSwishInit, activations::HardSwishFree, |
1761 | activations::HardSwishPrepare, |
1762 | activations::HardSwishEval<activations::kReference>}; |
1763 | return &r; |
1764 | } |
1765 | |
1766 | TfLiteRegistration* Register_GELU() { |
1767 | static TfLiteRegistration r = {activations::Init, activations::Free, |
1768 | activations::GeluPrepare, |
1769 | activations::GeluEval}; |
1770 | return &r; |
1771 | } |
1772 | |
1773 | } // namespace builtin |
1774 | } // namespace ops |
1775 | } // namespace tflite |
1776 | |