1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25#include "tensorflow/lite/string_util.h"
26
27namespace tflite {
28namespace ops {
29namespace builtin {
30namespace gather {
31constexpr int kInputTensor = 0;
32constexpr int kInputPositions = 1;
33constexpr int kOutputTensor = 0;
34
35TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
36 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
37 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
38
39 const auto* params =
40 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
41 const TfLiteTensor* input;
42 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
43 const TfLiteTensor* positions;
44 TF_LITE_ENSURE_OK(context,
45 GetInputSafe(context, node, kInputPositions, &positions));
46 TfLiteTensor* output;
47 TF_LITE_ENSURE_OK(context,
48 GetOutputSafe(context, node, kOutputTensor, &output));
49
50 switch (positions->type) {
51 case kTfLiteInt64:
52 case kTfLiteInt32:
53 break;
54 default:
55 TF_LITE_KERNEL_LOG(context,
56 "Positions of type '%s' are not supported by gather.",
57 TfLiteTypeGetName(positions->type));
58 return kTfLiteError;
59 }
60
61 // Assign to output the input type.
62 output->type = input->type;
63
64 // Check conditions for different types.
65 switch (input->type) {
66 case kTfLiteFloat32:
67 case kTfLiteUInt8:
68 case kTfLiteInt8:
69 case kTfLiteInt16:
70 case kTfLiteInt64:
71 case kTfLiteInt32:
72 case kTfLiteBool:
73 break;
74 case kTfLiteString: {
75 // Only 1D input is supported.
76 TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1);
77 } break;
78 default:
79 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
80 TfLiteTypeGetName(input->type));
81 return kTfLiteError;
82 }
83
84 int axis = params->axis;
85 if (axis < 0) {
86 axis += NumDimensions(input);
87 }
88 TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input));
89
90 int batch_dims = params->batch_dims;
91 // batch_dims should be in range: [-rank(positions), rank(positions)].
92 // Negative batch_dims is added with rank of positions.
93 if (batch_dims < 0) {
94 batch_dims += NumDimensions(positions);
95 }
96 TF_LITE_ENSURE(context, batch_dims <= axis);
97 TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input));
98 TF_LITE_ENSURE(context, batch_dims <= NumDimensions(positions));
99 for (int i = 0; i < batch_dims; ++i) {
100 TF_LITE_ENSURE_EQ(context, input->dims->data[i], positions->dims->data[i]);
101 }
102
103 const int num_dimensions =
104 NumDimensions(input) + NumDimensions(positions) - 1 - batch_dims;
105 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions);
106 int output_index = 0;
107 for (int i = 0; i < axis; ++i) {
108 output_shape->data[output_index++] = input->dims->data[i];
109 }
110 for (int i = batch_dims; i < positions->dims->size; ++i) {
111 output_shape->data[output_index++] = positions->dims->data[i];
112 }
113 for (int i = axis + 1; i < input->dims->size; ++i) {
114 output_shape->data[output_index++] = input->dims->data[i];
115 }
116 return context->ResizeTensor(context, output, output_shape);
117}
118
119template <typename InputT, typename PositionsT>
120TfLiteStatus Gather(TfLiteContext* context, const TfLiteGatherParams& params,
121 const TfLiteTensor* input, const TfLiteTensor* positions,
122 TfLiteTensor* output) {
123 const PositionsT* indexes = GetTensorData<PositionsT>(positions);
124 bool indices_has_only_positive_elements = true;
125 const size_t num_indices = positions->bytes / sizeof(PositionsT);
126 for (size_t i = 0; i < num_indices; i++) {
127 if (indexes[i] < 0) {
128 indices_has_only_positive_elements = false;
129 break;
130 }
131 }
132 TF_LITE_ENSURE(context, indices_has_only_positive_elements);
133
134 tflite::GatherParams op_params;
135 op_params.axis = params.axis;
136 op_params.batch_dims = params.batch_dims;
137 return optimized_ops::Gather(
138 op_params, GetTensorShape(input), GetTensorData<InputT>(input),
139 GetTensorShape(positions), GetTensorData<PositionsT>(positions),
140 GetTensorShape(output), GetTensorData<InputT>(output));
141}
142
143template <typename PositionT>
144TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input,
145 const TfLiteTensor* positions,
146 TfLiteTensor* output) {
147 DynamicBuffer buffer;
148
149 const PositionT* indexes = GetTensorData<PositionT>(positions);
150 bool indices_has_only_positive_elements = true;
151 const size_t num_indices = positions->bytes / sizeof(PositionT);
152 for (size_t i = 0; i < num_indices; i++) {
153 if (indexes[i] < 0) {
154 indices_has_only_positive_elements = false;
155 break;
156 }
157 }
158 TF_LITE_ENSURE(context, indices_has_only_positive_elements);
159
160 const PositionT num_strings = GetStringCount(input);
161 const int num_indexes = NumElements(positions);
162
163 for (int i = 0; i < num_indexes; ++i) {
164 const PositionT pos = indexes[i];
165 TF_LITE_ENSURE(context, pos < num_strings);
166 const auto string_ref = GetString(input, pos);
167 buffer.AddString(string_ref.str, string_ref.len);
168 }
169 buffer.WriteToTensor(output, /*new_shape=*/nullptr);
170 return kTfLiteOk;
171}
172
173TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
174 const auto* params =
175 reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data);
176 const TfLiteTensor* input;
177 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
178 const TfLiteTensor* positions;
179 TF_LITE_ENSURE_OK(context,
180 GetInputSafe(context, node, kInputPositions, &positions));
181 TfLiteTensor* output;
182 TF_LITE_ENSURE_OK(context,
183 GetOutputSafe(context, node, kOutputTensor, &output));
184
185 TfLiteStatus status = kTfLiteError;
186 if (positions->type == kTfLiteInt32) {
187 switch (input->type) {
188 case kTfLiteFloat32:
189 status =
190 Gather<float, int32_t>(context, *params, input, positions, output);
191 break;
192 case kTfLiteUInt8:
193 status = Gather<uint8_t, int32_t>(context, *params, input, positions,
194 output);
195 break;
196 case kTfLiteInt8:
197 status =
198 Gather<int8_t, int32_t>(context, *params, input, positions, output);
199 break;
200 case kTfLiteInt16:
201 status = Gather<int16_t, int32_t>(context, *params, input, positions,
202 output);
203 break;
204 case kTfLiteInt32:
205 status = Gather<int32_t, int32_t>(context, *params, input, positions,
206 output);
207 break;
208 case kTfLiteInt64:
209 status = Gather<int64_t, int32_t>(context, *params, input, positions,
210 output);
211 break;
212 case kTfLiteBool:
213 status =
214 Gather<bool, int32_t>(context, *params, input, positions, output);
215 break;
216 case kTfLiteString:
217 status = GatherStrings<int32_t>(context, input, positions, output);
218 break;
219 default:
220 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
221 TfLiteTypeGetName(input->type));
222 return kTfLiteError;
223 }
224 }
225 if (positions->type == kTfLiteInt64) {
226 switch (input->type) {
227 case kTfLiteFloat32:
228 status =
229 Gather<float, int64_t>(context, *params, input, positions, output);
230 break;
231 case kTfLiteUInt8:
232 status = Gather<uint8_t, int64_t>(context, *params, input, positions,
233 output);
234 break;
235 case kTfLiteInt8:
236 status =
237 Gather<int8_t, int64_t>(context, *params, input, positions, output);
238 break;
239 case kTfLiteInt16:
240 status = Gather<int16_t, int64_t>(context, *params, input, positions,
241 output);
242 break;
243 case kTfLiteInt32:
244 status = Gather<int32_t, int64_t>(context, *params, input, positions,
245 output);
246 break;
247 case kTfLiteInt64:
248 status = Gather<int64_t, int64_t>(context, *params, input, positions,
249 output);
250 break;
251 case kTfLiteBool:
252 status =
253 Gather<bool, int64_t>(context, *params, input, positions, output);
254 break;
255 case kTfLiteString:
256 status = GatherStrings<int64_t>(context, input, positions, output);
257 break;
258 default:
259 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather.",
260 TfLiteTypeGetName(input->type));
261 return kTfLiteError;
262 }
263 }
264 if (status != kTfLiteOk) {
265 TF_LITE_KERNEL_LOG(context, "gather index out of bounds");
266 }
267 return status;
268 TF_LITE_KERNEL_LOG(context,
269 "Positions of type '%s' are not supported by gather.",
270 TfLiteTypeGetName(positions->type));
271 return kTfLiteError;
272}
273} // namespace gather
274
275TfLiteRegistration* Register_GATHER() {
276 static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare,
277 gather::Eval};
278 return &r;
279}
280
281} // namespace builtin
282} // namespace ops
283} // namespace tflite
284