1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/concatenation.h"
16
17#include <stdint.h>
18#include <limits>
19
20#include "tensorflow/lite/c/builtin_op_data.h"
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/compatibility.h"
23#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
24#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
25#include "tensorflow/lite/kernels/internal/tensor.h"
26#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27#include "tensorflow/lite/kernels/internal/types.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace concatenation {
34
35// This file has two implementation of Concatenation.
36enum KernelType {
37 kReference,
38 kGenericOptimized,
39};
40
41template <KernelType kernel_type>
42TfLiteStatus EvalImpl(TfLiteContext* context, TfLiteNode* node, int axis,
43 TfLiteTensor* output) {
44// TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should
45// allocate and populate these during Prepare().
46// TODO(ycling): Activation function parameter is ignored. For now we don't have
47// a model with a Concatenation with fused activation function.
48#define TF_LITE_CONCATENATION(scalar) \
49 { \
50 VectorOfTensors<scalar> all_inputs(*context, *node->inputs); \
51 tflite::ConcatenationParams op_params; \
52 op_params.axis = axis; \
53 op_params.inputs_count = node->inputs->size; \
54 if (kernel_type == kReference) { \
55 reference_ops::Concatenation(op_params, all_inputs.shapes(), \
56 all_inputs.data(), GetTensorShape(output), \
57 GetTensorData<scalar>(output)); \
58 } else { \
59 optimized_ops::Concatenation(op_params, all_inputs.shapes(), \
60 all_inputs.data(), GetTensorShape(output), \
61 GetTensorData<scalar>(output)); \
62 } \
63 }
64
65#define TF_LITE_CONCATENATION_QUANTIZED() \
66 { \
67 VectorOfQuantizedTensors all_inputs(*context, *node->inputs); \
68 tflite::ConcatenationParams op_params; \
69 op_params.axis = axis; \
70 op_params.input_zeropoint = all_inputs.zero_point(); \
71 op_params.input_scale = all_inputs.scale(); \
72 op_params.inputs_count = node->inputs->size; \
73 op_params.output_zeropoint = output->params.zero_point; \
74 op_params.output_scale = output->params.scale; \
75 if (kernel_type == kReference) { \
76 reference_ops::ConcatenationWithScaling( \
77 op_params, all_inputs.shapes(), all_inputs.data(), \
78 GetTensorShape(output), GetTensorData<uint8>(output)); \
79 } else { \
80 optimized_ops::ConcatenationWithScaling( \
81 op_params, all_inputs.shapes(), all_inputs.data(), \
82 GetTensorShape(output), GetTensorData<uint8>(output)); \
83 } \
84 }
85
86 switch (output->type) { // Already know in/outtypes are same.
87 case kTfLiteFloat32:
88 TF_LITE_CONCATENATION(float);
89 break;
90 case kTfLiteInt32:
91 TF_LITE_CONCATENATION(int32);
92 break;
93 case kTfLiteUInt8:
94 TF_LITE_CONCATENATION_QUANTIZED();
95 break;
96 case kTfLiteInt8:
97 TF_LITE_CONCATENATION(int8_t);
98 break;
99 case kTfLiteInt64:
100 TF_LITE_CONCATENATION(int64_t);
101 break;
102 case kTfLiteInt16:
103 TF_LITE_CONCATENATION(int16_t);
104 break;
105 case kTfLiteBool:
106 TF_LITE_CONCATENATION(bool);
107 break;
108 default:
109 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported currently.",
110 TfLiteTypeGetName(output->type));
111 return kTfLiteError;
112 }
113
114#undef TF_LITE_CONCATENATION_QUANTIZED
115#undef TF_LITE_CONCATENATION
116
117 return kTfLiteOk;
118}
119
120TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
121 auto* params =
122 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
123 int axis = params->axis;
124 int num_inputs = node->inputs->size;
125
126 // The number of dimensions of the input tensors must match, and all
127 // dimensions except 'axis' must be equal.
128 const TfLiteTensor* t0;
129 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &t0));
130 TfLiteType input_type = t0->type;
131 if (axis < 0) axis += t0->dims->size;
132 TF_LITE_ENSURE(context, axis >= 0);
133 TF_LITE_ENSURE(context, axis < t0->dims->size);
134
135 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
136 TF_LITE_ENSURE(context,
137 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
138 input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
139 input_type == kTfLiteInt32 || input_type == kTfLiteInt64 ||
140 input_type == kTfLiteBool);
141
142 // Output dimensions will match input dimensions, except 'axis', which
143 // will be the sum of inputs
144 int sum_axis = t0->dims->data[axis];
145 for (int i = 1; i < num_inputs; ++i) {
146 const TfLiteTensor* t;
147 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
148 TF_LITE_ENSURE_EQ(context, t->dims->size, t0->dims->size);
149 TF_LITE_ENSURE_EQ(context, t->type, input_type);
150 for (int d = 0; d < t0->dims->size; ++d) {
151 if (d == axis) {
152 // Avoid integer overflow in sum_axis below
153 TF_LITE_ENSURE(context, t->dims->data[axis] >= 0);
154 TF_LITE_ENSURE(context, t->dims->data[axis] <=
155 std::numeric_limits<int>::max() - sum_axis);
156 sum_axis += t->dims->data[axis];
157 } else {
158 TF_LITE_ENSURE_EQ(context, t->dims->data[d], t0->dims->data[d]);
159 }
160 }
161 }
162
163 TfLiteIntArray* output_size = TfLiteIntArrayCreate(t0->dims->size);
164 for (int d = 0; d < t0->dims->size; ++d) {
165 output_size->data[d] = (d == axis) ? sum_axis : t0->dims->data[d];
166 }
167
168 TfLiteTensor* output;
169 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
170 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
171
172 if (input_type == kTfLiteInt8) {
173 // Make sure there is no re-scaling needed for Int8 quantized kernel. This
174 // is a restriction we introduced to Int8 kernels.
175 VectorOfTensors<int8_t> all_inputs(*context, *node->inputs);
176 for (int i = 0; i < node->inputs->size; ++i) {
177 const TfLiteTensor* t;
178 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
179 TF_LITE_ENSURE_EQ(context, t->params.scale, output->params.scale);
180 TF_LITE_ENSURE_EQ(context, t->params.zero_point,
181 output->params.zero_point);
182 }
183 }
184
185 if (input_type == kTfLiteInt16) {
186 // Make sure that all Int16 inputs have a null zero-point.
187 for (int i = 0; i < node->inputs->size; ++i) {
188 const TfLiteTensor* t = GetInput(context, node, i);
189 TF_LITE_ENSURE_EQ(context, t->params.zero_point, 0);
190 }
191 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
192 }
193
194 // Check to see if we can calculate the output now.
195 bool all_inputs_at_prepare = true;
196 for (int i = 0; i < num_inputs; ++i) {
197 const TfLiteTensor* t;
198 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, i, &t));
199 if (!IsConstantOrPersistentTensor(t)) {
200 all_inputs_at_prepare = false;
201 break;
202 }
203 }
204 if (all_inputs_at_prepare) {
205 SetTensorToPersistentRo(output);
206 context->ResizeTensor(context, output, output_size);
207 return EvalImpl<kReference>(context, node, axis, output);
208 }
209 return context->ResizeTensor(context, output, output_size);
210}
211
212template <KernelType kernel_type>
213TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
214 auto* params =
215 reinterpret_cast<TfLiteConcatenationParams*>(node->builtin_data);
216 int axis = params->axis;
217 TfLiteTensor* output;
218 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
219 if (IsConstantOrPersistentTensor(output)) {
220 // Output is computed in Prepare.
221 return kTfLiteOk;
222 }
223 if (axis < 0) axis += output->dims->size;
224
225 return EvalImpl<kernel_type>(context, node, axis, output);
226}
227
228#undef TF_LITE_MACRO_DISPATCH
229
230} // namespace concatenation
231
232TfLiteRegistration* Register_CONCATENATION_REF() {
233 static TfLiteRegistration r = {
234 nullptr, nullptr, concatenation::Prepare,
235 concatenation::Eval<concatenation::kReference>};
236 return &r;
237}
238
239TfLiteRegistration* Register_CONCATENATION_GENERIC_OPT() {
240 static TfLiteRegistration r = {
241 nullptr, nullptr, concatenation::Prepare,
242 concatenation::Eval<concatenation::kGenericOptimized>};
243 return &r;
244}
245
246TfLiteRegistration* Register_CONCATENATION() {
247 // TODO(ahentz): It turns out the two versions of Concatenation are almost
248 // identical, so we should consider removing one.
249 return Register_CONCATENATION_GENERIC_OPT();
250}
251
252} // namespace builtin
253} // namespace ops
254} // namespace tflite
255