1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/c_api_types.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace gather_nd {
29constexpr int kParams = 0;
30constexpr int kIndices = 1;
31constexpr int kOutputTensor = 0;
32
33TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
34 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
35 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
36
37 const TfLiteTensor* params;
38 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
39 const TfLiteTensor* indices;
40 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
41 TfLiteTensor* output;
42 TF_LITE_ENSURE_OK(context,
43 GetOutputSafe(context, node, kOutputTensor, &output));
44
45 switch (params->type) {
46 case kTfLiteFloat32:
47 case kTfLiteUInt8:
48 case kTfLiteInt8:
49 case kTfLiteInt16:
50 case kTfLiteInt64:
51 case kTfLiteInt32:
52 case kTfLiteString:
53 break;
54 default:
55 TF_LITE_KERNEL_LOG(context,
56 "Params of type '%s' are not supported by gather_nd.",
57 TfLiteTypeGetName(params->type));
58 return kTfLiteError;
59 }
60 switch (indices->type) {
61 case kTfLiteInt64:
62 case kTfLiteInt32:
63 break;
64 default:
65 TF_LITE_KERNEL_LOG(context,
66 "Indices of type '%s' are not supported by gather_nd.",
67 TfLiteTypeGetName(indices->type));
68 return kTfLiteError;
69 }
70
71 const int params_rank = NumDimensions(params);
72 const int indices_rank = NumDimensions(indices);
73 const int indices_nd = SizeOfDimension(indices, indices_rank - 1);
74 if (params_rank < 1) {
75 TF_LITE_KERNEL_LOG(context, "Params must be at least a vector.");
76 return kTfLiteError;
77 }
78 if (indices_rank < 1) {
79 TF_LITE_KERNEL_LOG(context, "Indices must be at least a vector.");
80 return kTfLiteError;
81 }
82 if (indices_nd > params_rank) {
83 TF_LITE_KERNEL_LOG(
84 context, "Index innermost dimension length must be <= params rank.");
85 return kTfLiteError;
86 }
87
88 // Assign to output the input type.
89 output->type = params->type;
90
91 // The result shape is
92 // indices.shape[:-1] + params.shape[indices.shape[-1]:]
93 const int output_rank = indices_rank + params_rank - indices_nd - 1;
94 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
95 int output_index = 0;
96 for (int i = 0; i < indices_rank - 1; ++i) {
97 output_shape->data[output_index++] = indices->dims->data[i];
98 }
99 for (int i = indices_nd; i < params_rank; ++i) {
100 output_shape->data[output_index++] = params->dims->data[i];
101 }
102 return context->ResizeTensor(context, output, output_shape);
103}
104
105template <typename ParamsT, typename IndicesT>
106TfLiteStatus GatherNd(const TfLiteTensor* params, const TfLiteTensor* indices,
107 TfLiteTensor* output) {
108 return reference_ops::GatherNd(
109 GetTensorShape(params), GetTensorData<ParamsT>(params),
110 GetTensorShape(indices), GetTensorData<IndicesT>(indices),
111 GetTensorShape(output), GetTensorData<ParamsT>(output));
112}
113
114template <typename IndicesT>
115TfLiteStatus GatherNdString(const TfLiteTensor* params,
116 const TfLiteTensor* indices, TfLiteTensor* output) {
117 return reference_ops::GatherNdString(
118 GetTensorShape(params), params, GetTensorShape(indices),
119 GetTensorData<IndicesT>(indices), GetTensorShape(output), output);
120}
121
122template <typename IndicesT>
123TfLiteStatus EvalGatherNd(TfLiteContext* context, const TfLiteTensor* params,
124 const TfLiteTensor* indices, TfLiteTensor* output) {
125 bool indices_has_only_positive_elements = true;
126 const auto* indices_values = GetTensorData<IndicesT>(indices);
127 const size_t num_indices = indices->bytes / sizeof(IndicesT);
128 for (size_t i = 0; i < num_indices; i++) {
129 if (indices_values[i] < 0) {
130 indices_has_only_positive_elements = false;
131 break;
132 }
133 }
134 TF_LITE_ENSURE(context, indices_has_only_positive_elements);
135
136 TfLiteStatus status = kTfLiteError;
137 switch (params->type) {
138 case kTfLiteFloat32:
139 status = GatherNd<float, IndicesT>(params, indices, output);
140 break;
141 case kTfLiteUInt8:
142 status = GatherNd<uint8_t, IndicesT>(params, indices, output);
143 break;
144 case kTfLiteInt8:
145 status = GatherNd<int8_t, IndicesT>(params, indices, output);
146 break;
147 case kTfLiteInt16:
148 status = GatherNd<int16_t, IndicesT>(params, indices, output);
149 break;
150 case kTfLiteInt32:
151 status = GatherNd<int32_t, IndicesT>(params, indices, output);
152 break;
153 case kTfLiteInt64:
154 status = GatherNd<int64_t, IndicesT>(params, indices, output);
155 break;
156 case kTfLiteString:
157 status = GatherNdString<IndicesT>(params, indices, output);
158 break;
159 default:
160 TF_LITE_KERNEL_LOG(context,
161 "Params type '%s' are not supported by gather_nd.",
162 TfLiteTypeGetName(params->type));
163 return kTfLiteError;
164 }
165 if (status != kTfLiteOk) {
166 TF_LITE_KERNEL_LOG(context, "gather_nd index out of bounds");
167 }
168 return status;
169}
170
171TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
172 const TfLiteTensor* params;
173 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kParams, &params));
174 const TfLiteTensor* indices;
175 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kIndices, &indices));
176 TfLiteTensor* output;
177 TF_LITE_ENSURE_OK(context,
178 GetOutputSafe(context, node, kOutputTensor, &output));
179
180 // Prevent division by 0 in the helper.
181 // In TF, GatherND supports empty `params` only when `indices` is also empty.
182 TF_LITE_ENSURE(context,
183 (NumElements(params) == 0 && NumElements(indices) == 0) ||
184 NumElements(params) > 0);
185
186 switch (indices->type) {
187 case kTfLiteInt32:
188 return EvalGatherNd<int32_t>(context, params, indices, output);
189 case kTfLiteInt64:
190 return EvalGatherNd<int64_t>(context, params, indices, output);
191 default:
192 TF_LITE_KERNEL_LOG(context,
193 "Indices of type '%s' are not supported by gather_nd.",
194 TfLiteTypeGetName(indices->type));
195 return kTfLiteError;
196 }
197}
198} // namespace gather_nd
199
200TfLiteRegistration* Register_GATHER_ND() {
201 static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
202 gather_nd::Prepare, gather_nd::Eval};
203 return &r;
204}
205} // namespace builtin
206} // namespace ops
207} // namespace tflite
208