1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/reduce.h"
16
17#include <stddef.h>
18
19#include <cstdint>
20#include <limits>
21
22#include "ruy/profiler/instrumentation.h" // from @ruy
23#include "tensorflow/lite/c/builtin_op_data.h"
24#include "tensorflow/lite/c/c_api_types.h"
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/kernels/cpu_backend_context.h"
27#include "tensorflow/lite/kernels/internal/compatibility.h"
28#include "tensorflow/lite/kernels/internal/optimized/integer_ops/mean.h"
29#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
30#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
31#include "tensorflow/lite/kernels/internal/optimized/reduce.h"
32#include "tensorflow/lite/kernels/internal/quantization_util.h"
33#include "tensorflow/lite/kernels/internal/reduce_common.h"
34#include "tensorflow/lite/kernels/internal/reference/integer_ops/mean.h"
35#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
36#include "tensorflow/lite/kernels/internal/tensor.h"
37#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
38#include "tensorflow/lite/kernels/internal/types.h"
39#include "tensorflow/lite/kernels/kernel_util.h"
40
41namespace tflite {
42namespace ops {
43namespace builtin {
44namespace reduce {
45
46// This file has reference implementation of reduce_* operators.
47enum KernelType {
48 kReference,
49 kGenericOptimized,
50};
51
52struct OpData {
53 int32_t multiplier;
54 int shift;
55 // The index of the temporary tensor where the quantized inputs are cached.
56 int scratch_tensor_index;
57};
58
59struct OpContext {
60 OpContext(TfLiteContext* context, TfLiteNode* node) {
61 params = reinterpret_cast<TfLiteReducerParams*>(node->builtin_data);
62 input = GetInput(context, node, 0);
63 axis = GetInput(context, node, 1);
64 output = GetOutput(context, node, 0);
65 }
66 TfLiteReducerParams* params;
67 const TfLiteTensor* input;
68 const TfLiteTensor* axis;
69 TfLiteTensor* output;
70};
71
72void* Init(TfLiteContext* context, const char* buffer, size_t length) {
73 // Creates three temp tensors to store index and axis for internal
74 // implementation only.
75 auto* op_data = new OpData();
76 context->AddTensors(context, 4, &op_data->scratch_tensor_index);
77 return op_data;
78}
79
80void Free(TfLiteContext* context, void* buffer) {
81 delete reinterpret_cast<OpData*>(buffer);
82}
83
84// Resizes the temp tensor that stores resolved axis.
85TfLiteStatus ResizeTempAxis(TfLiteContext* context, OpContext* op_context,
86 TfLiteTensor* resolved_axis) {
87 TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1);
88 axis_size->data[0] = static_cast<int>(NumElements(op_context->axis));
89 return context->ResizeTensor(context, resolved_axis, axis_size);
90}
91
92// Resizes the temp tensor that stores temp sum of reduced elements.
93TfLiteStatus ResizeTempAccum(TfLiteContext* context, OpContext* op_context,
94 TfLiteTensor* temp_accum) {
95 TfLiteIntArray* size = TfLiteIntArrayCreate(1);
96 size->data[0] = static_cast<int>(NumElements(op_context->output));
97 return context->ResizeTensor(context, temp_accum, size);
98}
99
100// Resizes output array based on the input size and resolved axis.
101TfLiteStatus ResizeOutputTensor(TfLiteContext* context, OpContext* op_context) {
102 size_t num_axis = NumElements(op_context->axis);
103 const TfLiteIntArray* input_dims = op_context->input->dims;
104 int input_num_dims = NumDimensions(op_context->input);
105 if (input_num_dims == 0) {
106 return context->ResizeTensor(context, op_context->output,
107 TfLiteIntArrayCreate(0));
108 }
109 const int* axis = GetTensorData<int>(op_context->axis);
110 if (op_context->params->keep_dims) {
111 TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims);
112 for (int idx = 0; idx < input_num_dims; ++idx) {
113 bool is_axis = false;
114 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
115 if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
116 is_axis = true;
117 break;
118 }
119 }
120 if (is_axis) {
121 output_dims->data[idx] = 1;
122 } else {
123 output_dims->data[idx] = input_dims->data[idx];
124 }
125 }
126 return context->ResizeTensor(context, op_context->output, output_dims);
127 } else {
128 // Calculates size of reducing axis.
129 int num_reduce_axis = num_axis;
130 for (int i = 0; i < num_axis; ++i) {
131 int current = axis[i];
132 if (current < 0) {
133 current += input_num_dims;
134 }
135 TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims);
136 for (int j = 0; j < i; ++j) {
137 int previous = axis[j];
138 if (previous < 0) {
139 previous += input_num_dims;
140 }
141 if (current == previous) {
142 --num_reduce_axis;
143 break;
144 }
145 }
146 }
147 // Determines output dimensions.
148 TfLiteIntArray* output_dims =
149 TfLiteIntArrayCreate(input_num_dims - num_reduce_axis);
150 int num_skip_axis = 0;
151 for (int idx = 0; idx < input_num_dims; ++idx) {
152 bool is_axis = false;
153 for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) {
154 if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) {
155 ++num_skip_axis;
156 is_axis = true;
157 break;
158 }
159 }
160 if (!is_axis) {
161 output_dims->data[idx - num_skip_axis] = input_dims->data[idx];
162 }
163 }
164 return context->ResizeTensor(context, op_context->output, output_dims);
165 }
166}
167
168// Resizes the temp tensor that stores normalized dims.
169TfLiteStatus ResizeTempDims(TfLiteContext* context, OpContext* op_context,
170 TfLiteTensor* normalized_dims) {
171 TfLiteIntArray* dims_size = TfLiteIntArrayCreate(1);
172 dims_size->data[0] = (op_context->input->dims->size);
173 return context->ResizeTensor(context, normalized_dims, dims_size);
174}
175
176// Initializes temp tensors to store index and resolved axis.
177TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
178 OpContext* op_context) {
179 // Creates a temp index to iterate through input data.
180 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
181 TfLiteIntArrayFree(node->temporaries);
182 node->temporaries = TfLiteIntArrayCreate(4);
183 node->temporaries->data[0] = op_data->scratch_tensor_index;
184 TfLiteTensor* scratch_tensor;
185 TF_LITE_ENSURE_OK(
186 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
187 scratch_tensor->type = kTfLiteInt32;
188 scratch_tensor->allocation_type = kTfLiteArenaRw;
189 TfLiteIntArray* index_size = TfLiteIntArrayCreate(1);
190 index_size->data[0] = NumDimensions(op_context->input);
191 TF_LITE_ENSURE_OK(context,
192 context->ResizeTensor(context, scratch_tensor, index_size));
193
194 // Creates a temp tensor to store resolved axis given input data.
195 node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
196 TfLiteTensor* resolved_axis;
197 TF_LITE_ENSURE_OK(
198 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
199 resolved_axis->type = kTfLiteInt32;
200 // Creates a temporary accumulation tensor to store temp sums when calculating
201 // mean or temp prod when calculating reduce prod.
202 node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
203 TfLiteTensor* temp_accum;
204 TF_LITE_ENSURE_OK(context,
205 GetTemporarySafe(context, node, /*index=*/2, &temp_accum));
206 switch (op_context->input->type) {
207 case kTfLiteFloat32:
208 temp_accum->type = kTfLiteFloat32;
209 break;
210 case kTfLiteInt32:
211 temp_accum->type = kTfLiteInt64;
212 break;
213 case kTfLiteInt64:
214 temp_accum->type = kTfLiteInt64;
215 break;
216 case kTfLiteUInt8:
217 case kTfLiteInt8:
218 case kTfLiteInt16:
219 temp_accum->type = kTfLiteInt32;
220 break;
221 case kTfLiteBool:
222 temp_accum->type = kTfLiteBool;
223 break;
224 default:
225 return kTfLiteError;
226 }
227 // Creates a temp tensor to store normalized shape given input data.
228 node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
229 TfLiteTensor* normalized_dims;
230 TF_LITE_ENSURE_OK(
231 context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims));
232 normalized_dims->type = kTfLiteInt32;
233 return kTfLiteOk;
234}
235
236TfLiteStatus PrepareSimple(TfLiteContext* context, TfLiteNode* node) {
237 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
238 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
239
240 OpContext op_context(context, node);
241 TF_LITE_ENSURE_TYPES_EQ(context, op_context.axis->type, kTfLiteInt32);
242 TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
243
244 if (op_context.input->type == kTfLiteInt16) {
245 TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
246 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
247 }
248
249 TfLiteTensor* resolved_axis;
250 TF_LITE_ENSURE_OK(
251 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
252 TfLiteTensor* normalized_dims;
253 TF_LITE_ENSURE_OK(
254 context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims));
255
256 if (!IsConstantTensor(op_context.input)) {
257 SetTensorToDynamic(normalized_dims);
258 } else {
259 normalized_dims->allocation_type = kTfLiteArenaRw;
260 TF_LITE_ENSURE_OK(context,
261 ResizeTempDims(context, &op_context, normalized_dims));
262 }
263 // Leaves work to Eval if axis is not constant; else resizes output.
264 if (!IsConstantTensor(op_context.axis)) {
265 SetTensorToDynamic(op_context.output);
266 SetTensorToDynamic(resolved_axis);
267 return kTfLiteOk;
268 }
269 resolved_axis->allocation_type = kTfLiteArenaRw;
270 TF_LITE_ENSURE_OK(context,
271 ResizeTempAxis(context, &op_context, resolved_axis));
272 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
273 return kTfLiteOk;
274}
275
276TfLiteStatus PrepareAllOrAny(TfLiteContext* context, TfLiteNode* node) {
277 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
278 const TfLiteTensor* input;
279 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
280 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteBool);
281 return PrepareSimple(context, node);
282}
283
284TfLiteStatus PrepareMeanOrSum(TfLiteContext* context, TfLiteNode* node) {
285 TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
286 OpData* data = reinterpret_cast<OpData*>(node->user_data);
287
288 // reduce_mean requires a buffer to store intermediate sum result.
289 OpContext op_context(context, node);
290 if (op_context.input->type == kTfLiteInt8 ||
291 op_context.input->type == kTfLiteUInt8 ||
292 op_context.input->type == kTfLiteInt16) {
293 const double real_multiplier =
294 static_cast<double>(op_context.input->params.scale) /
295 static_cast<double>(op_context.output->params.scale);
296 int exponent;
297 QuantizeMultiplier(real_multiplier, &data->multiplier, &exponent);
298 data->shift = exponent;
299 }
300
301 if (op_context.input->type == kTfLiteInt16) {
302 TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
303 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
304 }
305
306 TfLiteTensor* temp_sum;
307 TF_LITE_ENSURE_OK(context,
308 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
309 if (!IsConstantTensor(op_context.axis)) {
310 SetTensorToDynamic(temp_sum);
311 return kTfLiteOk;
312 }
313 temp_sum->allocation_type = kTfLiteArenaRw;
314 return ResizeTempAccum(context, &op_context, temp_sum);
315}
316
317double GetQuantProdScaling(double input_scale, double output_scale,
318 int reduced_axis_size) {
319 // The scaling after taking the product of all the quantized values should
320 // be (input_scale**reduced_axis_size)/output_scale but to avoid overflowing
321 // the accumulator we instead scale each multiplication by
322 // input_scale/nth_root(output_scale, reduced_axis_size).
323 return input_scale / std::pow(output_scale, 1.0 / reduced_axis_size);
324}
325
326TfLiteStatus PrepareProd(TfLiteContext* context, TfLiteNode* node) {
327 TF_LITE_ENSURE_OK(context, PrepareSimple(context, node));
328
329 OpContext op_context(context, node);
330 OpData* data = reinterpret_cast<OpData*>(node->user_data);
331
332 TfLiteTensor* temp_prod;
333 TF_LITE_ENSURE_OK(context,
334 GetTemporarySafe(context, node, /*index=*/2, &temp_prod));
335
336 if (op_context.input->type == kTfLiteInt16) {
337 TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
338 TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, 0);
339 }
340
341 if (!IsConstantTensor(op_context.axis)) {
342 SetTensorToDynamic(temp_prod);
343 return kTfLiteOk;
344 }
345
346 const int input_size = GetTensorShape(op_context.input).FlatSize();
347 const int output_size = GetTensorShape(op_context.output).FlatSize();
348 // We support both quantized and non-quantized int8/int16 inputs
349 if (op_context.input->quantization.type != kTfLiteNoQuantization &&
350 (op_context.input->type == kTfLiteInt8 ||
351 op_context.input->type == kTfLiteInt16) &&
352 input_size != 0 && output_size != 0) {
353 const int reduced_axis_size = input_size / output_size;
354 const double scaling = GetQuantProdScaling(
355 static_cast<double>(op_context.input->params.scale),
356 static_cast<double>(op_context.output->params.scale),
357 reduced_axis_size);
358 QuantizeMultiplier(scaling, &data->multiplier, &data->shift);
359 }
360
361 temp_prod->allocation_type = kTfLiteArenaRw;
362 return ResizeTempAccum(context, &op_context, temp_prod);
363}
364
365void ResolveAxis(const int* axis_data, int axis_count,
366 tflite::MeanParams* op_params) {
367 int i = 0;
368 for (; i < axis_count; ++i) {
369 op_params->axis[i] = static_cast<int16>(axis_data[i]);
370 }
371 for (; i < 4; ++i) {
372 op_params->axis[i] = 1;
373 }
374}
375
376template <typename T, typename U>
377TfLiteStatus Mean(TfLiteContext* context, const OpContext* op_context,
378 int* temp_index, int* resolved_axis, U* temp_sum,
379 KernelType kernel_type) {
380 int num_axis = static_cast<int>(NumElements(op_context->axis));
381 auto args = std::tuple(
382 GetTensorData<T>(op_context->input), &op_context->input->dims->data[0],
383 op_context->input->dims->size, GetTensorData<T>(op_context->output),
384 &op_context->output->dims->data[0], op_context->output->dims->size,
385 GetTensorData<int>(op_context->axis), num_axis,
386 op_context->params->keep_dims, temp_index, resolved_axis, temp_sum);
387 if (kernel_type == kReference) {
388 TF_LITE_ENSURE(context, std::apply(reference_ops::Mean<T, U>, args));
389 } else {
390 TF_LITE_ENSURE(context, std::apply(optimized_ops::Mean<T, U>, args));
391 }
392 return kTfLiteOk;
393}
394
395template <typename T>
396TfLiteStatus QuantizedMeanOrSum(TfLiteContext* context,
397 const OpContext* op_context, int* temp_index,
398 int* resolved_axis, int* temp_sum,
399 KernelType kernel_type, bool compute_sum) {
400 int num_axis = static_cast<int>(NumElements(op_context->axis));
401 auto args = std::tuple(
402 GetTensorData<T>(op_context->input), op_context->input->params.zero_point,
403 op_context->input->params.scale, &op_context->input->dims->data[0],
404 op_context->input->dims->size, GetTensorData<T>(op_context->output),
405 op_context->output->params.zero_point, op_context->output->params.scale,
406 &op_context->output->dims->data[0], op_context->output->dims->size,
407 GetTensorData<int>(op_context->axis), num_axis,
408 op_context->params->keep_dims, temp_index, resolved_axis, temp_sum,
409 compute_sum);
410 if (kernel_type == kReference) {
411 TF_LITE_ENSURE(
412 context,
413 std::apply(reference_ops::QuantizedMeanOrSum<T, int32_t>, args));
414 } else {
415 TF_LITE_ENSURE(
416 context,
417 std::apply(optimized_ops::QuantizedMeanOrSum<T, int32_t>, args));
418 }
419 return kTfLiteOk;
420}
421
422template <typename integer_type>
423TfLiteStatus EvalIntegerMean(TfLiteContext* context,
424 const OpContext& op_context, int num_axis,
425 OpData* data, TfLiteTensor* temp_index,
426 TfLiteTensor* resolved_axis,
427 TfLiteTensor* temp_sum,
428 TfLiteTensor* normalized_dims,
429 KernelType kernel_type) {
430 tflite::MeanParams op_params;
431 op_params.axis_count = num_axis;
432 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
433 const TfLiteTensor* input = op_context.input;
434
435 if (input->params.zero_point == op_context.output->params.zero_point &&
436 input->params.scale == op_context.output->params.scale) {
437 Mean<integer_type, int>(context, &op_context,
438 GetTensorData<int>(temp_index),
439 GetTensorData<int>(resolved_axis),
440 GetTensorData<int>(temp_sum), kernel_type);
441 } else {
442 QuantizedMeanOrSum<integer_type>(
443 context, &op_context, GetTensorData<int>(temp_index),
444 GetTensorData<int>(resolved_axis), GetTensorData<int32_t>(temp_sum),
445 kernel_type, /*compute_sum=*/false);
446 }
447 return kTfLiteOk;
448}
449
450template <typename T>
451void InitializeMeanOutputTyped(TfLiteTensor* output) {
452 RuntimeShape output_shape = GetTensorShape(output);
453 const size_t flat_size = output_shape.FlatSize();
454 T* output_data = GetTensorData<T>(output);
455 T nan_value = std::numeric_limits<T>::quiet_NaN();
456 for (int idx = 0; idx < flat_size; ++idx) {
457 *output_data++ = nan_value;
458 }
459}
460
461TfLiteStatus InitializeMeanOutput(TfLiteTensor* output) {
462 switch (output->type) {
463 case kTfLiteFloat32:
464 InitializeMeanOutputTyped<float>(output);
465 break;
466 case kTfLiteInt32:
467 InitializeMeanOutputTyped<int>(output);
468 break;
469 case kTfLiteInt64:
470 InitializeMeanOutputTyped<int64_t>(output);
471 break;
472 case kTfLiteUInt8:
473 InitializeMeanOutputTyped<uint8_t>(output);
474 break;
475 case kTfLiteInt8:
476 InitializeMeanOutputTyped<int8_t>(output);
477 break;
478 case kTfLiteInt16:
479 InitializeMeanOutputTyped<int16_t>(output);
480 break;
481 default:
482 return kTfLiteError;
483 }
484 return kTfLiteOk;
485}
486
487template <KernelType kernel_type>
488TfLiteStatus EvalMean(TfLiteContext* context, TfLiteNode* node) {
489 OpContext op_context(context, node);
490 OpData* data = reinterpret_cast<OpData*>(node->user_data);
491
492 int num_axis = static_cast<int>(NumElements(op_context.axis));
493 TfLiteTensor* temp_index;
494 TF_LITE_ENSURE_OK(context,
495 GetTemporarySafe(context, node, /*index=*/0, &temp_index));
496 TfLiteTensor* resolved_axis;
497 TF_LITE_ENSURE_OK(
498 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
499 TfLiteTensor* temp_sum;
500 TF_LITE_ENSURE_OK(context,
501 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
502 // Resize the output tensor if the output tensor is dynamic.
503 if (IsDynamicTensor(op_context.output)) {
504 TF_LITE_ENSURE_OK(context,
505 ResizeTempAxis(context, &op_context, resolved_axis));
506 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
507 TF_LITE_ENSURE_OK(context, ResizeTempAccum(context, &op_context, temp_sum));
508 }
509 TfLiteTensor* normalized_dims;
510 TF_LITE_ENSURE_OK(
511 context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims));
512 if (IsDynamicTensor(normalized_dims)) {
513 TF_LITE_ENSURE_OK(context,
514 ResizeTempDims(context, &op_context, normalized_dims));
515 }
516
517 // Return early when input is empty.
518 const TfLiteTensor* input = op_context.input;
519 RuntimeShape input_shape = GetTensorShape(input);
520 if (input_shape.FlatSize() == 0) {
521 TF_LITE_ENSURE_OK(context, InitializeMeanOutput(op_context.output));
522 return kTfLiteOk;
523 }
524
525 if (kernel_type == kGenericOptimized) {
526 // Use optimized ops if available.
527 switch (input->type) {
528 case kTfLiteInt8: {
529 tflite::MeanParams op_params;
530 op_params.axis_count = num_axis;
531 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
532 if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
533 op_params.axis_count == 2 &&
534 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
535 (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
536 optimized_integer_ops::Mean(
537 op_params, input_shape, GetTensorData<int8_t>(input),
538 input->params.zero_point, input->params.scale,
539 GetTensorShape(op_context.output),
540 GetTensorData<int8_t>(op_context.output),
541 op_context.output->params.zero_point,
542 op_context.output->params.scale,
543 CpuBackendContext::GetFromContext(context));
544 return kTfLiteOk;
545 }
546 } break;
547 case kTfLiteUInt8: {
548 tflite::MeanParams op_params;
549 op_params.axis_count = num_axis;
550 ResolveAxis(GetTensorData<int>(op_context.axis), num_axis, &op_params);
551 if (op_context.params->keep_dims && NumDimensions(input) == 4 &&
552 op_params.axis_count == 2 &&
553 ((op_params.axis[0] == 1 && op_params.axis[1] == 2) ||
554 (op_params.axis[0] == 2 && op_params.axis[1] == 1))) {
555 optimized_ops::Mean(op_params, input_shape,
556 GetTensorData<uint8_t>(input),
557 input->params.zero_point, input->params.scale,
558 GetTensorShape(op_context.output),
559 GetTensorData<uint8_t>(op_context.output),
560 op_context.output->params.zero_point,
561 op_context.output->params.scale,
562 CpuBackendContext::GetFromContext(context));
563 return kTfLiteOk;
564 }
565 } break;
566 default:
567 break;
568 }
569 }
570
571 switch (op_context.input->type) {
572 case kTfLiteFloat32:
573 Mean<float, float>(context, &op_context, GetTensorData<int>(temp_index),
574 GetTensorData<int>(resolved_axis),
575 GetTensorData<float>(temp_sum), kernel_type);
576 break;
577 case kTfLiteInt32:
578 Mean<int, int64_t>(context, &op_context, GetTensorData<int>(temp_index),
579 GetTensorData<int>(resolved_axis),
580 GetTensorData<int64_t>(temp_sum), kernel_type);
581 break;
582 case kTfLiteInt64:
583 Mean<int64_t, int64_t>(context, &op_context,
584 GetTensorData<int>(temp_index),
585 GetTensorData<int>(resolved_axis),
586 GetTensorData<int64_t>(temp_sum), kernel_type);
587 break;
588 case kTfLiteInt8: {
589 TF_LITE_ENSURE_OK(
590 context, EvalIntegerMean<int8_t>(context, op_context, num_axis, data,
591 temp_index, resolved_axis, temp_sum,
592 normalized_dims, kernel_type));
593 } break;
594 case kTfLiteInt16: {
595 TF_LITE_ENSURE_OK(
596 context, EvalIntegerMean<int16_t>(context, op_context, num_axis, data,
597 temp_index, resolved_axis, temp_sum,
598 normalized_dims, kernel_type));
599 } break;
600 case kTfLiteUInt8: {
601 TF_LITE_ENSURE_OK(
602 context, EvalIntegerMean<uint8_t>(context, op_context, num_axis, data,
603 temp_index, resolved_axis, temp_sum,
604 normalized_dims, kernel_type));
605 } break;
606 default:
607 return kTfLiteError;
608 }
609 return kTfLiteOk;
610}
611
612template <typename T>
613struct EvalData {
614 std::function<T(T, T)> reduce_func;
615 const T* input_data;
616 T output;
617};
618
619// Returns true if 'axis' holds all dims [0 ... N-1] where N is num_dims.
620bool IsReduceAllDims(const TfLiteTensor* axis, int num_axis, int num_dims) {
621 int dims_mask = 0;
622 for (int i = 0; i < num_axis; ++i) {
623 dims_mask |= 1 << (axis->data.i32[i]);
624 }
625 return num_dims == 0 ? dims_mask == 0 : (dims_mask == (1 << num_dims) - 1);
626}
627
628// Worker for reducing single interval. Interval is identified by index
629// from [start, end).
630template <typename T>
631struct ReduceWorkerTask : cpu_backend_threadpool::Task {
632 ReduceWorkerTask(EvalData<T>* eval_data, int start, int end)
633 : eval_data(eval_data), start(start), end(end) {}
634 void Run() override {
635 auto* input_data = eval_data->input_data;
636 T& output = eval_data->output;
637 auto& reducer = eval_data->reduce_func;
638 for (int i = start; i < end; ++i) {
639 output = reducer(output, input_data[i]);
640 }
641 }
642
643 private:
644 EvalData<T>* eval_data;
645 int start;
646 int end;
647};
648
649// Apply reduce operation using the 'reducer' function on all of 'input_data'.
650// and reduce all to single element.
651template <typename T>
652void ReduceAllDims(const T* input_data, const int* input_dims,
653 const int input_num_dims, T* output_data, T init_value,
654 T reducer(const T current, const T in),
655 TfLiteContext* context) {
656 EvalData<T> eval_data;
657 eval_data.reduce_func = reducer;
658 eval_data.input_data = input_data;
659 eval_data.output = init_value;
660
661 int num_elems = NumElements(input_dims, input_num_dims);
662
663 // Fetch backend context and number of threads.
664 CpuBackendContext* cpu_backend_context =
665 CpuBackendContext::GetFromContext(context);
666 int thread_count = cpu_backend_context->max_num_threads();
667 const int kMinElementsPerThread = 1024;
668 if (num_elems / thread_count < kMinElementsPerThread) thread_count = 1;
669
670 if (thread_count == 1) {
671 output_data[0] = num_elems > 0 ? input_data[0] : init_value;
672 for (int i = 1; i < num_elems; ++i) {
673 output_data[0] = reducer(output_data[0], input_data[i]);
674 }
675 return;
676 }
677 std::vector<ReduceWorkerTask<T>> tasks;
678 std::vector<EvalData<T>> data;
679 tasks.reserve(thread_count);
680 data.reserve(thread_count);
681 int start = 0;
682 for (int i = 0; i < thread_count; ++i) {
683 data.push_back(eval_data);
684 int end = start + (num_elems - start) / (thread_count - i);
685 tasks.emplace_back(ReduceWorkerTask<T>(&data.back(), start, end));
686 start = end;
687 }
688 // Run all tasks on the thread pool.
689 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
690 cpu_backend_context);
691 // Reduce all data from different workers.
692 output_data[0] = data[0].output;
693 for (int i = 1; i < data.size(); ++i) {
694 output_data[0] = reducer(output_data[0], data[i].output);
695 }
696}
697
698// The underlying logic for Reduce Sum/Prod/Max/Min/Any
699template <typename T>
700TfLiteStatus EvalType(TfLiteContext* context, TfLiteNode* node,
701 OpContext* op_context, KernelType kernel_type,
702 ReduceType reduce_type) {
703 int64_t num_axis = NumElements(op_context->axis);
704 TfLiteTensor* temp_index;
705 TF_LITE_ENSURE_OK(context,
706 GetTemporarySafe(context, node, /*index=*/0, &temp_index));
707 TfLiteTensor* resolved_axis;
708 TF_LITE_ENSURE_OK(
709 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
710 // Resize the output tensor if the output tensor is dynamic.
711 if (IsDynamicTensor(op_context->output)) {
712 TF_LITE_ENSURE_OK(context,
713 ResizeTempAxis(context, op_context, resolved_axis));
714 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
715 }
716
717 const TfLiteTensor* input = op_context->input;
718 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
719 input->type == kTfLiteInt16) {
720 TF_LITE_ENSURE_EQ(context, input->params.scale,
721 op_context->output->params.scale);
722 TF_LITE_ENSURE_EQ(context, input->params.zero_point,
723 op_context->output->params.zero_point);
724 }
725 if (kernel_type == kReference) {
726 T init_value = 0;
727 T (*reducer)(const T current, const T in);
728 switch (reduce_type) {
729 case kSum:
730 reducer = [](const T current, const T in) -> T { return in + current; };
731 init_value = T(0);
732 break;
733 case kProd:
734 init_value = static_cast<T>(1);
735 reducer = [](const T current, const T in) -> T { return in * current; };
736 break;
737 case kMax:
738 init_value = std::numeric_limits<T>::lowest();
739 reducer = [](const T current, const T in) -> T {
740 return (in > current) ? in : current;
741 };
742 break;
743 case kMin:
744 init_value = std::numeric_limits<T>::max();
745 reducer = [](const T current, const T in) -> T {
746 return (in < current) ? in : current;
747 };
748 break;
749 case kAny:
750 init_value = false;
751 reducer = [](const T current, const T in) -> T {
752 return in || current;
753 };
754 break;
755 case kAll:
756 init_value = true;
757 reducer = [](const T current, const T in) -> T {
758 return in && current;
759 };
760 break;
761 default:
762 TF_LITE_KERNEL_LOG(context, "Unsupported ReduceType: %d", reduce_type);
763 return kTfLiteError;
764 }
765
766 int num_resolved_axis = 0;
767 TF_LITE_ENSURE_MSG(
768 context,
769 tflite::reference_ops::ResolveAxis(
770 input->dims->size, GetTensorData<int>(op_context->axis), num_axis,
771 GetTensorData<int>(resolved_axis), &num_resolved_axis),
772 "Invalid axis index.");
773
774 if (IsReduceAllDims(resolved_axis, num_resolved_axis, input->dims->size)) {
775 ReduceAllDims(GetTensorData<T>(input), input->dims->data,
776 input->dims->size, GetTensorData<T>(op_context->output),
777 init_value, reducer, context);
778 return kTfLiteOk;
779 }
780 TF_LITE_ENSURE(
781 context,
782 reference_ops::ReduceGeneric<T>(
783 GetTensorData<T>(input), input->dims->data, input->dims->size,
784 GetTensorData<T>(op_context->output),
785 op_context->output->dims->data, op_context->output->dims->size,
786 GetTensorData<int>(op_context->axis), num_axis,
787 op_context->params->keep_dims, GetTensorData<int>(temp_index),
788 GetTensorData<int>(resolved_axis), init_value, reducer));
789 return kTfLiteOk;
790 } else {
791 TfLiteTensor* normalized_dims;
792 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
793 &normalized_dims));
794 if (IsDynamicTensor(normalized_dims)) {
795 TF_LITE_ENSURE_OK(context,
796 ResizeTempDims(context, op_context, normalized_dims));
797 }
798 TF_LITE_ENSURE(
799 context,
800 optimized_ops::ReduceGeneric<T>(
801 GetTensorData<T>(input), input->dims->data, input->dims->size,
802 GetTensorData<T>(op_context->output),
803 op_context->output->dims->data, op_context->output->dims->size,
804 GetTensorData<int>(op_context->axis), num_axis,
805 GetTensorData<int>(resolved_axis),
806 GetTensorData<int>(normalized_dims), reduce_type));
807 return kTfLiteOk;
808 }
809}
810
811// The entry point that handles input types and then calls template functions to
812// handle ReduceType.
813template <KernelType kernel_type, ReduceType reduce_type>
814TfLiteStatus EvalGeneric(TfLiteContext* context, TfLiteNode* node) {
815 OpContext op_context(context, node);
816 switch (op_context.input->type) {
817 case kTfLiteFloat32:
818 return EvalType<float>(context, node, &op_context, kernel_type,
819 reduce_type);
820 break;
821 case kTfLiteInt32:
822 return EvalType<int>(context, node, &op_context, kernel_type,
823 reduce_type);
824 break;
825 case kTfLiteInt64:
826 return EvalType<int64_t>(context, node, &op_context, kernel_type,
827 reduce_type);
828 break;
829 case kTfLiteUInt8:
830 return EvalType<uint8_t>(context, node, &op_context, kernel_type,
831 reduce_type);
832 break;
833 case kTfLiteInt8:
834 return EvalType<int8_t>(context, node, &op_context, kernel_type,
835 reduce_type);
836 break;
837 case kTfLiteInt16:
838 return EvalType<int16_t>(context, node, &op_context, kernel_type,
839 reduce_type);
840 break;
841 case kTfLiteBool:
842 return EvalType<bool>(context, node, &op_context, kernel_type,
843 reduce_type);
844 break;
845 default:
846 return kTfLiteError;
847 }
848}
849
850template <KernelType kernel_type>
851TfLiteStatus EvalSum(TfLiteContext* context, TfLiteNode* node) {
852 OpContext op_context(context, node);
853 ruy::profiler::ScopeLabel label("Sum");
854 const auto& input = op_context.input;
855 const auto& output = op_context.output;
856 const bool same_scale =
857 (input->params.scale == output->params.scale &&
858 input->params.zero_point == output->params.zero_point);
859 const bool eight_bit_quantized =
860 input->type == kTfLiteUInt8 || input->type == kTfLiteInt8;
861 const bool need_rescale = (eight_bit_quantized && !same_scale);
862 if (need_rescale) {
863 // Rescaling 8bit reduce sum.
864 TfLiteTensor* temp_index;
865 TF_LITE_ENSURE_OK(
866 context, GetTemporarySafe(context, node, /*index=*/0, &temp_index));
867 TfLiteTensor* resolved_axis;
868 TF_LITE_ENSURE_OK(
869 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
870 TfLiteTensor* temp_sum;
871 TF_LITE_ENSURE_OK(context,
872 GetTemporarySafe(context, node, /*index=*/2, &temp_sum));
873 // Resize the output tensor if the output tensor is dynamic.
874 if (IsDynamicTensor(op_context.output)) {
875 TF_LITE_ENSURE_OK(context,
876 ResizeTempAxis(context, &op_context, resolved_axis));
877 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context));
878 TF_LITE_ENSURE_OK(context,
879 ResizeTempAccum(context, &op_context, temp_sum));
880 }
881
882 if (input->type == kTfLiteUInt8) {
883 QuantizedMeanOrSum<uint8_t>(context, &op_context,
884 GetTensorData<int>(temp_index),
885 GetTensorData<int>(resolved_axis),
886 GetTensorData<int32_t>(temp_sum), kernel_type,
887 /*compute_sum=*/true);
888 } else {
889 QuantizedMeanOrSum<int8_t>(context, &op_context,
890 GetTensorData<int>(temp_index),
891 GetTensorData<int>(resolved_axis),
892 GetTensorData<int32_t>(temp_sum), kernel_type,
893 /*compute_sum=*/true);
894 }
895 } else {
896 return EvalGeneric<kernel_type, kSum>(context, node);
897 }
898
899 return kTfLiteOk;
900}
901
902template <KernelType kernel_type, typename T>
903TfLiteStatus EvalQuantizedProd(TfLiteContext* context, TfLiteNode* node,
904 OpContext* op_context) {
905 OpData* data = reinterpret_cast<OpData*>(node->user_data);
906
907 const int64_t num_axis = NumElements(op_context->axis);
908 TfLiteTensor* temp_index;
909 TF_LITE_ENSURE_OK(context,
910 GetTemporarySafe(context, node, /*index=*/0, &temp_index));
911 TfLiteTensor* resolved_axis;
912 TF_LITE_ENSURE_OK(
913 context, GetTemporarySafe(context, node, /*index=*/1, &resolved_axis));
914 TfLiteTensor* temp_prod;
915 TF_LITE_ENSURE_OK(context,
916 GetTemporarySafe(context, node, /*index=*/2, &temp_prod));
917 TfLiteTensor* normalized_dims;
918 TF_LITE_ENSURE_OK(
919 context, GetTemporarySafe(context, node, /*index=*/3, &normalized_dims));
920 const TfLiteTensor* input = op_context->input;
921 TfLiteTensor* output = op_context->output;
922
923 // Return early when input shape has zero dim.
924 for (int i = 0; i < input->dims->size; ++i) {
925 if (input->dims->data[i] == 0) return kTfLiteOk;
926 }
927
928 if (IsDynamicTensor(normalized_dims)) {
929 TF_LITE_ENSURE_OK(context,
930 ResizeTempDims(context, op_context, normalized_dims));
931 }
932 // Resize the output tensor if the output tensor is dynamic.
933 if (IsDynamicTensor(output)) {
934 TF_LITE_ENSURE_OK(context,
935 ResizeTempAxis(context, op_context, resolved_axis));
936 TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, op_context));
937 TF_LITE_ENSURE_OK(context, ResizeTempAccum(context, op_context, temp_prod));
938
939 const int input_size = GetTensorShape(input).FlatSize();
940 const int output_size = GetTensorShape(output).FlatSize();
941 TF_LITE_ENSURE(context, input_size != 0);
942 TF_LITE_ENSURE(context, output_size != 0);
943
944 const int reduced_axis_size = input_size / output_size;
945 const double scaling = GetQuantProdScaling(
946 static_cast<double>(input->params.scale),
947 static_cast<double>(output->params.scale), reduced_axis_size);
948 QuantizeMultiplier(scaling, &data->multiplier, &data->shift);
949 }
950
951 if (kernel_type == kReference) {
952 TF_LITE_ENSURE(
953 context,
954 reference_ops::QuantizedReduceProd<T>(
955 GetTensorData<T>(input), input->params.zero_point,
956 GetTensorShape(input), GetTensorData<T>(output),
957 output->params.zero_point, GetTensorShape(output),
958 GetTensorData<int>(op_context->axis), num_axis,
959 op_context->params->keep_dims, GetTensorData<int>(temp_index),
960 GetTensorData<int>(resolved_axis), GetTensorData<int32>(temp_prod),
961 data->multiplier, data->shift));
962 return kTfLiteOk;
963 } else {
964 TF_LITE_ENSURE(
965 context,
966 optimized_ops::QuantizedReduceProd<T>(
967 GetTensorData<T>(input), input->params.zero_point,
968 GetTensorShape(input), GetTensorData<T>(output),
969 output->params.zero_point, GetTensorShape(output),
970 GetTensorData<int>(op_context->axis), num_axis,
971 GetTensorData<int>(resolved_axis),
972 GetTensorData<int>(normalized_dims),
973 GetTensorData<int32>(temp_prod), data->multiplier, data->shift));
974 return kTfLiteOk;
975 }
976}
977
978template <KernelType kernel_type>
979TfLiteStatus EvalProd(TfLiteContext* context, TfLiteNode* node) {
980 OpContext op_context(context, node);
981 // As we need to support both quantized and non-quantized int8/int16 inputs,
982 // we separate the evaluation between EvalQuantizedProd for quantized
983 // int8/int16 inputs and EvalGeneric for non-quantized int8/int16 (and
984 // other non-quantized types).
985 if (op_context.input->quantization.type != kTfLiteNoQuantization) {
986 if (op_context.input->type == kTfLiteInt8) {
987 return EvalQuantizedProd<kernel_type, int8_t>(context, node, &op_context);
988 } else if (op_context.input->type == kTfLiteInt16) {
989 return EvalQuantizedProd<kernel_type, int16_t>(context, node,
990 &op_context);
991 } else {
992 TF_LITE_KERNEL_LOG(context, "Unsupported quantized data type: %d",
993 op_context.input->type);
994 return kTfLiteError;
995 }
996 } else {
997 return EvalGeneric<kernel_type, kProd>(context, node);
998 }
999}
1000
1001} // namespace reduce
1002
1003using ops::builtin::reduce::ReduceType;
1004
1005TfLiteRegistration* Register_MEAN_OPT() {
1006 static TfLiteRegistration r = {reduce::Init, reduce::Free,
1007 reduce::PrepareMeanOrSum,
1008 reduce::EvalMean<reduce::kGenericOptimized>};
1009 return &r;
1010}
1011
1012TfLiteRegistration* Register_MEAN_REF() {
1013 static TfLiteRegistration r = {reduce::Init, reduce::Free,
1014 reduce::PrepareMeanOrSum,
1015 reduce::EvalMean<reduce::kReference>};
1016 return &r;
1017}
1018
1019TfLiteRegistration* Register_SUM_REF() {
1020 static TfLiteRegistration r = {reduce::Init, reduce::Free,
1021 reduce::PrepareMeanOrSum,
1022 reduce::EvalSum<reduce::kReference>};
1023 return &r;
1024}
1025
1026TfLiteRegistration* Register_SUM_OPT() {
1027 static TfLiteRegistration r = {reduce::Init, reduce::Free,
1028 reduce::PrepareMeanOrSum,
1029 reduce::EvalSum<reduce::kGenericOptimized>};
1030 return &r;
1031}
1032
1033TfLiteRegistration* Register_REDUCE_PROD_REF() {
1034 static TfLiteRegistration r = {reduce::Init, reduce::Free,
1035 reduce::PrepareProd,
1036 reduce::EvalProd<reduce::kReference>};
1037 return &r;
1038}
1039
1040TfLiteRegistration* Register_REDUCE_PROD_OPT() {
1041 static TfLiteRegistration r = {reduce::Init, reduce::Free,
1042 reduce::PrepareProd,
1043 reduce::EvalProd<reduce::kGenericOptimized>};
1044 return &r;
1045}
1046
1047TfLiteRegistration* Register_REDUCE_MAX_REF() {
1048 static TfLiteRegistration r = {
1049 reduce::Init, reduce::Free, reduce::PrepareSimple,
1050 reduce::EvalGeneric<reduce::kReference, ReduceType::kMax>};
1051 return &r;
1052}
1053
1054TfLiteRegistration* Register_REDUCE_MAX_OPT() {
1055 static TfLiteRegistration r = {
1056 reduce::Init, reduce::Free, reduce::PrepareSimple,
1057 reduce::EvalGeneric<reduce::kGenericOptimized, ReduceType::kMax>};
1058 return &r;
1059}
1060
1061TfLiteRegistration* Register_REDUCE_MIN_REF() {
1062 static TfLiteRegistration r = {
1063 reduce::Init, reduce::Free, reduce::PrepareSimple,
1064 reduce::EvalGeneric<reduce::kReference, ReduceType::kMin>};
1065 return &r;
1066}
1067
1068TfLiteRegistration* Register_REDUCE_MIN_OPT() {
1069 static TfLiteRegistration r = {
1070 reduce::Init, reduce::Free, reduce::PrepareSimple,
1071 reduce::EvalGeneric<reduce::kGenericOptimized, ReduceType::kMin>};
1072 return &r;
1073}
1074
1075TfLiteRegistration* Register_REDUCE_ANY_REF() {
1076 static TfLiteRegistration r = {
1077 reduce::Init, reduce::Free, reduce::PrepareAllOrAny,
1078 reduce::EvalGeneric<reduce::kReference, ReduceType::kAny>};
1079 return &r;
1080}
1081
1082TfLiteRegistration* Register_REDUCE_ANY_OPT() {
1083 static TfLiteRegistration r = {
1084 reduce::Init, reduce::Free, reduce::PrepareAllOrAny,
1085 reduce::EvalGeneric<reduce::kGenericOptimized, ReduceType::kAny>};
1086 return &r;
1087}
1088
1089TfLiteRegistration* Register_REDUCE_ALL_REF() {
1090 static TfLiteRegistration r = {
1091 reduce::Init, reduce::Free, reduce::PrepareAllOrAny,
1092 reduce::EvalGeneric<reduce::kReference, ReduceType::kAll>};
1093 return &r;
1094}
1095
1096TfLiteRegistration* Register_REDUCE_ALL_OPT() {
1097 static TfLiteRegistration r = {
1098 reduce::Init, reduce::Free, reduce::PrepareAllOrAny,
1099 reduce::EvalGeneric<reduce::kGenericOptimized, ReduceType::kAll>};
1100 return &r;
1101}
1102
1103TfLiteRegistration* Register_MEAN() { return Register_MEAN_OPT(); }
1104
1105TfLiteRegistration* Register_SUM() { return Register_SUM_OPT(); }
1106TfLiteRegistration* Register_REDUCE_PROD() {
1107 return Register_REDUCE_PROD_OPT();
1108}
1109TfLiteRegistration* Register_REDUCE_MAX() { return Register_REDUCE_MAX_OPT(); }
1110TfLiteRegistration* Register_REDUCE_MIN() { return Register_REDUCE_MIN_OPT(); }
1111TfLiteRegistration* Register_REDUCE_ANY() { return Register_REDUCE_ANY_OPT(); }
1112TfLiteRegistration* Register_REDUCE_ALL() { return Register_REDUCE_ALL_OPT(); }
1113
1114} // namespace builtin
1115} // namespace ops
1116} // namespace tflite
1117