1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h"
16
17#include <stdint.h>
18
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/compatibility.h"
21#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
23#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
24#include "tensorflow/lite/kernels/internal/tensor.h"
25#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
26#include "tensorflow/lite/kernels/internal/types.h"
27#include "tensorflow/lite/kernels/kernel_util.h"
28
29namespace tflite {
30namespace ops {
31namespace builtin {
32namespace maximum_minimum {
33
34// This file has a reference implementation of TFMaximum/TFMinimum.
35enum KernelType {
36 kReference,
37 kGenericOptimized,
38};
39
40constexpr int kInputTensor1 = 0;
41constexpr int kInputTensor2 = 1;
42constexpr int kOutputTensor = 0;
43
44struct OpContext {
45 OpContext(TfLiteContext* context, TfLiteNode* node) {
46 input1 = GetInput(context, node, kInputTensor1);
47 input2 = GetInput(context, node, kInputTensor2);
48 output = GetOutput(context, node, kOutputTensor);
49 }
50 const TfLiteTensor* input1;
51 const TfLiteTensor* input2;
52 TfLiteTensor* output;
53};
54
55TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
56 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
57 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
58
59 OpContext op_context(context, node);
60 TF_LITE_ENSURE_TYPES_EQ(context, op_context.input1->type,
61 op_context.input2->type);
62 op_context.output->type = op_context.input1->type;
63
64 bool requires_broadcast =
65 !HaveSameShapes(op_context.input1, op_context.input2);
66
67 TfLiteIntArray* output_size = nullptr;
68 if (requires_broadcast) {
69 TF_LITE_ENSURE_OK(
70 context, CalculateShapeForBroadcast(context, op_context.input1,
71 op_context.input2, &output_size));
72 } else {
73 output_size = TfLiteIntArrayCopy(op_context.input1->dims);
74 }
75
76 return context->ResizeTensor(context, op_context.output, output_size);
77}
78
79struct MaximumOp {
80 template <typename data_type>
81 static data_type op(data_type el1, data_type el2) {
82 return el1 > el2 ? el1 : el2;
83 }
84};
85
86struct MinimumOp {
87 template <typename data_type>
88 static data_type op(data_type el1, data_type el2) {
89 return el1 < el2 ? el1 : el2;
90 }
91};
92
93template <KernelType kernel_type, typename data_type, typename op_type>
94void TFLiteOperation(TfLiteContext* context, TfLiteNode* node,
95 const OpContext& op_context) {
96 reference_ops::MaximumMinimumBroadcastSlow(
97 GetTensorShape(op_context.input1),
98 GetTensorData<data_type>(op_context.input1),
99 GetTensorShape(op_context.input2),
100 GetTensorData<data_type>(op_context.input2),
101 GetTensorShape(op_context.output),
102 GetTensorData<data_type>(op_context.output),
103 op_type::template op<data_type>);
104}
105
106// Maximum generic opt int8.
107template <>
108void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>(
109 TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
110 tflite::ArithmeticParams op_params;
111 const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
112 GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
113 &op_params);
114 if (need_broadcast) {
115 optimized_ops::BroadcastMaximumDispatch(
116 op_params, GetTensorShape(op_context.input1),
117 GetTensorData<int8>(op_context.input1),
118 GetTensorShape(op_context.input2),
119 GetTensorData<int8>(op_context.input2),
120 GetTensorShape(op_context.output),
121 GetTensorData<int8>(op_context.output), MaximumOp::template op<int8>);
122 return;
123 }
124 reference_ops::MaximumMinimumBroadcastSlow(
125 GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
126 GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
127 GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
128 MaximumOp::template op<int8>);
129}
130
131// Minimum generic opt int8.
132template <>
133void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>(
134 TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) {
135 tflite::ArithmeticParams op_params;
136 const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
137 GetTensorShape(op_context.input1), GetTensorShape(op_context.input2),
138 &op_params);
139 if (need_broadcast) {
140 optimized_ops::BroadcastMinimumDispatch(
141 op_params, GetTensorShape(op_context.input1),
142 GetTensorData<int8>(op_context.input1),
143 GetTensorShape(op_context.input2),
144 GetTensorData<int8>(op_context.input2),
145 GetTensorShape(op_context.output),
146 GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>);
147 return;
148 }
149 reference_ops::MaximumMinimumBroadcastSlow(
150 GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1),
151 GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2),
152 GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output),
153 MinimumOp::template op<int8>);
154}
155
156template <KernelType kernel_type, typename OpType>
157TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
158 OpContext op_context(context, node);
159
160 // If inputs have no element, shortcircuit.
161 if (NumElements(op_context.input1) == 0 ||
162 NumElements(op_context.input2) == 0) {
163 return kTfLiteOk;
164 }
165
166 switch (op_context.output->type) {
167 case kTfLiteFloat32:
168 TFLiteOperation<kernel_type, float, OpType>(context, node, op_context);
169 break;
170 case kTfLiteUInt8:
171 TFLiteOperation<kernel_type, uint8_t, OpType>(context, node, op_context);
172 break;
173 case kTfLiteInt8:
174 TFLiteOperation<kernel_type, int8_t, OpType>(context, node, op_context);
175 break;
176 case kTfLiteInt32:
177 TFLiteOperation<kernel_type, int32_t, OpType>(context, node, op_context);
178 break;
179 case kTfLiteInt64:
180 TFLiteOperation<kernel_type, int64_t, OpType>(context, node, op_context);
181 break;
182 case kTfLiteInt16:
183 TFLiteOperation<kernel_type, int16_t, OpType>(context, node, op_context);
184 break;
185 default:
186 TF_LITE_KERNEL_LOG(context,
187 "Type %d is currently not supported by Maximum.",
188 op_context.output->type);
189 return kTfLiteError;
190 }
191 return kTfLiteOk;
192}
193
194} // namespace maximum_minimum
195
196TfLiteRegistration* Register_MAXIMUM_REF() {
197 static TfLiteRegistration r = {
198 nullptr, nullptr, maximum_minimum::Prepare,
199 maximum_minimum::Eval<maximum_minimum::kReference,
200 maximum_minimum::MaximumOp>};
201 return &r;
202}
203
204TfLiteRegistration* Register_MAXIMUM_GENERIC_OPT() {
205 static TfLiteRegistration r = {
206 nullptr, nullptr, maximum_minimum::Prepare,
207 maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
208 maximum_minimum::MaximumOp>};
209 return &r;
210}
211
212TfLiteRegistration* Register_MINIMUM_REF() {
213 static TfLiteRegistration r = {
214 nullptr, nullptr, maximum_minimum::Prepare,
215 maximum_minimum::Eval<maximum_minimum::kReference,
216 maximum_minimum::MinimumOp>};
217 return &r;
218}
219
220TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() {
221 static TfLiteRegistration r = {
222 nullptr, nullptr, maximum_minimum::Prepare,
223 maximum_minimum::Eval<maximum_minimum::kGenericOptimized,
224 maximum_minimum::MinimumOp>};
225 return &r;
226}
227
228TfLiteRegistration* Register_MAXIMUM() {
229 return Register_MAXIMUM_GENERIC_OPT();
230}
231TfLiteRegistration* Register_MINIMUM() {
232 return Register_MINIMUM_GENERIC_OPT();
233}
234
235} // namespace builtin
236} // namespace ops
237} // namespace tflite
238