1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/comparisons.h" |
16 | |
17 | #include <stdint.h> |
18 | |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
21 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
22 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
25 | #include "tensorflow/lite/kernels/internal/types.h" |
26 | #include "tensorflow/lite/kernels/kernel_util.h" |
27 | #include "tensorflow/lite/string_util.h" |
28 | |
29 | namespace tflite { |
30 | namespace ops { |
31 | namespace builtin { |
32 | namespace comparisons { |
33 | namespace { |
34 | |
35 | constexpr int kInputTensor1 = 0; |
36 | constexpr int kInputTensor2 = 1; |
37 | constexpr int kOutputTensor = 0; |
38 | |
39 | TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node, |
40 | bool is_string_allowed) { |
41 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
42 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
43 | |
44 | const TfLiteTensor* input1; |
45 | TF_LITE_ENSURE_OK(context, |
46 | GetInputSafe(context, node, kInputTensor1, &input1)); |
47 | const TfLiteTensor* input2; |
48 | TF_LITE_ENSURE_OK(context, |
49 | GetInputSafe(context, node, kInputTensor2, &input2)); |
50 | TfLiteTensor* output; |
51 | TF_LITE_ENSURE_OK(context, |
52 | GetOutputSafe(context, node, kOutputTensor, &output)); |
53 | |
54 | // Don't support string. |
55 | if (!is_string_allowed) { |
56 | TF_LITE_ENSURE(context, input1->type != kTfLiteString); |
57 | } |
58 | // Currently only support tensors have the same type. |
59 | TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); |
60 | output->type = kTfLiteBool; |
61 | |
62 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
63 | |
64 | TfLiteIntArray* output_size = nullptr; |
65 | if (requires_broadcast) { |
66 | TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( |
67 | context, input1, input2, &output_size)); |
68 | } else { |
69 | output_size = TfLiteIntArrayCopy(input1->dims); |
70 | } |
71 | |
72 | return context->ResizeTensor(context, output, output_size); |
73 | } |
74 | |
75 | TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { |
76 | return ComparisonPrepareCommon(context, node, false); |
77 | } |
78 | |
79 | TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context, |
80 | TfLiteNode* node) { |
81 | return ComparisonPrepareCommon(context, node, true); |
82 | } |
83 | |
84 | void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, |
85 | int* left_shift) { |
86 | if (double_multiplier < 1.0) { |
87 | QuantizeMultiplierSmallerThanOneExp(double_multiplier, quantized_multiplier, |
88 | left_shift); |
89 | } else { |
90 | QuantizeMultiplierGreaterThanOne(double_multiplier, quantized_multiplier, |
91 | left_shift); |
92 | } |
93 | } |
94 | |
95 | template <typename input_dtype, reference_ops::ComparisonFn<int32> opname> |
96 | void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2, |
97 | TfLiteTensor* output, bool requires_broadcast) { |
98 | if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) { |
99 | auto input1_offset = -input1->params.zero_point; |
100 | auto input2_offset = -input2->params.zero_point; |
101 | const int left_shift = 8; |
102 | |
103 | int32 input1_multiplier; |
104 | int32 input2_multiplier; |
105 | int input1_shift; |
106 | int input2_shift; |
107 | QuantizeMultiplier(input1->params.scale, &input1_multiplier, &input1_shift); |
108 | QuantizeMultiplier(input2->params.scale, &input2_multiplier, &input2_shift); |
109 | |
110 | ComparisonParams op_params; |
111 | op_params.left_shift = left_shift; |
112 | op_params.input1_offset = input1_offset; |
113 | op_params.input1_multiplier = input1_multiplier; |
114 | op_params.input1_shift = input1_shift; |
115 | op_params.input2_offset = input2_offset; |
116 | op_params.input2_multiplier = input2_multiplier; |
117 | op_params.input2_shift = input2_shift; |
118 | if (requires_broadcast) { |
119 | reference_ops::BroadcastComparison4DSlowWithScaling<input_dtype, opname>( |
120 | op_params, GetTensorShape(input1), GetTensorData<input_dtype>(input1), |
121 | GetTensorShape(input2), GetTensorData<input_dtype>(input2), |
122 | GetTensorShape(output), GetTensorData<bool>(output)); |
123 | } else { |
124 | reference_ops::ComparisonWithScaling<input_dtype, opname>( |
125 | op_params, GetTensorShape(input1), GetTensorData<input_dtype>(input1), |
126 | GetTensorShape(input2), GetTensorData<input_dtype>(input2), |
127 | GetTensorShape(output), GetTensorData<bool>(output)); |
128 | } |
129 | } |
130 | } |
131 | |
132 | template <typename T, reference_ops::ComparisonFn<T> opname> |
133 | void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2, |
134 | TfLiteTensor* output, bool requires_broadcast) { |
135 | ComparisonParams op_params; |
136 | requires_broadcast |
137 | ? reference_ops::BroadcastComparison4DSlowImpl<T, opname>( |
138 | op_params, GetTensorShape(input1), GetTensorData<T>(input1), |
139 | GetTensorShape(input2), GetTensorData<T>(input2), |
140 | GetTensorShape(output), GetTensorData<bool>(output)) |
141 | : reference_ops::ComparisonImpl<T, opname>( |
142 | op_params, GetTensorShape(input1), GetTensorData<T>(input1), |
143 | GetTensorShape(input2), GetTensorData<T>(input2), |
144 | GetTensorShape(output), GetTensorData<bool>(output)); |
145 | } |
146 | |
147 | void ComparisonString(bool (*opname)(const StringRef&, const StringRef&), |
148 | const TfLiteTensor* input1, const TfLiteTensor* input2, |
149 | TfLiteTensor* output, bool requires_broadcast) { |
150 | bool* output_data = GetTensorData<bool>(output); |
151 | if (requires_broadcast) { |
152 | reference_ops::BroadcastComparison4DSlowStringImpl( |
153 | opname, GetTensorShape(input1), input1, GetTensorShape(input2), input2, |
154 | GetTensorShape(output), output_data); |
155 | } else { |
156 | reference_ops::ComparisonStringImpl(opname, GetTensorShape(input1), input1, |
157 | GetTensorShape(input2), input2, |
158 | GetTensorShape(output), output_data); |
159 | } |
160 | } |
161 | |
162 | TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) { |
163 | const TfLiteTensor* input1; |
164 | TF_LITE_ENSURE_OK(context, |
165 | GetInputSafe(context, node, kInputTensor1, &input1)); |
166 | const TfLiteTensor* input2; |
167 | TF_LITE_ENSURE_OK(context, |
168 | GetInputSafe(context, node, kInputTensor2, &input2)); |
169 | TfLiteTensor* output; |
170 | TF_LITE_ENSURE_OK(context, |
171 | GetOutputSafe(context, node, kOutputTensor, &output)); |
172 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
173 | switch (input1->type) { |
174 | case kTfLiteBool: |
175 | Comparison<bool, reference_ops::EqualFn>(input1, input2, output, |
176 | requires_broadcast); |
177 | break; |
178 | case kTfLiteFloat32: |
179 | Comparison<float, reference_ops::EqualFn>(input1, input2, output, |
180 | requires_broadcast); |
181 | break; |
182 | case kTfLiteInt32: |
183 | Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output, |
184 | requires_broadcast); |
185 | break; |
186 | case kTfLiteInt64: |
187 | Comparison<int64_t, reference_ops::EqualFn>(input1, input2, output, |
188 | requires_broadcast); |
189 | break; |
190 | case kTfLiteUInt8: |
191 | ComparisonQuantized<uint8_t, reference_ops::EqualFn>( |
192 | input1, input2, output, requires_broadcast); |
193 | break; |
194 | case kTfLiteInt8: |
195 | ComparisonQuantized<int8_t, reference_ops::EqualFn>( |
196 | input1, input2, output, requires_broadcast); |
197 | break; |
198 | case kTfLiteString: |
199 | ComparisonString(reference_ops::StringRefEqualFn, input1, input2, output, |
200 | requires_broadcast); |
201 | break; |
202 | default: |
203 | TF_LITE_KERNEL_LOG( |
204 | context, |
205 | "Does not support type %d, requires bool|float|int|uint8|string" , |
206 | input1->type); |
207 | return kTfLiteError; |
208 | } |
209 | return kTfLiteOk; |
210 | } |
211 | |
212 | TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) { |
213 | const TfLiteTensor* input1; |
214 | TF_LITE_ENSURE_OK(context, |
215 | GetInputSafe(context, node, kInputTensor1, &input1)); |
216 | const TfLiteTensor* input2; |
217 | TF_LITE_ENSURE_OK(context, |
218 | GetInputSafe(context, node, kInputTensor2, &input2)); |
219 | TfLiteTensor* output; |
220 | TF_LITE_ENSURE_OK(context, |
221 | GetOutputSafe(context, node, kOutputTensor, &output)); |
222 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
223 | switch (input1->type) { |
224 | case kTfLiteBool: |
225 | Comparison<bool, reference_ops::NotEqualFn>(input1, input2, output, |
226 | requires_broadcast); |
227 | break; |
228 | case kTfLiteFloat32: |
229 | Comparison<float, reference_ops::NotEqualFn>(input1, input2, output, |
230 | requires_broadcast); |
231 | break; |
232 | case kTfLiteInt32: |
233 | Comparison<int32_t, reference_ops::NotEqualFn>(input1, input2, output, |
234 | requires_broadcast); |
235 | break; |
236 | case kTfLiteInt64: |
237 | Comparison<int64_t, reference_ops::NotEqualFn>(input1, input2, output, |
238 | requires_broadcast); |
239 | break; |
240 | case kTfLiteUInt8: |
241 | ComparisonQuantized<uint8_t, reference_ops::NotEqualFn>( |
242 | input1, input2, output, requires_broadcast); |
243 | break; |
244 | case kTfLiteInt8: |
245 | ComparisonQuantized<int8_t, reference_ops::NotEqualFn>( |
246 | input1, input2, output, requires_broadcast); |
247 | break; |
248 | case kTfLiteString: |
249 | ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2, |
250 | output, requires_broadcast); |
251 | break; |
252 | default: |
253 | TF_LITE_KERNEL_LOG( |
254 | context, |
255 | "Does not support type %d, requires bool|float|int|uint8|string" , |
256 | input1->type); |
257 | return kTfLiteError; |
258 | } |
259 | return kTfLiteOk; |
260 | } |
261 | |
262 | TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) { |
263 | const TfLiteTensor* input1; |
264 | TF_LITE_ENSURE_OK(context, |
265 | GetInputSafe(context, node, kInputTensor1, &input1)); |
266 | const TfLiteTensor* input2; |
267 | TF_LITE_ENSURE_OK(context, |
268 | GetInputSafe(context, node, kInputTensor2, &input2)); |
269 | TfLiteTensor* output; |
270 | TF_LITE_ENSURE_OK(context, |
271 | GetOutputSafe(context, node, kOutputTensor, &output)); |
272 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
273 | switch (input1->type) { |
274 | case kTfLiteFloat32: |
275 | Comparison<float, reference_ops::GreaterFn>(input1, input2, output, |
276 | requires_broadcast); |
277 | break; |
278 | case kTfLiteInt32: |
279 | Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output, |
280 | requires_broadcast); |
281 | break; |
282 | case kTfLiteInt64: |
283 | Comparison<int64_t, reference_ops::GreaterFn>(input1, input2, output, |
284 | requires_broadcast); |
285 | break; |
286 | case kTfLiteUInt8: |
287 | ComparisonQuantized<uint8_t, reference_ops::GreaterFn>( |
288 | input1, input2, output, requires_broadcast); |
289 | break; |
290 | case kTfLiteInt8: |
291 | ComparisonQuantized<int8_t, reference_ops::GreaterFn>( |
292 | input1, input2, output, requires_broadcast); |
293 | break; |
294 | default: |
295 | TF_LITE_KERNEL_LOG(context, |
296 | "Does not support type %d, requires float|int|uint8" , |
297 | input1->type); |
298 | return kTfLiteError; |
299 | } |
300 | return kTfLiteOk; |
301 | } |
302 | |
303 | TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) { |
304 | const TfLiteTensor* input1; |
305 | TF_LITE_ENSURE_OK(context, |
306 | GetInputSafe(context, node, kInputTensor1, &input1)); |
307 | const TfLiteTensor* input2; |
308 | TF_LITE_ENSURE_OK(context, |
309 | GetInputSafe(context, node, kInputTensor2, &input2)); |
310 | TfLiteTensor* output; |
311 | TF_LITE_ENSURE_OK(context, |
312 | GetOutputSafe(context, node, kOutputTensor, &output)); |
313 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
314 | switch (input1->type) { |
315 | case kTfLiteFloat32: |
316 | Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output, |
317 | requires_broadcast); |
318 | break; |
319 | case kTfLiteInt32: |
320 | Comparison<int32_t, reference_ops::GreaterEqualFn>(input1, input2, output, |
321 | requires_broadcast); |
322 | break; |
323 | case kTfLiteInt64: |
324 | Comparison<int64_t, reference_ops::GreaterEqualFn>(input1, input2, output, |
325 | requires_broadcast); |
326 | break; |
327 | case kTfLiteUInt8: |
328 | ComparisonQuantized<uint8_t, reference_ops::GreaterEqualFn>( |
329 | input1, input2, output, requires_broadcast); |
330 | break; |
331 | case kTfLiteInt8: |
332 | ComparisonQuantized<int8_t, reference_ops::GreaterEqualFn>( |
333 | input1, input2, output, requires_broadcast); |
334 | break; |
335 | default: |
336 | TF_LITE_KERNEL_LOG(context, |
337 | "Does not support type %d, requires float|int|uint8" , |
338 | input1->type); |
339 | return kTfLiteError; |
340 | } |
341 | return kTfLiteOk; |
342 | } |
343 | |
344 | TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) { |
345 | const TfLiteTensor* input1; |
346 | TF_LITE_ENSURE_OK(context, |
347 | GetInputSafe(context, node, kInputTensor1, &input1)); |
348 | const TfLiteTensor* input2; |
349 | TF_LITE_ENSURE_OK(context, |
350 | GetInputSafe(context, node, kInputTensor2, &input2)); |
351 | TfLiteTensor* output; |
352 | TF_LITE_ENSURE_OK(context, |
353 | GetOutputSafe(context, node, kOutputTensor, &output)); |
354 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
355 | switch (input1->type) { |
356 | case kTfLiteFloat32: |
357 | Comparison<float, reference_ops::LessFn>(input1, input2, output, |
358 | requires_broadcast); |
359 | break; |
360 | case kTfLiteInt32: |
361 | Comparison<int32_t, reference_ops::LessFn>(input1, input2, output, |
362 | requires_broadcast); |
363 | break; |
364 | case kTfLiteInt64: |
365 | Comparison<int64_t, reference_ops::LessFn>(input1, input2, output, |
366 | requires_broadcast); |
367 | break; |
368 | case kTfLiteUInt8: |
369 | ComparisonQuantized<uint8_t, reference_ops::LessFn>( |
370 | input1, input2, output, requires_broadcast); |
371 | break; |
372 | case kTfLiteInt8: |
373 | ComparisonQuantized<int8_t, reference_ops::LessFn>(input1, input2, output, |
374 | requires_broadcast); |
375 | break; |
376 | default: |
377 | TF_LITE_KERNEL_LOG(context, |
378 | "Does not support type %d, requires float|int|uint8" , |
379 | input1->type); |
380 | return kTfLiteError; |
381 | } |
382 | return kTfLiteOk; |
383 | } |
384 | |
385 | TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) { |
386 | const TfLiteTensor* input1; |
387 | TF_LITE_ENSURE_OK(context, |
388 | GetInputSafe(context, node, kInputTensor1, &input1)); |
389 | const TfLiteTensor* input2; |
390 | TF_LITE_ENSURE_OK(context, |
391 | GetInputSafe(context, node, kInputTensor2, &input2)); |
392 | TfLiteTensor* output; |
393 | TF_LITE_ENSURE_OK(context, |
394 | GetOutputSafe(context, node, kOutputTensor, &output)); |
395 | bool requires_broadcast = !HaveSameShapes(input1, input2); |
396 | switch (input1->type) { |
397 | case kTfLiteFloat32: |
398 | Comparison<float, reference_ops::LessEqualFn>(input1, input2, output, |
399 | requires_broadcast); |
400 | break; |
401 | case kTfLiteInt32: |
402 | Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output, |
403 | requires_broadcast); |
404 | break; |
405 | case kTfLiteInt64: |
406 | Comparison<int64_t, reference_ops::LessEqualFn>(input1, input2, output, |
407 | requires_broadcast); |
408 | break; |
409 | case kTfLiteUInt8: |
410 | ComparisonQuantized<uint8_t, reference_ops::LessEqualFn>( |
411 | input1, input2, output, requires_broadcast); |
412 | break; |
413 | case kTfLiteInt8: |
414 | ComparisonQuantized<int8_t, reference_ops::LessEqualFn>( |
415 | input1, input2, output, requires_broadcast); |
416 | break; |
417 | default: |
418 | TF_LITE_KERNEL_LOG(context, |
419 | "Does not support type %d, requires float|int|uint8" , |
420 | input1->type); |
421 | return kTfLiteError; |
422 | } |
423 | return kTfLiteOk; |
424 | } |
425 | |
426 | } // namespace |
427 | } // namespace comparisons |
428 | |
429 | TfLiteRegistration* Register_EQUAL() { |
430 | static TfLiteRegistration r = {nullptr, nullptr, |
431 | comparisons::ComparisonPrepareStringAllowed, |
432 | comparisons::EqualEval}; |
433 | return &r; |
434 | } |
435 | |
436 | TfLiteRegistration* Register_NOT_EQUAL() { |
437 | static TfLiteRegistration r = {nullptr, nullptr, |
438 | comparisons::ComparisonPrepareStringAllowed, |
439 | comparisons::NotEqualEval}; |
440 | return &r; |
441 | } |
442 | |
443 | TfLiteRegistration* Register_GREATER() { |
444 | static TfLiteRegistration r = {nullptr, nullptr, |
445 | comparisons::ComparisonPrepare, |
446 | comparisons::GreaterEval}; |
447 | return &r; |
448 | } |
449 | |
450 | TfLiteRegistration* Register_GREATER_EQUAL() { |
451 | static TfLiteRegistration r = {nullptr, nullptr, |
452 | comparisons::ComparisonPrepare, |
453 | comparisons::GreaterEqualEval}; |
454 | return &r; |
455 | } |
456 | |
457 | TfLiteRegistration* Register_LESS() { |
458 | static TfLiteRegistration r = { |
459 | nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval}; |
460 | return &r; |
461 | } |
462 | |
463 | TfLiteRegistration* Register_LESS_EQUAL() { |
464 | static TfLiteRegistration r = {nullptr, nullptr, |
465 | comparisons::ComparisonPrepare, |
466 | comparisons::LessEqualEval}; |
467 | return &r; |
468 | } |
469 | |
470 | } // namespace builtin |
471 | } // namespace ops |
472 | } // namespace tflite |
473 | |