1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include "tensorflow/lite/c/builtin_op_data.h" |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
20 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/internal/types.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | #include "tensorflow/lite/string_util.h" |
26 | |
27 | namespace tflite { |
28 | namespace ops { |
29 | namespace builtin { |
30 | namespace gather { |
31 | constexpr int kInputTensor = 0; |
32 | constexpr int kInputPositions = 1; |
33 | constexpr int kOutputTensor = 0; |
34 | |
35 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
36 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
37 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
38 | |
39 | const auto* params = |
40 | reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data); |
41 | const TfLiteTensor* input; |
42 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
43 | const TfLiteTensor* positions; |
44 | TF_LITE_ENSURE_OK(context, |
45 | GetInputSafe(context, node, kInputPositions, &positions)); |
46 | TfLiteTensor* output; |
47 | TF_LITE_ENSURE_OK(context, |
48 | GetOutputSafe(context, node, kOutputTensor, &output)); |
49 | |
50 | switch (positions->type) { |
51 | case kTfLiteInt64: |
52 | case kTfLiteInt32: |
53 | break; |
54 | default: |
55 | TF_LITE_KERNEL_LOG(context, |
56 | "Positions of type '%s' are not supported by gather." , |
57 | TfLiteTypeGetName(positions->type)); |
58 | return kTfLiteError; |
59 | } |
60 | |
61 | // Assign to output the input type. |
62 | output->type = input->type; |
63 | |
64 | // Check conditions for different types. |
65 | switch (input->type) { |
66 | case kTfLiteFloat32: |
67 | case kTfLiteUInt8: |
68 | case kTfLiteInt8: |
69 | case kTfLiteInt16: |
70 | case kTfLiteInt64: |
71 | case kTfLiteInt32: |
72 | case kTfLiteBool: |
73 | break; |
74 | case kTfLiteString: { |
75 | // Only 1D input is supported. |
76 | TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); |
77 | } break; |
78 | default: |
79 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather." , |
80 | TfLiteTypeGetName(input->type)); |
81 | return kTfLiteError; |
82 | } |
83 | |
84 | int axis = params->axis; |
85 | if (axis < 0) { |
86 | axis += NumDimensions(input); |
87 | } |
88 | TF_LITE_ENSURE(context, 0 <= axis && axis < NumDimensions(input)); |
89 | |
90 | int batch_dims = params->batch_dims; |
91 | // batch_dims should be in range: [-rank(positions), rank(positions)]. |
92 | // Negative batch_dims is added with rank of positions. |
93 | if (batch_dims < 0) { |
94 | batch_dims += NumDimensions(positions); |
95 | } |
96 | TF_LITE_ENSURE(context, batch_dims <= axis); |
97 | TF_LITE_ENSURE(context, 0 <= batch_dims && batch_dims < NumDimensions(input)); |
98 | TF_LITE_ENSURE(context, batch_dims <= NumDimensions(positions)); |
99 | for (int i = 0; i < batch_dims; ++i) { |
100 | TF_LITE_ENSURE_EQ(context, input->dims->data[i], positions->dims->data[i]); |
101 | } |
102 | |
103 | const int num_dimensions = |
104 | NumDimensions(input) + NumDimensions(positions) - 1 - batch_dims; |
105 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); |
106 | int output_index = 0; |
107 | for (int i = 0; i < axis; ++i) { |
108 | output_shape->data[output_index++] = input->dims->data[i]; |
109 | } |
110 | for (int i = batch_dims; i < positions->dims->size; ++i) { |
111 | output_shape->data[output_index++] = positions->dims->data[i]; |
112 | } |
113 | for (int i = axis + 1; i < input->dims->size; ++i) { |
114 | output_shape->data[output_index++] = input->dims->data[i]; |
115 | } |
116 | return context->ResizeTensor(context, output, output_shape); |
117 | } |
118 | |
119 | template <typename InputT, typename PositionsT> |
120 | TfLiteStatus Gather(TfLiteContext* context, const TfLiteGatherParams& params, |
121 | const TfLiteTensor* input, const TfLiteTensor* positions, |
122 | TfLiteTensor* output) { |
123 | const PositionsT* indexes = GetTensorData<PositionsT>(positions); |
124 | bool indices_has_only_positive_elements = true; |
125 | const size_t num_indices = positions->bytes / sizeof(PositionsT); |
126 | for (size_t i = 0; i < num_indices; i++) { |
127 | if (indexes[i] < 0) { |
128 | indices_has_only_positive_elements = false; |
129 | break; |
130 | } |
131 | } |
132 | TF_LITE_ENSURE(context, indices_has_only_positive_elements); |
133 | |
134 | tflite::GatherParams op_params; |
135 | op_params.axis = params.axis; |
136 | op_params.batch_dims = params.batch_dims; |
137 | return optimized_ops::Gather( |
138 | op_params, GetTensorShape(input), GetTensorData<InputT>(input), |
139 | GetTensorShape(positions), GetTensorData<PositionsT>(positions), |
140 | GetTensorShape(output), GetTensorData<InputT>(output)); |
141 | } |
142 | |
143 | template <typename PositionT> |
144 | TfLiteStatus GatherStrings(TfLiteContext* context, const TfLiteTensor* input, |
145 | const TfLiteTensor* positions, |
146 | TfLiteTensor* output) { |
147 | DynamicBuffer buffer; |
148 | |
149 | const PositionT* indexes = GetTensorData<PositionT>(positions); |
150 | bool indices_has_only_positive_elements = true; |
151 | const size_t num_indices = positions->bytes / sizeof(PositionT); |
152 | for (size_t i = 0; i < num_indices; i++) { |
153 | if (indexes[i] < 0) { |
154 | indices_has_only_positive_elements = false; |
155 | break; |
156 | } |
157 | } |
158 | TF_LITE_ENSURE(context, indices_has_only_positive_elements); |
159 | |
160 | const PositionT num_strings = GetStringCount(input); |
161 | const int num_indexes = NumElements(positions); |
162 | |
163 | for (int i = 0; i < num_indexes; ++i) { |
164 | const PositionT pos = indexes[i]; |
165 | TF_LITE_ENSURE(context, pos < num_strings); |
166 | const auto string_ref = GetString(input, pos); |
167 | buffer.AddString(string_ref.str, string_ref.len); |
168 | } |
169 | buffer.WriteToTensor(output, /*new_shape=*/nullptr); |
170 | return kTfLiteOk; |
171 | } |
172 | |
173 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
174 | const auto* params = |
175 | reinterpret_cast<const TfLiteGatherParams*>(node->builtin_data); |
176 | const TfLiteTensor* input; |
177 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
178 | const TfLiteTensor* positions; |
179 | TF_LITE_ENSURE_OK(context, |
180 | GetInputSafe(context, node, kInputPositions, &positions)); |
181 | TfLiteTensor* output; |
182 | TF_LITE_ENSURE_OK(context, |
183 | GetOutputSafe(context, node, kOutputTensor, &output)); |
184 | |
185 | TfLiteStatus status = kTfLiteError; |
186 | if (positions->type == kTfLiteInt32) { |
187 | switch (input->type) { |
188 | case kTfLiteFloat32: |
189 | status = |
190 | Gather<float, int32_t>(context, *params, input, positions, output); |
191 | break; |
192 | case kTfLiteUInt8: |
193 | status = Gather<uint8_t, int32_t>(context, *params, input, positions, |
194 | output); |
195 | break; |
196 | case kTfLiteInt8: |
197 | status = |
198 | Gather<int8_t, int32_t>(context, *params, input, positions, output); |
199 | break; |
200 | case kTfLiteInt16: |
201 | status = Gather<int16_t, int32_t>(context, *params, input, positions, |
202 | output); |
203 | break; |
204 | case kTfLiteInt32: |
205 | status = Gather<int32_t, int32_t>(context, *params, input, positions, |
206 | output); |
207 | break; |
208 | case kTfLiteInt64: |
209 | status = Gather<int64_t, int32_t>(context, *params, input, positions, |
210 | output); |
211 | break; |
212 | case kTfLiteBool: |
213 | status = |
214 | Gather<bool, int32_t>(context, *params, input, positions, output); |
215 | break; |
216 | case kTfLiteString: |
217 | status = GatherStrings<int32_t>(context, input, positions, output); |
218 | break; |
219 | default: |
220 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather." , |
221 | TfLiteTypeGetName(input->type)); |
222 | return kTfLiteError; |
223 | } |
224 | } |
225 | if (positions->type == kTfLiteInt64) { |
226 | switch (input->type) { |
227 | case kTfLiteFloat32: |
228 | status = |
229 | Gather<float, int64_t>(context, *params, input, positions, output); |
230 | break; |
231 | case kTfLiteUInt8: |
232 | status = Gather<uint8_t, int64_t>(context, *params, input, positions, |
233 | output); |
234 | break; |
235 | case kTfLiteInt8: |
236 | status = |
237 | Gather<int8_t, int64_t>(context, *params, input, positions, output); |
238 | break; |
239 | case kTfLiteInt16: |
240 | status = Gather<int16_t, int64_t>(context, *params, input, positions, |
241 | output); |
242 | break; |
243 | case kTfLiteInt32: |
244 | status = Gather<int32_t, int64_t>(context, *params, input, positions, |
245 | output); |
246 | break; |
247 | case kTfLiteInt64: |
248 | status = Gather<int64_t, int64_t>(context, *params, input, positions, |
249 | output); |
250 | break; |
251 | case kTfLiteBool: |
252 | status = |
253 | Gather<bool, int64_t>(context, *params, input, positions, output); |
254 | break; |
255 | case kTfLiteString: |
256 | status = GatherStrings<int64_t>(context, input, positions, output); |
257 | break; |
258 | default: |
259 | TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by gather." , |
260 | TfLiteTypeGetName(input->type)); |
261 | return kTfLiteError; |
262 | } |
263 | } |
264 | if (status != kTfLiteOk) { |
265 | TF_LITE_KERNEL_LOG(context, "gather index out of bounds" ); |
266 | } |
267 | return status; |
268 | TF_LITE_KERNEL_LOG(context, |
269 | "Positions of type '%s' are not supported by gather." , |
270 | TfLiteTypeGetName(positions->type)); |
271 | return kTfLiteError; |
272 | } |
273 | } // namespace gather |
274 | |
275 | TfLiteRegistration* Register_GATHER() { |
276 | static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, |
277 | gather::Eval}; |
278 | return &r; |
279 | } |
280 | |
281 | } // namespace builtin |
282 | } // namespace ops |
283 | } // namespace tflite |
284 | |