1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include <algorithm>
18
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/cpu_backend_threadpool.h"
21#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22#include "tensorflow/lite/kernels/internal/tensor.h"
23#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace add_n {
30
31constexpr int kInputTensor1 = 0;
32constexpr int kOutputTensor = 0;
33
34struct OpData {
35 // The index of the temporary tensor where temporary accumulations are kept.
36 int scratch_tensor_index;
37};
38
39void* Init(TfLiteContext* context, const char* buffer, size_t length) {
40 auto* op_data = new OpData();
41 context->AddTensors(context, 1, &op_data->scratch_tensor_index);
42 return op_data;
43}
44
45void Free(TfLiteContext* context, void* buffer) {
46 delete reinterpret_cast<OpData*>(buffer);
47}
48
49TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
50 int num_inputs = NumInputs(node);
51 TF_LITE_ENSURE(context, num_inputs >= 2);
52 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
53
54 const TfLiteTensor* input1;
55 TF_LITE_ENSURE_OK(context,
56 GetInputSafe(context, node, kInputTensor1, &input1));
57 TfLiteTensor* output;
58 TF_LITE_ENSURE_OK(context,
59 GetOutputSafe(context, node, kOutputTensor, &output));
60 output->type = input1->type;
61
62 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
63 TfLiteIntArrayFree(node->temporaries);
64 node->temporaries = TfLiteIntArrayCreate(1);
65 node->temporaries->data[0] = op_data->scratch_tensor_index;
66 TfLiteTensor* scratch_tensor;
67 TF_LITE_ENSURE_OK(
68 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
69 scratch_tensor->type = input1->type;
70 scratch_tensor->allocation_type = kTfLiteArenaRw;
71
72 CpuBackendContext* cpu_backend_context =
73 CpuBackendContext::GetFromContext(context);
74 // Choose the proper number of thread so that:
75 // (1) Each thread gets at least two tensors (1 if we only have 1 input
76 // tensor).
77 // (2) Total thread_count should be bounded by the maximimum allowed threads.
78 // (3) Tensors are distributed evenly across different threads.
79 const int thread_count =
80 std::min(std::max(1, static_cast<int>(num_inputs) / 2),
81 cpu_backend_context->max_num_threads());
82
83 TfLiteIntArray* scratch_shape = TfLiteIntArrayCreate(1);
84 scratch_shape->data[0] = thread_count * NumElements(input1);
85 TF_LITE_ENSURE_OK(
86 context, context->ResizeTensor(context, scratch_tensor, scratch_shape));
87
88 // Check that all input tensors have the same shape and type.
89 for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
90 const TfLiteTensor* input;
91 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input));
92 TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
93 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type);
94 }
95
96 // Use the first input node's dimension to be the dimension of the output
97 // node.
98 TfLiteIntArray* input1_dims = input1->dims;
99 TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims);
100 return context->ResizeTensor(context, output, output_dims);
101}
102
103template <typename T>
104TfLiteStatus EvalAddN(TfLiteContext* context, TfLiteNode* node) {
105 // TODO(haoliang): Initialize all_inputs only once during init.
106 VectorOfTensors<T> all_inputs(*context, *node->inputs);
107 // Safe to use unchecked since caller checks that tensor is valid
108 TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
109 int num_inputs = NumInputs(node);
110 // Safe to use unchecked since caller checks that tensor is valid
111 const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
112
113 // Fetch backend context and number of threads.
114 CpuBackendContext* cpu_backend_context =
115 CpuBackendContext::GetFromContext(context);
116 TfLiteTensor* scratch_tensor;
117 TF_LITE_ENSURE_OK(context,
118 GetTemporarySafe(context, node, 0, &scratch_tensor));
119 optimized_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
120 GetTensorData<T>(output),
121 GetTensorData<T>(scratch_tensor), cpu_backend_context);
122 return kTfLiteOk;
123}
124
125TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
126 const TfLiteTensor* input1;
127 TF_LITE_ENSURE_OK(context,
128 GetInputSafe(context, node, kInputTensor1, &input1));
129 TfLiteTensor* output;
130 TF_LITE_ENSURE_OK(context,
131 GetOutputSafe(context, node, kOutputTensor, &output));
132 if (output->type == kTfLiteFloat32) {
133 TF_LITE_ENSURE_OK(context, EvalAddN<float>(context, node));
134 } else if (output->type == kTfLiteInt32) {
135 TF_LITE_ENSURE_OK(context, EvalAddN<int32_t>(context, node));
136 } else {
137 TF_LITE_KERNEL_LOG(context, "AddN only supports FLOAT32|INT32 now, got %s.",
138 TfLiteTypeGetName(output->type));
139 return kTfLiteError;
140 }
141 return kTfLiteOk;
142}
143
144} // namespace add_n
145
146TfLiteRegistration* Register_ADD_N() {
147 static TfLiteRegistration r = {add_n::Init, add_n::Free, add_n::Prepare,
148 add_n::Eval};
149 return &r;
150}
151
152} // namespace builtin
153} // namespace ops
154} // namespace tflite
155