1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/reference/maximum_minimum.h" |
16 | |
17 | #include <stdint.h> |
18 | |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
21 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
22 | #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" |
23 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor.h" |
25 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
26 | #include "tensorflow/lite/kernels/internal/types.h" |
27 | #include "tensorflow/lite/kernels/kernel_util.h" |
28 | |
29 | namespace tflite { |
30 | namespace ops { |
31 | namespace builtin { |
32 | namespace maximum_minimum { |
33 | |
34 | // This file has a reference implementation of TFMaximum/TFMinimum. |
35 | enum KernelType { |
36 | kReference, |
37 | kGenericOptimized, |
38 | }; |
39 | |
40 | constexpr int kInputTensor1 = 0; |
41 | constexpr int kInputTensor2 = 1; |
42 | constexpr int kOutputTensor = 0; |
43 | |
44 | struct OpContext { |
45 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
46 | input1 = GetInput(context, node, kInputTensor1); |
47 | input2 = GetInput(context, node, kInputTensor2); |
48 | output = GetOutput(context, node, kOutputTensor); |
49 | } |
50 | const TfLiteTensor* input1; |
51 | const TfLiteTensor* input2; |
52 | TfLiteTensor* output; |
53 | }; |
54 | |
55 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
56 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
57 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
58 | |
59 | OpContext op_context(context, node); |
60 | TF_LITE_ENSURE_TYPES_EQ(context, op_context.input1->type, |
61 | op_context.input2->type); |
62 | op_context.output->type = op_context.input1->type; |
63 | |
64 | bool requires_broadcast = |
65 | !HaveSameShapes(op_context.input1, op_context.input2); |
66 | |
67 | TfLiteIntArray* output_size = nullptr; |
68 | if (requires_broadcast) { |
69 | TF_LITE_ENSURE_OK( |
70 | context, CalculateShapeForBroadcast(context, op_context.input1, |
71 | op_context.input2, &output_size)); |
72 | } else { |
73 | output_size = TfLiteIntArrayCopy(op_context.input1->dims); |
74 | } |
75 | |
76 | return context->ResizeTensor(context, op_context.output, output_size); |
77 | } |
78 | |
79 | struct MaximumOp { |
80 | template <typename data_type> |
81 | static data_type op(data_type el1, data_type el2) { |
82 | return el1 > el2 ? el1 : el2; |
83 | } |
84 | }; |
85 | |
86 | struct MinimumOp { |
87 | template <typename data_type> |
88 | static data_type op(data_type el1, data_type el2) { |
89 | return el1 < el2 ? el1 : el2; |
90 | } |
91 | }; |
92 | |
93 | template <KernelType kernel_type, typename data_type, typename op_type> |
94 | void TFLiteOperation(TfLiteContext* context, TfLiteNode* node, |
95 | const OpContext& op_context) { |
96 | reference_ops::MaximumMinimumBroadcastSlow( |
97 | GetTensorShape(op_context.input1), |
98 | GetTensorData<data_type>(op_context.input1), |
99 | GetTensorShape(op_context.input2), |
100 | GetTensorData<data_type>(op_context.input2), |
101 | GetTensorShape(op_context.output), |
102 | GetTensorData<data_type>(op_context.output), |
103 | op_type::template op<data_type>); |
104 | } |
105 | |
106 | // Maximum generic opt int8. |
107 | template <> |
108 | void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MaximumOp>( |
109 | TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) { |
110 | tflite::ArithmeticParams op_params; |
111 | const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( |
112 | GetTensorShape(op_context.input1), GetTensorShape(op_context.input2), |
113 | &op_params); |
114 | if (need_broadcast) { |
115 | optimized_ops::BroadcastMaximumDispatch( |
116 | op_params, GetTensorShape(op_context.input1), |
117 | GetTensorData<int8>(op_context.input1), |
118 | GetTensorShape(op_context.input2), |
119 | GetTensorData<int8>(op_context.input2), |
120 | GetTensorShape(op_context.output), |
121 | GetTensorData<int8>(op_context.output), MaximumOp::template op<int8>); |
122 | return; |
123 | } |
124 | reference_ops::MaximumMinimumBroadcastSlow( |
125 | GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1), |
126 | GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2), |
127 | GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output), |
128 | MaximumOp::template op<int8>); |
129 | } |
130 | |
131 | // Minimum generic opt int8. |
132 | template <> |
133 | void TFLiteOperation<maximum_minimum::kGenericOptimized, int8, MinimumOp>( |
134 | TfLiteContext* context, TfLiteNode* node, const OpContext& op_context) { |
135 | tflite::ArithmeticParams op_params; |
136 | const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( |
137 | GetTensorShape(op_context.input1), GetTensorShape(op_context.input2), |
138 | &op_params); |
139 | if (need_broadcast) { |
140 | optimized_ops::BroadcastMinimumDispatch( |
141 | op_params, GetTensorShape(op_context.input1), |
142 | GetTensorData<int8>(op_context.input1), |
143 | GetTensorShape(op_context.input2), |
144 | GetTensorData<int8>(op_context.input2), |
145 | GetTensorShape(op_context.output), |
146 | GetTensorData<int8>(op_context.output), MinimumOp::template op<int8>); |
147 | return; |
148 | } |
149 | reference_ops::MaximumMinimumBroadcastSlow( |
150 | GetTensorShape(op_context.input1), GetTensorData<int8>(op_context.input1), |
151 | GetTensorShape(op_context.input2), GetTensorData<int8>(op_context.input2), |
152 | GetTensorShape(op_context.output), GetTensorData<int8>(op_context.output), |
153 | MinimumOp::template op<int8>); |
154 | } |
155 | |
156 | template <KernelType kernel_type, typename OpType> |
157 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
158 | OpContext op_context(context, node); |
159 | |
160 | // If inputs have no element, shortcircuit. |
161 | if (NumElements(op_context.input1) == 0 || |
162 | NumElements(op_context.input2) == 0) { |
163 | return kTfLiteOk; |
164 | } |
165 | |
166 | switch (op_context.output->type) { |
167 | case kTfLiteFloat32: |
168 | TFLiteOperation<kernel_type, float, OpType>(context, node, op_context); |
169 | break; |
170 | case kTfLiteUInt8: |
171 | TFLiteOperation<kernel_type, uint8_t, OpType>(context, node, op_context); |
172 | break; |
173 | case kTfLiteInt8: |
174 | TFLiteOperation<kernel_type, int8_t, OpType>(context, node, op_context); |
175 | break; |
176 | case kTfLiteInt32: |
177 | TFLiteOperation<kernel_type, int32_t, OpType>(context, node, op_context); |
178 | break; |
179 | case kTfLiteInt64: |
180 | TFLiteOperation<kernel_type, int64_t, OpType>(context, node, op_context); |
181 | break; |
182 | case kTfLiteInt16: |
183 | TFLiteOperation<kernel_type, int16_t, OpType>(context, node, op_context); |
184 | break; |
185 | default: |
186 | TF_LITE_KERNEL_LOG(context, |
187 | "Type %d is currently not supported by Maximum." , |
188 | op_context.output->type); |
189 | return kTfLiteError; |
190 | } |
191 | return kTfLiteOk; |
192 | } |
193 | |
194 | } // namespace maximum_minimum |
195 | |
196 | TfLiteRegistration* Register_MAXIMUM_REF() { |
197 | static TfLiteRegistration r = { |
198 | nullptr, nullptr, maximum_minimum::Prepare, |
199 | maximum_minimum::Eval<maximum_minimum::kReference, |
200 | maximum_minimum::MaximumOp>}; |
201 | return &r; |
202 | } |
203 | |
204 | TfLiteRegistration* Register_MAXIMUM_GENERIC_OPT() { |
205 | static TfLiteRegistration r = { |
206 | nullptr, nullptr, maximum_minimum::Prepare, |
207 | maximum_minimum::Eval<maximum_minimum::kGenericOptimized, |
208 | maximum_minimum::MaximumOp>}; |
209 | return &r; |
210 | } |
211 | |
212 | TfLiteRegistration* Register_MINIMUM_REF() { |
213 | static TfLiteRegistration r = { |
214 | nullptr, nullptr, maximum_minimum::Prepare, |
215 | maximum_minimum::Eval<maximum_minimum::kReference, |
216 | maximum_minimum::MinimumOp>}; |
217 | return &r; |
218 | } |
219 | |
220 | TfLiteRegistration* Register_MINIMUM_GENERIC_OPT() { |
221 | static TfLiteRegistration r = { |
222 | nullptr, nullptr, maximum_minimum::Prepare, |
223 | maximum_minimum::Eval<maximum_minimum::kGenericOptimized, |
224 | maximum_minimum::MinimumOp>}; |
225 | return &r; |
226 | } |
227 | |
228 | TfLiteRegistration* Register_MAXIMUM() { |
229 | return Register_MAXIMUM_GENERIC_OPT(); |
230 | } |
231 | TfLiteRegistration* Register_MINIMUM() { |
232 | return Register_MINIMUM_GENERIC_OPT(); |
233 | } |
234 | |
235 | } // namespace builtin |
236 | } // namespace ops |
237 | } // namespace tflite |
238 | |