1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/arg_min_max.h"
16
17#include <stdint.h>
18
19#include <functional>
20
21#include "tensorflow/lite/c/builtin_op_data.h"
22#include "tensorflow/lite/c/c_api_types.h"
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
25#include "tensorflow/lite/kernels/internal/quantization_util.h"
26#include "tensorflow/lite/kernels/internal/tensor.h"
27#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28#include "tensorflow/lite/kernels/kernel_util.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace arg_min_max {
34
35constexpr int kInputTensor = 0;
36constexpr int kAxis = 1;
37constexpr int kOutputTensor = 0;
38
39TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input,
40 const TfLiteTensor* axis, TfLiteTensor* output) {
41 int axis_value;
42 // Retrieve all 8 bytes when axis type is kTfLiteInt64 to avoid data loss.
43 if (axis->type == kTfLiteInt64) {
44 axis_value = static_cast<int>(*GetTensorData<int64_t>(axis));
45 } else {
46 axis_value = *GetTensorData<int>(axis);
47 }
48 if (axis_value < 0) {
49 axis_value += NumDimensions(input);
50 }
51
52 TF_LITE_ENSURE(context, axis_value >= 0);
53 TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
54
55 // Copy the input dimensions to output except the axis dimension.
56 TfLiteIntArray* output_dims = TfLiteIntArrayCreate(NumDimensions(input) - 1);
57 int j = 0;
58 for (int i = 0; i < NumDimensions(input); ++i) {
59 if (i != axis_value) {
60 output_dims->data[j] = SizeOfDimension(input, i);
61 ++j;
62 }
63 }
64 return context->ResizeTensor(context, output, output_dims);
65}
66
67TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
68 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
69 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
70
71 const TfLiteTensor* input;
72 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
73 const TfLiteTensor* axis;
74 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
75 // Make sure the axis is only 1 dimension.
76 TF_LITE_ENSURE_EQ(context, NumElements(axis), 1);
77 // Make sure the axis is only either int32 or int64.
78 TF_LITE_ENSURE(context,
79 axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64);
80
81 TfLiteTensor* output;
82 TF_LITE_ENSURE_OK(context,
83 GetOutputSafe(context, node, kOutputTensor, &output));
84
85 auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data);
86 switch (params->output_type) {
87 case kTfLiteInt32:
88 output->type = kTfLiteInt32;
89 break;
90 case kTfLiteInt64:
91 output->type = kTfLiteInt64;
92 break;
93 default:
94 TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %d",
95 params->output_type);
96 return kTfLiteError;
97 }
98
99 // Check conditions for different types.
100 switch (input->type) {
101 case kTfLiteFloat32:
102 case kTfLiteUInt8:
103 case kTfLiteInt8:
104 case kTfLiteInt32:
105 case kTfLiteBool:
106 break;
107
108 default:
109 TF_LITE_KERNEL_LOG(context,
110 "Unknown input type: %d, only float32, int types "
111 "and bool are supported",
112 input->type);
113 return kTfLiteError;
114 }
115
116 TF_LITE_ENSURE(context, NumDimensions(input) >= 1);
117
118 if (IsConstantTensor(axis)) {
119 TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
120 } else {
121 SetTensorToDynamic(output);
122 }
123
124 return kTfLiteOk;
125}
126
127TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) {
128 const TfLiteTensor* input;
129 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
130 const TfLiteTensor* axis;
131 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
132 TfLiteTensor* output;
133 TF_LITE_ENSURE_OK(context,
134 GetOutputSafe(context, node, kOutputTensor, &output));
135 if (IsDynamicTensor(output)) {
136 TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output));
137 }
138
139#define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \
140 optimized_ops::ArgMinMax( \
141 GetTensorShape(input), GetTensorData<data_type>(input), \
142 GetTensorData<axis_type>(axis), GetTensorShape(output), \
143 GetTensorData<output_type>(output), is_arg_max)
144 if (axis->type == kTfLiteInt32) {
145 switch (output->type) {
146 case kTfLiteInt32: {
147 switch (input->type) {
148 case kTfLiteFloat32:
149 TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t);
150 break;
151 case kTfLiteUInt8:
152 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
153 break;
154 case kTfLiteInt8:
155 TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
156 break;
157 case kTfLiteInt32:
158 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
159 break;
160 case kTfLiteBool:
161 TF_LITE_ARG_MIN_MAX(bool, int32_t, int32_t);
162 break;
163 default:
164 TF_LITE_KERNEL_LOG(context,
165 "Only float32, uint8, int8, int32 and bool are "
166 "supported currently, got %s.",
167 TfLiteTypeGetName(input->type));
168 return kTfLiteError;
169 }
170 } break;
171 case kTfLiteInt64: {
172 switch (input->type) {
173 case kTfLiteFloat32:
174 TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t);
175 break;
176 case kTfLiteUInt8:
177 TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
178 break;
179 case kTfLiteInt8:
180 TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
181 break;
182 case kTfLiteInt32:
183 TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
184 break;
185 case kTfLiteBool:
186 TF_LITE_ARG_MIN_MAX(bool, int32_t, int64_t);
187 break;
188 default:
189 TF_LITE_KERNEL_LOG(context,
190 "Only float32, uint8, int8, int32 and bool are "
191 "supported currently, got %s.",
192 TfLiteTypeGetName(input->type));
193 return kTfLiteError;
194 }
195 } break;
196 default:
197 TF_LITE_KERNEL_LOG(
198 context, "Only int32 and int64 are supported currently, got %s.",
199 TfLiteTypeGetName(output->type));
200 return kTfLiteError;
201 }
202 } else {
203 switch (output->type) {
204 case kTfLiteInt32: {
205 switch (input->type) {
206 case kTfLiteFloat32:
207 TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t);
208 break;
209 case kTfLiteUInt8:
210 TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t);
211 break;
212 case kTfLiteInt8:
213 TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t);
214 break;
215 case kTfLiteInt32:
216 TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t);
217 break;
218 case kTfLiteBool:
219 TF_LITE_ARG_MIN_MAX(bool, int64_t, int32_t);
220 break;
221 default:
222 TF_LITE_KERNEL_LOG(context,
223 "Only float32, uint8, int8, int32 and bool are "
224 "supported currently, got %s.",
225 TfLiteTypeGetName(input->type));
226 return kTfLiteError;
227 }
228 } break;
229 case kTfLiteInt64: {
230 switch (input->type) {
231 case kTfLiteFloat32:
232 TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t);
233 break;
234 case kTfLiteUInt8:
235 TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t);
236 break;
237 case kTfLiteInt8:
238 TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t);
239 break;
240 case kTfLiteInt32:
241 TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t);
242 break;
243 case kTfLiteBool:
244 TF_LITE_ARG_MIN_MAX(bool, int64_t, int64_t);
245 break;
246 default:
247 TF_LITE_KERNEL_LOG(context,
248 "Only float32, uint8, int8, int32 and bool are "
249 "supported currently, got %s.",
250 TfLiteTypeGetName(input->type));
251 return kTfLiteError;
252 }
253 } break;
254 default:
255 TF_LITE_KERNEL_LOG(
256 context, "Only int32 and int64 are supported currently, got %s.",
257 TfLiteTypeGetName(output->type));
258 return kTfLiteError;
259 }
260 }
261#undef TF_LITE_ARG_MIN_MAX
262
263 return kTfLiteOk;
264}
265
266TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) {
267 return Eval(context, node, false);
268}
269
270TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) {
271 return Eval(context, node, true);
272}
273
274} // namespace arg_min_max
275
276TfLiteRegistration* Register_ARG_MAX() {
277 static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
278 arg_min_max::ArgMaxEval};
279 return &r;
280}
281
282TfLiteRegistration* Register_ARG_MIN() {
283 static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare,
284 arg_min_max::ArgMinEval};
285 return &r;
286}
287
288} // namespace builtin
289} // namespace ops
290} // namespace tflite
291