1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include <algorithm> |
18 | |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/cpu_backend_threadpool.h" |
21 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace add_n { |
30 | |
31 | constexpr int kInputTensor1 = 0; |
32 | constexpr int kOutputTensor = 0; |
33 | |
34 | struct OpData { |
35 | // The index of the temporary tensor where temporary accumulations are kept. |
36 | int scratch_tensor_index; |
37 | }; |
38 | |
39 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
40 | auto* op_data = new OpData(); |
41 | context->AddTensors(context, 1, &op_data->scratch_tensor_index); |
42 | return op_data; |
43 | } |
44 | |
45 | void Free(TfLiteContext* context, void* buffer) { |
46 | delete reinterpret_cast<OpData*>(buffer); |
47 | } |
48 | |
49 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
50 | int num_inputs = NumInputs(node); |
51 | TF_LITE_ENSURE(context, num_inputs >= 2); |
52 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
53 | |
54 | const TfLiteTensor* input1; |
55 | TF_LITE_ENSURE_OK(context, |
56 | GetInputSafe(context, node, kInputTensor1, &input1)); |
57 | TfLiteTensor* output; |
58 | TF_LITE_ENSURE_OK(context, |
59 | GetOutputSafe(context, node, kOutputTensor, &output)); |
60 | output->type = input1->type; |
61 | |
62 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
63 | TfLiteIntArrayFree(node->temporaries); |
64 | node->temporaries = TfLiteIntArrayCreate(1); |
65 | node->temporaries->data[0] = op_data->scratch_tensor_index; |
66 | TfLiteTensor* scratch_tensor; |
67 | TF_LITE_ENSURE_OK( |
68 | context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor)); |
69 | scratch_tensor->type = input1->type; |
70 | scratch_tensor->allocation_type = kTfLiteArenaRw; |
71 | |
72 | CpuBackendContext* cpu_backend_context = |
73 | CpuBackendContext::GetFromContext(context); |
74 | // Choose the proper number of thread so that: |
75 | // (1) Each thread gets at least two tensors (1 if we only have 1 input |
76 | // tensor). |
77 | // (2) Total thread_count should be bounded by the maximimum allowed threads. |
78 | // (3) Tensors are distributed evenly across different threads. |
79 | const int thread_count = |
80 | std::min(std::max(1, static_cast<int>(num_inputs) / 2), |
81 | cpu_backend_context->max_num_threads()); |
82 | |
83 | TfLiteIntArray* scratch_shape = TfLiteIntArrayCreate(1); |
84 | scratch_shape->data[0] = thread_count * NumElements(input1); |
85 | TF_LITE_ENSURE_OK( |
86 | context, context->ResizeTensor(context, scratch_tensor, scratch_shape)); |
87 | |
88 | // Check that all input tensors have the same shape and type. |
89 | for (int i = kInputTensor1 + 1; i < num_inputs; ++i) { |
90 | const TfLiteTensor* input; |
91 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &input)); |
92 | TF_LITE_ENSURE(context, HaveSameShapes(input1, input)); |
93 | TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input->type); |
94 | } |
95 | |
96 | // Use the first input node's dimension to be the dimension of the output |
97 | // node. |
98 | TfLiteIntArray* input1_dims = input1->dims; |
99 | TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims); |
100 | return context->ResizeTensor(context, output, output_dims); |
101 | } |
102 | |
103 | template <typename T> |
104 | TfLiteStatus EvalAddN(TfLiteContext* context, TfLiteNode* node) { |
105 | // TODO(haoliang): Initialize all_inputs only once during init. |
106 | VectorOfTensors<T> all_inputs(*context, *node->inputs); |
107 | // Safe to use unchecked since caller checks that tensor is valid |
108 | TfLiteTensor* output = GetOutput(context, node, kOutputTensor); |
109 | int num_inputs = NumInputs(node); |
110 | // Safe to use unchecked since caller checks that tensor is valid |
111 | const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); |
112 | |
113 | // Fetch backend context and number of threads. |
114 | CpuBackendContext* cpu_backend_context = |
115 | CpuBackendContext::GetFromContext(context); |
116 | TfLiteTensor* scratch_tensor; |
117 | TF_LITE_ENSURE_OK(context, |
118 | GetTemporarySafe(context, node, 0, &scratch_tensor)); |
119 | optimized_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(), |
120 | GetTensorData<T>(output), |
121 | GetTensorData<T>(scratch_tensor), cpu_backend_context); |
122 | return kTfLiteOk; |
123 | } |
124 | |
125 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
126 | const TfLiteTensor* input1; |
127 | TF_LITE_ENSURE_OK(context, |
128 | GetInputSafe(context, node, kInputTensor1, &input1)); |
129 | TfLiteTensor* output; |
130 | TF_LITE_ENSURE_OK(context, |
131 | GetOutputSafe(context, node, kOutputTensor, &output)); |
132 | if (output->type == kTfLiteFloat32) { |
133 | TF_LITE_ENSURE_OK(context, EvalAddN<float>(context, node)); |
134 | } else if (output->type == kTfLiteInt32) { |
135 | TF_LITE_ENSURE_OK(context, EvalAddN<int32_t>(context, node)); |
136 | } else { |
137 | TF_LITE_KERNEL_LOG(context, "AddN only supports FLOAT32|INT32 now, got %s." , |
138 | TfLiteTypeGetName(output->type)); |
139 | return kTfLiteError; |
140 | } |
141 | return kTfLiteOk; |
142 | } |
143 | |
144 | } // namespace add_n |
145 | |
146 | TfLiteRegistration* Register_ADD_N() { |
147 | static TfLiteRegistration r = {add_n::Init, add_n::Free, add_n::Prepare, |
148 | add_n::Eval}; |
149 | return &r; |
150 | } |
151 | |
152 | } // namespace builtin |
153 | } // namespace ops |
154 | } // namespace tflite |
155 | |