1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/arg_min_max.h" |
16 | |
17 | #include <stdint.h> |
18 | |
19 | #include <functional> |
20 | |
21 | #include "tensorflow/lite/c/builtin_op_data.h" |
22 | #include "tensorflow/lite/c/c_api_types.h" |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
25 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
28 | #include "tensorflow/lite/kernels/kernel_util.h" |
29 | |
30 | namespace tflite { |
31 | namespace ops { |
32 | namespace builtin { |
33 | namespace arg_min_max { |
34 | |
35 | constexpr int kInputTensor = 0; |
36 | constexpr int kAxis = 1; |
37 | constexpr int kOutputTensor = 0; |
38 | |
39 | TfLiteStatus ResizeOutput(TfLiteContext* context, const TfLiteTensor* input, |
40 | const TfLiteTensor* axis, TfLiteTensor* output) { |
41 | int axis_value; |
42 | // Retrieve all 8 bytes when axis type is kTfLiteInt64 to avoid data loss. |
43 | if (axis->type == kTfLiteInt64) { |
44 | axis_value = static_cast<int>(*GetTensorData<int64_t>(axis)); |
45 | } else { |
46 | axis_value = *GetTensorData<int>(axis); |
47 | } |
48 | if (axis_value < 0) { |
49 | axis_value += NumDimensions(input); |
50 | } |
51 | |
52 | TF_LITE_ENSURE(context, axis_value >= 0); |
53 | TF_LITE_ENSURE(context, axis_value < NumDimensions(input)); |
54 | |
55 | // Copy the input dimensions to output except the axis dimension. |
56 | TfLiteIntArray* output_dims = TfLiteIntArrayCreate(NumDimensions(input) - 1); |
57 | int j = 0; |
58 | for (int i = 0; i < NumDimensions(input); ++i) { |
59 | if (i != axis_value) { |
60 | output_dims->data[j] = SizeOfDimension(input, i); |
61 | ++j; |
62 | } |
63 | } |
64 | return context->ResizeTensor(context, output, output_dims); |
65 | } |
66 | |
67 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
68 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
69 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
70 | |
71 | const TfLiteTensor* input; |
72 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
73 | const TfLiteTensor* axis; |
74 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis)); |
75 | // Make sure the axis is only 1 dimension. |
76 | TF_LITE_ENSURE_EQ(context, NumElements(axis), 1); |
77 | // Make sure the axis is only either int32 or int64. |
78 | TF_LITE_ENSURE(context, |
79 | axis->type == kTfLiteInt32 || axis->type == kTfLiteInt64); |
80 | |
81 | TfLiteTensor* output; |
82 | TF_LITE_ENSURE_OK(context, |
83 | GetOutputSafe(context, node, kOutputTensor, &output)); |
84 | |
85 | auto* params = reinterpret_cast<TfLiteArgMaxParams*>(node->builtin_data); |
86 | switch (params->output_type) { |
87 | case kTfLiteInt32: |
88 | output->type = kTfLiteInt32; |
89 | break; |
90 | case kTfLiteInt64: |
91 | output->type = kTfLiteInt64; |
92 | break; |
93 | default: |
94 | TF_LITE_KERNEL_LOG(context, "Unknown index output data type: %d" , |
95 | params->output_type); |
96 | return kTfLiteError; |
97 | } |
98 | |
99 | // Check conditions for different types. |
100 | switch (input->type) { |
101 | case kTfLiteFloat32: |
102 | case kTfLiteUInt8: |
103 | case kTfLiteInt8: |
104 | case kTfLiteInt32: |
105 | case kTfLiteBool: |
106 | break; |
107 | |
108 | default: |
109 | TF_LITE_KERNEL_LOG(context, |
110 | "Unknown input type: %d, only float32, int types " |
111 | "and bool are supported" , |
112 | input->type); |
113 | return kTfLiteError; |
114 | } |
115 | |
116 | TF_LITE_ENSURE(context, NumDimensions(input) >= 1); |
117 | |
118 | if (IsConstantTensor(axis)) { |
119 | TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output)); |
120 | } else { |
121 | SetTensorToDynamic(output); |
122 | } |
123 | |
124 | return kTfLiteOk; |
125 | } |
126 | |
127 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node, bool is_arg_max) { |
128 | const TfLiteTensor* input; |
129 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
130 | const TfLiteTensor* axis; |
131 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis)); |
132 | TfLiteTensor* output; |
133 | TF_LITE_ENSURE_OK(context, |
134 | GetOutputSafe(context, node, kOutputTensor, &output)); |
135 | if (IsDynamicTensor(output)) { |
136 | TF_LITE_ENSURE_STATUS(ResizeOutput(context, input, axis, output)); |
137 | } |
138 | |
139 | #define TF_LITE_ARG_MIN_MAX(data_type, axis_type, output_type) \ |
140 | optimized_ops::ArgMinMax( \ |
141 | GetTensorShape(input), GetTensorData<data_type>(input), \ |
142 | GetTensorData<axis_type>(axis), GetTensorShape(output), \ |
143 | GetTensorData<output_type>(output), is_arg_max) |
144 | if (axis->type == kTfLiteInt32) { |
145 | switch (output->type) { |
146 | case kTfLiteInt32: { |
147 | switch (input->type) { |
148 | case kTfLiteFloat32: |
149 | TF_LITE_ARG_MIN_MAX(float, int32_t, int32_t); |
150 | break; |
151 | case kTfLiteUInt8: |
152 | TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t); |
153 | break; |
154 | case kTfLiteInt8: |
155 | TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t); |
156 | break; |
157 | case kTfLiteInt32: |
158 | TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t); |
159 | break; |
160 | case kTfLiteBool: |
161 | TF_LITE_ARG_MIN_MAX(bool, int32_t, int32_t); |
162 | break; |
163 | default: |
164 | TF_LITE_KERNEL_LOG(context, |
165 | "Only float32, uint8, int8, int32 and bool are " |
166 | "supported currently, got %s." , |
167 | TfLiteTypeGetName(input->type)); |
168 | return kTfLiteError; |
169 | } |
170 | } break; |
171 | case kTfLiteInt64: { |
172 | switch (input->type) { |
173 | case kTfLiteFloat32: |
174 | TF_LITE_ARG_MIN_MAX(float, int32_t, int64_t); |
175 | break; |
176 | case kTfLiteUInt8: |
177 | TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t); |
178 | break; |
179 | case kTfLiteInt8: |
180 | TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t); |
181 | break; |
182 | case kTfLiteInt32: |
183 | TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t); |
184 | break; |
185 | case kTfLiteBool: |
186 | TF_LITE_ARG_MIN_MAX(bool, int32_t, int64_t); |
187 | break; |
188 | default: |
189 | TF_LITE_KERNEL_LOG(context, |
190 | "Only float32, uint8, int8, int32 and bool are " |
191 | "supported currently, got %s." , |
192 | TfLiteTypeGetName(input->type)); |
193 | return kTfLiteError; |
194 | } |
195 | } break; |
196 | default: |
197 | TF_LITE_KERNEL_LOG( |
198 | context, "Only int32 and int64 are supported currently, got %s." , |
199 | TfLiteTypeGetName(output->type)); |
200 | return kTfLiteError; |
201 | } |
202 | } else { |
203 | switch (output->type) { |
204 | case kTfLiteInt32: { |
205 | switch (input->type) { |
206 | case kTfLiteFloat32: |
207 | TF_LITE_ARG_MIN_MAX(float, int64_t, int32_t); |
208 | break; |
209 | case kTfLiteUInt8: |
210 | TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int32_t); |
211 | break; |
212 | case kTfLiteInt8: |
213 | TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int32_t); |
214 | break; |
215 | case kTfLiteInt32: |
216 | TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int32_t); |
217 | break; |
218 | case kTfLiteBool: |
219 | TF_LITE_ARG_MIN_MAX(bool, int64_t, int32_t); |
220 | break; |
221 | default: |
222 | TF_LITE_KERNEL_LOG(context, |
223 | "Only float32, uint8, int8, int32 and bool are " |
224 | "supported currently, got %s." , |
225 | TfLiteTypeGetName(input->type)); |
226 | return kTfLiteError; |
227 | } |
228 | } break; |
229 | case kTfLiteInt64: { |
230 | switch (input->type) { |
231 | case kTfLiteFloat32: |
232 | TF_LITE_ARG_MIN_MAX(float, int64_t, int64_t); |
233 | break; |
234 | case kTfLiteUInt8: |
235 | TF_LITE_ARG_MIN_MAX(uint8_t, int64_t, int64_t); |
236 | break; |
237 | case kTfLiteInt8: |
238 | TF_LITE_ARG_MIN_MAX(int8_t, int64_t, int64_t); |
239 | break; |
240 | case kTfLiteInt32: |
241 | TF_LITE_ARG_MIN_MAX(int32_t, int64_t, int64_t); |
242 | break; |
243 | case kTfLiteBool: |
244 | TF_LITE_ARG_MIN_MAX(bool, int64_t, int64_t); |
245 | break; |
246 | default: |
247 | TF_LITE_KERNEL_LOG(context, |
248 | "Only float32, uint8, int8, int32 and bool are " |
249 | "supported currently, got %s." , |
250 | TfLiteTypeGetName(input->type)); |
251 | return kTfLiteError; |
252 | } |
253 | } break; |
254 | default: |
255 | TF_LITE_KERNEL_LOG( |
256 | context, "Only int32 and int64 are supported currently, got %s." , |
257 | TfLiteTypeGetName(output->type)); |
258 | return kTfLiteError; |
259 | } |
260 | } |
261 | #undef TF_LITE_ARG_MIN_MAX |
262 | |
263 | return kTfLiteOk; |
264 | } |
265 | |
266 | TfLiteStatus ArgMinEval(TfLiteContext* context, TfLiteNode* node) { |
267 | return Eval(context, node, false); |
268 | } |
269 | |
270 | TfLiteStatus ArgMaxEval(TfLiteContext* context, TfLiteNode* node) { |
271 | return Eval(context, node, true); |
272 | } |
273 | |
274 | } // namespace arg_min_max |
275 | |
276 | TfLiteRegistration* Register_ARG_MAX() { |
277 | static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, |
278 | arg_min_max::ArgMaxEval}; |
279 | return &r; |
280 | } |
281 | |
282 | TfLiteRegistration* Register_ARG_MIN() { |
283 | static TfLiteRegistration r = {nullptr, nullptr, arg_min_max::Prepare, |
284 | arg_min_max::ArgMinEval}; |
285 | return &r; |
286 | } |
287 | |
288 | } // namespace builtin |
289 | } // namespace ops |
290 | } // namespace tflite |
291 | |