1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/comparisons.h"
16
17#include <stdint.h>
18
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/compatibility.h"
21#include "tensorflow/lite/kernels/internal/quantization_util.h"
22#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23#include "tensorflow/lite/kernels/internal/tensor.h"
24#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25#include "tensorflow/lite/kernels/internal/types.h"
26#include "tensorflow/lite/kernels/kernel_util.h"
27#include "tensorflow/lite/string_util.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace comparisons {
33namespace {
34
35constexpr int kInputTensor1 = 0;
36constexpr int kInputTensor2 = 1;
37constexpr int kOutputTensor = 0;
38
39TfLiteStatus ComparisonPrepareCommon(TfLiteContext* context, TfLiteNode* node,
40 bool is_string_allowed) {
41 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
42 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
43
44 const TfLiteTensor* input1;
45 TF_LITE_ENSURE_OK(context,
46 GetInputSafe(context, node, kInputTensor1, &input1));
47 const TfLiteTensor* input2;
48 TF_LITE_ENSURE_OK(context,
49 GetInputSafe(context, node, kInputTensor2, &input2));
50 TfLiteTensor* output;
51 TF_LITE_ENSURE_OK(context,
52 GetOutputSafe(context, node, kOutputTensor, &output));
53
54 // Don't support string.
55 if (!is_string_allowed) {
56 TF_LITE_ENSURE(context, input1->type != kTfLiteString);
57 }
58 // Currently only support tensors have the same type.
59 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
60 output->type = kTfLiteBool;
61
62 bool requires_broadcast = !HaveSameShapes(input1, input2);
63
64 TfLiteIntArray* output_size = nullptr;
65 if (requires_broadcast) {
66 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
67 context, input1, input2, &output_size));
68 } else {
69 output_size = TfLiteIntArrayCopy(input1->dims);
70 }
71
72 return context->ResizeTensor(context, output, output_size);
73}
74
75TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
76 return ComparisonPrepareCommon(context, node, false);
77}
78
79TfLiteStatus ComparisonPrepareStringAllowed(TfLiteContext* context,
80 TfLiteNode* node) {
81 return ComparisonPrepareCommon(context, node, true);
82}
83
84void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier,
85 int* left_shift) {
86 if (double_multiplier < 1.0) {
87 QuantizeMultiplierSmallerThanOneExp(double_multiplier, quantized_multiplier,
88 left_shift);
89 } else {
90 QuantizeMultiplierGreaterThanOne(double_multiplier, quantized_multiplier,
91 left_shift);
92 }
93}
94
95template <typename input_dtype, reference_ops::ComparisonFn<int32> opname>
96void ComparisonQuantized(const TfLiteTensor* input1, const TfLiteTensor* input2,
97 TfLiteTensor* output, bool requires_broadcast) {
98 if (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8) {
99 auto input1_offset = -input1->params.zero_point;
100 auto input2_offset = -input2->params.zero_point;
101 const int left_shift = 8;
102
103 int32 input1_multiplier;
104 int32 input2_multiplier;
105 int input1_shift;
106 int input2_shift;
107 QuantizeMultiplier(input1->params.scale, &input1_multiplier, &input1_shift);
108 QuantizeMultiplier(input2->params.scale, &input2_multiplier, &input2_shift);
109
110 ComparisonParams op_params;
111 op_params.left_shift = left_shift;
112 op_params.input1_offset = input1_offset;
113 op_params.input1_multiplier = input1_multiplier;
114 op_params.input1_shift = input1_shift;
115 op_params.input2_offset = input2_offset;
116 op_params.input2_multiplier = input2_multiplier;
117 op_params.input2_shift = input2_shift;
118 if (requires_broadcast) {
119 reference_ops::BroadcastComparison4DSlowWithScaling<input_dtype, opname>(
120 op_params, GetTensorShape(input1), GetTensorData<input_dtype>(input1),
121 GetTensorShape(input2), GetTensorData<input_dtype>(input2),
122 GetTensorShape(output), GetTensorData<bool>(output));
123 } else {
124 reference_ops::ComparisonWithScaling<input_dtype, opname>(
125 op_params, GetTensorShape(input1), GetTensorData<input_dtype>(input1),
126 GetTensorShape(input2), GetTensorData<input_dtype>(input2),
127 GetTensorShape(output), GetTensorData<bool>(output));
128 }
129 }
130}
131
132template <typename T, reference_ops::ComparisonFn<T> opname>
133void Comparison(const TfLiteTensor* input1, const TfLiteTensor* input2,
134 TfLiteTensor* output, bool requires_broadcast) {
135 ComparisonParams op_params;
136 requires_broadcast
137 ? reference_ops::BroadcastComparison4DSlowImpl<T, opname>(
138 op_params, GetTensorShape(input1), GetTensorData<T>(input1),
139 GetTensorShape(input2), GetTensorData<T>(input2),
140 GetTensorShape(output), GetTensorData<bool>(output))
141 : reference_ops::ComparisonImpl<T, opname>(
142 op_params, GetTensorShape(input1), GetTensorData<T>(input1),
143 GetTensorShape(input2), GetTensorData<T>(input2),
144 GetTensorShape(output), GetTensorData<bool>(output));
145}
146
147void ComparisonString(bool (*opname)(const StringRef&, const StringRef&),
148 const TfLiteTensor* input1, const TfLiteTensor* input2,
149 TfLiteTensor* output, bool requires_broadcast) {
150 bool* output_data = GetTensorData<bool>(output);
151 if (requires_broadcast) {
152 reference_ops::BroadcastComparison4DSlowStringImpl(
153 opname, GetTensorShape(input1), input1, GetTensorShape(input2), input2,
154 GetTensorShape(output), output_data);
155 } else {
156 reference_ops::ComparisonStringImpl(opname, GetTensorShape(input1), input1,
157 GetTensorShape(input2), input2,
158 GetTensorShape(output), output_data);
159 }
160}
161
162TfLiteStatus EqualEval(TfLiteContext* context, TfLiteNode* node) {
163 const TfLiteTensor* input1;
164 TF_LITE_ENSURE_OK(context,
165 GetInputSafe(context, node, kInputTensor1, &input1));
166 const TfLiteTensor* input2;
167 TF_LITE_ENSURE_OK(context,
168 GetInputSafe(context, node, kInputTensor2, &input2));
169 TfLiteTensor* output;
170 TF_LITE_ENSURE_OK(context,
171 GetOutputSafe(context, node, kOutputTensor, &output));
172 bool requires_broadcast = !HaveSameShapes(input1, input2);
173 switch (input1->type) {
174 case kTfLiteBool:
175 Comparison<bool, reference_ops::EqualFn>(input1, input2, output,
176 requires_broadcast);
177 break;
178 case kTfLiteFloat32:
179 Comparison<float, reference_ops::EqualFn>(input1, input2, output,
180 requires_broadcast);
181 break;
182 case kTfLiteInt32:
183 Comparison<int32_t, reference_ops::EqualFn>(input1, input2, output,
184 requires_broadcast);
185 break;
186 case kTfLiteInt64:
187 Comparison<int64_t, reference_ops::EqualFn>(input1, input2, output,
188 requires_broadcast);
189 break;
190 case kTfLiteUInt8:
191 ComparisonQuantized<uint8_t, reference_ops::EqualFn>(
192 input1, input2, output, requires_broadcast);
193 break;
194 case kTfLiteInt8:
195 ComparisonQuantized<int8_t, reference_ops::EqualFn>(
196 input1, input2, output, requires_broadcast);
197 break;
198 case kTfLiteString:
199 ComparisonString(reference_ops::StringRefEqualFn, input1, input2, output,
200 requires_broadcast);
201 break;
202 default:
203 TF_LITE_KERNEL_LOG(
204 context,
205 "Does not support type %d, requires bool|float|int|uint8|string",
206 input1->type);
207 return kTfLiteError;
208 }
209 return kTfLiteOk;
210}
211
212TfLiteStatus NotEqualEval(TfLiteContext* context, TfLiteNode* node) {
213 const TfLiteTensor* input1;
214 TF_LITE_ENSURE_OK(context,
215 GetInputSafe(context, node, kInputTensor1, &input1));
216 const TfLiteTensor* input2;
217 TF_LITE_ENSURE_OK(context,
218 GetInputSafe(context, node, kInputTensor2, &input2));
219 TfLiteTensor* output;
220 TF_LITE_ENSURE_OK(context,
221 GetOutputSafe(context, node, kOutputTensor, &output));
222 bool requires_broadcast = !HaveSameShapes(input1, input2);
223 switch (input1->type) {
224 case kTfLiteBool:
225 Comparison<bool, reference_ops::NotEqualFn>(input1, input2, output,
226 requires_broadcast);
227 break;
228 case kTfLiteFloat32:
229 Comparison<float, reference_ops::NotEqualFn>(input1, input2, output,
230 requires_broadcast);
231 break;
232 case kTfLiteInt32:
233 Comparison<int32_t, reference_ops::NotEqualFn>(input1, input2, output,
234 requires_broadcast);
235 break;
236 case kTfLiteInt64:
237 Comparison<int64_t, reference_ops::NotEqualFn>(input1, input2, output,
238 requires_broadcast);
239 break;
240 case kTfLiteUInt8:
241 ComparisonQuantized<uint8_t, reference_ops::NotEqualFn>(
242 input1, input2, output, requires_broadcast);
243 break;
244 case kTfLiteInt8:
245 ComparisonQuantized<int8_t, reference_ops::NotEqualFn>(
246 input1, input2, output, requires_broadcast);
247 break;
248 case kTfLiteString:
249 ComparisonString(reference_ops::StringRefNotEqualFn, input1, input2,
250 output, requires_broadcast);
251 break;
252 default:
253 TF_LITE_KERNEL_LOG(
254 context,
255 "Does not support type %d, requires bool|float|int|uint8|string",
256 input1->type);
257 return kTfLiteError;
258 }
259 return kTfLiteOk;
260}
261
262TfLiteStatus GreaterEval(TfLiteContext* context, TfLiteNode* node) {
263 const TfLiteTensor* input1;
264 TF_LITE_ENSURE_OK(context,
265 GetInputSafe(context, node, kInputTensor1, &input1));
266 const TfLiteTensor* input2;
267 TF_LITE_ENSURE_OK(context,
268 GetInputSafe(context, node, kInputTensor2, &input2));
269 TfLiteTensor* output;
270 TF_LITE_ENSURE_OK(context,
271 GetOutputSafe(context, node, kOutputTensor, &output));
272 bool requires_broadcast = !HaveSameShapes(input1, input2);
273 switch (input1->type) {
274 case kTfLiteFloat32:
275 Comparison<float, reference_ops::GreaterFn>(input1, input2, output,
276 requires_broadcast);
277 break;
278 case kTfLiteInt32:
279 Comparison<int32_t, reference_ops::GreaterFn>(input1, input2, output,
280 requires_broadcast);
281 break;
282 case kTfLiteInt64:
283 Comparison<int64_t, reference_ops::GreaterFn>(input1, input2, output,
284 requires_broadcast);
285 break;
286 case kTfLiteUInt8:
287 ComparisonQuantized<uint8_t, reference_ops::GreaterFn>(
288 input1, input2, output, requires_broadcast);
289 break;
290 case kTfLiteInt8:
291 ComparisonQuantized<int8_t, reference_ops::GreaterFn>(
292 input1, input2, output, requires_broadcast);
293 break;
294 default:
295 TF_LITE_KERNEL_LOG(context,
296 "Does not support type %d, requires float|int|uint8",
297 input1->type);
298 return kTfLiteError;
299 }
300 return kTfLiteOk;
301}
302
303TfLiteStatus GreaterEqualEval(TfLiteContext* context, TfLiteNode* node) {
304 const TfLiteTensor* input1;
305 TF_LITE_ENSURE_OK(context,
306 GetInputSafe(context, node, kInputTensor1, &input1));
307 const TfLiteTensor* input2;
308 TF_LITE_ENSURE_OK(context,
309 GetInputSafe(context, node, kInputTensor2, &input2));
310 TfLiteTensor* output;
311 TF_LITE_ENSURE_OK(context,
312 GetOutputSafe(context, node, kOutputTensor, &output));
313 bool requires_broadcast = !HaveSameShapes(input1, input2);
314 switch (input1->type) {
315 case kTfLiteFloat32:
316 Comparison<float, reference_ops::GreaterEqualFn>(input1, input2, output,
317 requires_broadcast);
318 break;
319 case kTfLiteInt32:
320 Comparison<int32_t, reference_ops::GreaterEqualFn>(input1, input2, output,
321 requires_broadcast);
322 break;
323 case kTfLiteInt64:
324 Comparison<int64_t, reference_ops::GreaterEqualFn>(input1, input2, output,
325 requires_broadcast);
326 break;
327 case kTfLiteUInt8:
328 ComparisonQuantized<uint8_t, reference_ops::GreaterEqualFn>(
329 input1, input2, output, requires_broadcast);
330 break;
331 case kTfLiteInt8:
332 ComparisonQuantized<int8_t, reference_ops::GreaterEqualFn>(
333 input1, input2, output, requires_broadcast);
334 break;
335 default:
336 TF_LITE_KERNEL_LOG(context,
337 "Does not support type %d, requires float|int|uint8",
338 input1->type);
339 return kTfLiteError;
340 }
341 return kTfLiteOk;
342}
343
344TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
345 const TfLiteTensor* input1;
346 TF_LITE_ENSURE_OK(context,
347 GetInputSafe(context, node, kInputTensor1, &input1));
348 const TfLiteTensor* input2;
349 TF_LITE_ENSURE_OK(context,
350 GetInputSafe(context, node, kInputTensor2, &input2));
351 TfLiteTensor* output;
352 TF_LITE_ENSURE_OK(context,
353 GetOutputSafe(context, node, kOutputTensor, &output));
354 bool requires_broadcast = !HaveSameShapes(input1, input2);
355 switch (input1->type) {
356 case kTfLiteFloat32:
357 Comparison<float, reference_ops::LessFn>(input1, input2, output,
358 requires_broadcast);
359 break;
360 case kTfLiteInt32:
361 Comparison<int32_t, reference_ops::LessFn>(input1, input2, output,
362 requires_broadcast);
363 break;
364 case kTfLiteInt64:
365 Comparison<int64_t, reference_ops::LessFn>(input1, input2, output,
366 requires_broadcast);
367 break;
368 case kTfLiteUInt8:
369 ComparisonQuantized<uint8_t, reference_ops::LessFn>(
370 input1, input2, output, requires_broadcast);
371 break;
372 case kTfLiteInt8:
373 ComparisonQuantized<int8_t, reference_ops::LessFn>(input1, input2, output,
374 requires_broadcast);
375 break;
376 default:
377 TF_LITE_KERNEL_LOG(context,
378 "Does not support type %d, requires float|int|uint8",
379 input1->type);
380 return kTfLiteError;
381 }
382 return kTfLiteOk;
383}
384
385TfLiteStatus LessEqualEval(TfLiteContext* context, TfLiteNode* node) {
386 const TfLiteTensor* input1;
387 TF_LITE_ENSURE_OK(context,
388 GetInputSafe(context, node, kInputTensor1, &input1));
389 const TfLiteTensor* input2;
390 TF_LITE_ENSURE_OK(context,
391 GetInputSafe(context, node, kInputTensor2, &input2));
392 TfLiteTensor* output;
393 TF_LITE_ENSURE_OK(context,
394 GetOutputSafe(context, node, kOutputTensor, &output));
395 bool requires_broadcast = !HaveSameShapes(input1, input2);
396 switch (input1->type) {
397 case kTfLiteFloat32:
398 Comparison<float, reference_ops::LessEqualFn>(input1, input2, output,
399 requires_broadcast);
400 break;
401 case kTfLiteInt32:
402 Comparison<int32_t, reference_ops::LessEqualFn>(input1, input2, output,
403 requires_broadcast);
404 break;
405 case kTfLiteInt64:
406 Comparison<int64_t, reference_ops::LessEqualFn>(input1, input2, output,
407 requires_broadcast);
408 break;
409 case kTfLiteUInt8:
410 ComparisonQuantized<uint8_t, reference_ops::LessEqualFn>(
411 input1, input2, output, requires_broadcast);
412 break;
413 case kTfLiteInt8:
414 ComparisonQuantized<int8_t, reference_ops::LessEqualFn>(
415 input1, input2, output, requires_broadcast);
416 break;
417 default:
418 TF_LITE_KERNEL_LOG(context,
419 "Does not support type %d, requires float|int|uint8",
420 input1->type);
421 return kTfLiteError;
422 }
423 return kTfLiteOk;
424}
425
426} // namespace
427} // namespace comparisons
428
429TfLiteRegistration* Register_EQUAL() {
430 static TfLiteRegistration r = {nullptr, nullptr,
431 comparisons::ComparisonPrepareStringAllowed,
432 comparisons::EqualEval};
433 return &r;
434}
435
436TfLiteRegistration* Register_NOT_EQUAL() {
437 static TfLiteRegistration r = {nullptr, nullptr,
438 comparisons::ComparisonPrepareStringAllowed,
439 comparisons::NotEqualEval};
440 return &r;
441}
442
443TfLiteRegistration* Register_GREATER() {
444 static TfLiteRegistration r = {nullptr, nullptr,
445 comparisons::ComparisonPrepare,
446 comparisons::GreaterEval};
447 return &r;
448}
449
450TfLiteRegistration* Register_GREATER_EQUAL() {
451 static TfLiteRegistration r = {nullptr, nullptr,
452 comparisons::ComparisonPrepare,
453 comparisons::GreaterEqualEval};
454 return &r;
455}
456
457TfLiteRegistration* Register_LESS() {
458 static TfLiteRegistration r = {
459 nullptr, nullptr, comparisons::ComparisonPrepare, comparisons::LessEval};
460 return &r;
461}
462
463TfLiteRegistration* Register_LESS_EQUAL() {
464 static TfLiteRegistration r = {nullptr, nullptr,
465 comparisons::ComparisonPrepare,
466 comparisons::LessEqualEval};
467 return &r;
468}
469
470} // namespace builtin
471} // namespace ops
472} // namespace tflite
473