1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/concatenation.h" |
16 | |
17 | #include <stdint.h> |
18 | #include <limits> |
19 | |
20 | #include "tensorflow/lite/c/builtin_op_data.h" |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
23 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
24 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
25 | #include "tensorflow/lite/kernels/internal/tensor.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
27 | #include "tensorflow/lite/kernels/internal/types.h" |
28 | #include "tensorflow/lite/kernels/kernel_util.h" |
29 | |
30 | namespace tflite { |
31 | namespace ops { |
32 | namespace builtin { |
33 | namespace concatenation { |
34 | |
35 | // This file has two implementation of Concatenation. |
36 | enum KernelType { |
37 | kReference, |
38 | kGenericOptimized, |
39 | }; |
40 | |
41 | template <KernelType kernel_type> |
42 | TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, int axis, |
43 | TfLiteTensor* output) { |
44 | // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should |
45 | // allocate and populate these during Prepare(). |
46 | // TODO(ycling): Activation function parameter is ignored. For now we don't have |
47 | // a model with a Concatenation with fused activation function. |
48 | #define TF_LITE_CONCATENATION(scalar) \ |
49 | { \ |
50 | VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \ |
51 | tflite::ConcatenationParams op_params; \ |
52 | op_params.axis = axis; \ |
53 | op_params.inputs_count = node->inputs->size; \ |
54 | if (kernel_type == kReference) { \ |
55 | reference_ops::Concatenation(op_params, all_inputs.shapes(), \ |
56 | all_inputs.data(), GetTensorShape(output), \ |
57 | GetTensorData<scalar>(output)); \ |
58 | } else { \ |
59 | optimized_ops::Concatenation(op_params, all_inputs.shapes(), \ |
60 | all_inputs.data(), GetTensorShape(output), \ |
61 | GetTensorData<scalar>(output)); \ |
62 | } \ |
63 | } |
64 | |
65 | #define TF_LITE_CONCATENATION_QUANTIZED() \ |
66 | { \ |
67 | VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \ |
68 | tflite::ConcatenationParams op_params; \ |
69 | op_params.axis = axis; \ |
70 | op_params.input_zeropoint = all_inputs.zero_point(); \ |
71 | op_params.input_scale = all_inputs.scale(); \ |
72 | op_params.inputs_count = node->inputs->size; \ |
73 | op_params.output_zeropoint = output->params.zero_point; \ |
74 | op_params.output_scale = output->params.scale; \ |
75 | if (kernel_type == kReference) { \ |
76 | reference_ops::ConcatenationWithScaling( \ |
77 | op_params, all_inputs.shapes(), all_inputs.data(), \ |
78 | GetTensorShape(output), GetTensorData<uint8>(output)); \ |
79 | } else { \ |
80 | optimized_ops::ConcatenationWithScaling( \ |
81 | op_params, all_inputs.shapes(), all_inputs.data(), \ |
82 | GetTensorShape(output), GetTensorData<uint8>(output)); \ |
83 | } \ |
84 | } |
85 | |
86 | switch (output->type) { // Already know in/outtypes are same. |
87 | case kTfLiteFloat32: |
88 | TF_LITE_CONCATENATION(float); |
89 | break; |
90 | case kTfLiteInt32: |
91 | TF_LITE_CONCATENATION(int32); |
92 | break; |
93 | case kTfLiteUInt8: |
94 | TF_LITE_CONCATENATION_QUANTIZED(); |
95 | break; |
96 | case kTfLiteInt8: |
97 | TF_LITE_CONCATENATION(int8_t); |
98 | break; |
99 | case kTfLiteInt64: |
100 | TF_LITE_CONCATENATION(int64_t); |
101 | break; |
102 | case kTfLiteInt16: |
103 | TF_LITE_CONCATENATION(int16_t); |
104 | break; |
105 | case kTfLiteBool: |
106 | TF_LITE_CONCATENATION(bool); |
107 | break; |
108 | default: |
109 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported currently." , |
110 | TfLiteTypeGetName(output->type)); |
111 | return kTfLiteError; |
112 | } |
113 | |
114 | #undef TF_LITE_CONCATENATION_QUANTIZED |
115 | #undef TF_LITE_CONCATENATION |
116 | |
117 | return kTfLiteOk; |
118 | } |
119 | |
120 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
121 | auto* params = |
122 | reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); |
123 | int axis = params->axis; |
124 | int num_inputs = node->inputs->size; |
125 | |
126 | // The number of dimensions of the input tensors must match, and all |
127 | // dimensions except 'axis' must be equal. |
128 | const TfLiteTensor* t0; |
129 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0)); |
130 | TfLiteType input_type = t0->type; |
131 | if (axis < 0) axis += t0->dims->size; |
132 | TF_LITE_ENSURE(context, axis >= 0); |
133 | TF_LITE_ENSURE(context, axis < t0->dims->size); |
134 | |
135 | TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); |
136 | TF_LITE_ENSURE(context, |
137 | input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || |
138 | input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || |
139 | input_type == kTfLiteInt32 || input_type == kTfLiteInt64 || |
140 | input_type == kTfLiteBool); |
141 | |
142 | // Output dimensions will match input dimensions, except 'axis', which |
143 | // will be the sum of inputs |
144 | int sum_axis = t0->dims->data[axis]; |
145 | for (int i = 1; i < num_inputs; ++i) { |
146 | const TfLiteTensor* t; |
147 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t)); |
148 | TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size); |
149 | TF_LITE_ENSURE_EQ(context, t->type, input_type); |
150 | for (int d = 0; d < t0->dims->size; ++d) { |
151 | if (d == axis) { |
152 | // Avoid integer overflow in sum_axis below |
153 | TF_LITE_ENSURE(context, t->dims->data[axis] >= 0); |
154 | TF_LITE_ENSURE(context, t->dims->data[axis] <= |
155 | std::numeric_limits<int>::max() - sum_axis); |
156 | sum_axis += t->dims->data[axis]; |
157 | } else { |
158 | TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]); |
159 | } |
160 | } |
161 | } |
162 | |
163 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size); |
164 | for (int d = 0; d < t0->dims->size; ++d) { |
165 | output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d]; |
166 | } |
167 | |
168 | TfLiteTensor* output; |
169 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
170 | TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type); |
171 | |
172 | if (input_type == kTfLiteInt8) { |
173 | // Make sure there is no re-scaling needed for Int8 quantized kernel. This |
174 | // is a restriction we introduced to Int8 kernels. |
175 | VectorOfTensors<int8_t> all_inputs(*context, *node->inputs); |
176 | for (int i = 0; i < node->inputs->size; ++i) { |
177 | const TfLiteTensor* t; |
178 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t)); |
179 | TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale); |
180 | TF_LITE_ENSURE_EQ(context, t->params.zero_point, |
181 | output->params.zero_point); |
182 | } |
183 | } |
184 | |
185 | if (input_type == kTfLiteInt16) { |
186 | // Make sure that all Int16 inputs have a null zero-point. |
187 | for (int i = 0; i < node->inputs->size; ++i) { |
188 | const TfLiteTensor* t = GetInput(context, node, i); |
189 | TF_LITE_ENSURE_EQ(context, t->params.zero_point, 0); |
190 | } |
191 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
192 | } |
193 | |
194 | // Check to see if we can calculate the output now. |
195 | bool all_inputs_at_prepare = true; |
196 | for (int i = 0; i < num_inputs; ++i) { |
197 | const TfLiteTensor* t; |
198 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t)); |
199 | if (!IsConstantOrPersistentTensor(t)) { |
200 | all_inputs_at_prepare = false; |
201 | break; |
202 | } |
203 | } |
204 | if (all_inputs_at_prepare) { |
205 | SetTensorToPersistentRo(output); |
206 | context->ResizeTensor(context, output, output_size); |
207 | return EvalImpl<kReference>(context, node, axis, output); |
208 | } |
209 | return context->ResizeTensor(context, output, output_size); |
210 | } |
211 | |
212 | template <KernelType kernel_type> |
213 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
214 | auto* params = |
215 | reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data); |
216 | int axis = params->axis; |
217 | TfLiteTensor* output; |
218 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output)); |
219 | if (IsConstantOrPersistentTensor(output)) { |
220 | // Output is computed in Prepare. |
221 | return kTfLiteOk; |
222 | } |
223 | if (axis < 0) axis += output->dims->size; |
224 | |
225 | return EvalImpl<kernel_type>(context, node, axis, output); |
226 | } |
227 | |
228 | #undef TF_LITE_MACRO_DISPATCH |
229 | |
230 | } // namespace concatenation |
231 | |
232 | TfLiteRegistration* Register_CONCATENATION_REF() { |
233 | static TfLiteRegistration r = { |
234 | nullptr, nullptr, concatenation::Prepare, |
235 | concatenation::Eval<concatenation::kReference>}; |
236 | return &r; |
237 | } |
238 | |
239 | TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() { |
240 | static TfLiteRegistration r = { |
241 | nullptr, nullptr, concatenation::Prepare, |
242 | concatenation::Eval<concatenation::kGenericOptimized>}; |
243 | return &r; |
244 | } |
245 | |
246 | TfLiteRegistration* Register_CONCATENATION() { |
247 | // TODO(ahentz): It turns out the two versions of Concatenation are almost |
248 | // identical, so we should consider removing one. |
249 | return Register_CONCATENATION_GENERIC_OPT(); |
250 | } |
251 | |
252 | } // namespace builtin |
253 | } // namespace ops |
254 | } // namespace tflite |
255 | |