1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h" |
16 | |
17 | #include <stddef.h> |
18 | #include <stdint.h> |
19 | |
20 | #include <complex> |
21 | |
22 | #include "tensorflow/lite/c/builtin_op_data.h" |
23 | #include "tensorflow/lite/c/common.h" |
24 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
25 | #include "tensorflow/lite/kernels/internal/optimized/cpu_check.h" |
26 | #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" |
27 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
28 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
29 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h" |
30 | #include "tensorflow/lite/kernels/internal/reference/mul.h" |
31 | #include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h" |
32 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
33 | #include "tensorflow/lite/kernels/internal/tensor.h" |
34 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
35 | #include "tensorflow/lite/kernels/internal/types.h" |
36 | #include "tensorflow/lite/kernels/kernel_util.h" |
37 | |
38 | namespace tflite { |
39 | namespace ops { |
40 | namespace builtin { |
41 | namespace mul { |
42 | |
43 | // This file has three implementation of Mul. |
44 | enum KernelType { |
45 | kReference, |
46 | kGenericOptimized, // Neon-free |
47 | kNeonOptimized, |
48 | }; |
49 | |
50 | constexpr int kInputTensor1 = 0; |
51 | constexpr int kInputTensor2 = 1; |
52 | constexpr int kOutputTensor = 0; |
53 | |
54 | struct OpData { |
55 | // Parameters used in the quantized paths where the output is 8bit |
56 | int32 output_activation_min; |
57 | int32 output_activation_max; |
58 | |
59 | // Parameters used in all quantized paths |
60 | int32_t output_multiplier; |
61 | int output_shift; |
62 | }; |
63 | |
64 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
65 | auto* data = new OpData; |
66 | return data; |
67 | } |
68 | |
69 | void Free(TfLiteContext* context, void* buffer) { |
70 | delete reinterpret_cast<OpData*>(buffer); |
71 | } |
72 | |
73 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
74 | auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data); |
75 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
76 | |
77 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
78 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
79 | |
80 | const TfLiteTensor* input1; |
81 | TF_LITE_ENSURE_OK(context, |
82 | GetInputSafe(context, node, kInputTensor1, &input1)); |
83 | const TfLiteTensor* input2; |
84 | TF_LITE_ENSURE_OK(context, |
85 | GetInputSafe(context, node, kInputTensor2, &input2)); |
86 | TfLiteTensor* output; |
87 | TF_LITE_ENSURE_OK(context, |
88 | GetOutputSafe(context, node, kOutputTensor, &output)); |
89 | |
90 | TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type); |
91 | |
92 | const bool requires_broadcast = !HaveSameShapes(input1, input2); |
93 | |
94 | TfLiteIntArray* output_size = nullptr; |
95 | if (requires_broadcast) { |
96 | TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( |
97 | context, input1, input2, &output_size)); |
98 | } else { |
99 | output_size = TfLiteIntArrayCopy(input1->dims); |
100 | } |
101 | |
102 | if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || |
103 | output->type == kTfLiteInt16) { |
104 | TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized( |
105 | context, params->activation, output, &data->output_activation_min, |
106 | &data->output_activation_max)); |
107 | double real_multiplier = |
108 | input1->params.scale * input2->params.scale / output->params.scale; |
109 | QuantizeMultiplier(real_multiplier, &data->output_multiplier, |
110 | &data->output_shift); |
111 | } |
112 | |
113 | return context->ResizeTensor(context, output, output_size); |
114 | } |
115 | |
116 | template <KernelType kernel_type> |
117 | void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params, |
118 | const OpData* data, const TfLiteTensor* input1, |
119 | const TfLiteTensor* input2, TfLiteTensor* output) { |
120 | tflite::ArithmeticParams op_params; |
121 | const bool need_broadcast = optimized_ops::ProcessBroadcastShapes( |
122 | GetTensorShape(input1), GetTensorShape(input2), &op_params); |
123 | #define TF_LITE_MUL(type, opname, data_type) \ |
124 | data_type output_activation_min, output_activation_max; \ |
125 | CalculateActivationRange(params->activation, &output_activation_min, \ |
126 | &output_activation_max); \ |
127 | SetActivationParams(output_activation_min, output_activation_max, \ |
128 | &op_params); \ |
129 | type::opname(op_params, GetTensorShape(input1), \ |
130 | GetTensorData<data_type>(input1), GetTensorShape(input2), \ |
131 | GetTensorData<data_type>(input2), GetTensorShape(output), \ |
132 | GetTensorData<data_type>(output)) |
133 | |
134 | if (output->type == kTfLiteInt32) { |
135 | if (kernel_type == kReference) { |
136 | if (need_broadcast) { |
137 | TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t); |
138 | } else { |
139 | TF_LITE_MUL(reference_ops, Mul, int32_t); |
140 | } |
141 | } else { |
142 | if (need_broadcast) { |
143 | TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t); |
144 | } else { |
145 | TF_LITE_MUL(optimized_ops, Mul, int32_t); |
146 | } |
147 | } |
148 | } else if (output->type == kTfLiteFloat32) { |
149 | if (kernel_type == kReference) { |
150 | if (need_broadcast) { |
151 | TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float); |
152 | } else { |
153 | TF_LITE_MUL(reference_ops, Mul, float); |
154 | } |
155 | } else { |
156 | if (need_broadcast) { |
157 | TF_LITE_MUL(optimized_ops, BroadcastMulDispatch, float); |
158 | } else { |
159 | TF_LITE_MUL(optimized_ops, Mul, float); |
160 | } |
161 | } |
162 | } else if (output->type == kTfLiteInt64) { |
163 | if (need_broadcast) { |
164 | TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int64_t); |
165 | } else { |
166 | TF_LITE_MUL(reference_ops, Mul, int64_t); |
167 | } |
168 | #undef TF_LITE_MUL |
169 | } else if (output->type == kTfLiteComplex64) { |
170 | #define TF_LITE_MUL_COMPLEX(op_name) \ |
171 | reference_ops::op_name( \ |
172 | op_params, GetTensorShape(input1), \ |
173 | GetTensorData<std::complex<float>>(input1), GetTensorShape(input2), \ |
174 | GetTensorData<std::complex<float>>(input2), GetTensorShape(output), \ |
175 | GetTensorData<std::complex<float>>(output)); |
176 | |
177 | if (need_broadcast) { |
178 | TF_LITE_MUL_COMPLEX(BroadcastMul4DSlow); |
179 | } else { |
180 | TF_LITE_MUL_COMPLEX(Mul); |
181 | } |
182 | #undef TF_LITE_MUL_COMPLEX |
183 | } |
184 | } |
185 | |
186 | template <KernelType kernel_type> |
187 | TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, |
188 | TfLiteMulParams* params, const OpData* data, |
189 | const TfLiteTensor* input1, |
190 | const TfLiteTensor* input2, TfLiteTensor* output) { |
191 | if (input1->type == input2->type && input1->type == output->type && |
192 | (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8 || |
193 | input1->type == kTfLiteInt16)) { |
194 | tflite::ArithmeticParams op_params; |
195 | SetActivationParams(data->output_activation_min, |
196 | data->output_activation_max, &op_params); |
197 | op_params.input1_offset = -input1->params.zero_point; |
198 | op_params.input2_offset = -input2->params.zero_point; |
199 | op_params.output_offset = output->params.zero_point; |
200 | op_params.output_multiplier = data->output_multiplier; |
201 | op_params.output_shift = data->output_shift; |
202 | bool need_broadcast = optimized_ops::ProcessBroadcastShapes( |
203 | GetTensorShape(input1), GetTensorShape(input2), &op_params); |
204 | #define TF_LITE_MUL(type, opname, dtype) \ |
205 | type::opname(op_params, GetTensorShape(input1), \ |
206 | GetTensorData<dtype>(input1), GetTensorShape(input2), \ |
207 | GetTensorData<dtype>(input2), GetTensorShape(output), \ |
208 | GetTensorData<dtype>(output)) |
209 | if (input1->type == kTfLiteInt8) { |
210 | if (kernel_type == kReference) { |
211 | if (need_broadcast) { |
212 | TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t); |
213 | } else { |
214 | TF_LITE_MUL(reference_integer_ops, Mul, int8_t); |
215 | } |
216 | } else { |
217 | if (need_broadcast) { |
218 | TF_LITE_MUL(optimized_integer_ops, BroadcastMulDispatch, int8_t); |
219 | } else { |
220 | TF_LITE_MUL(optimized_integer_ops, Mul, int8_t); |
221 | } |
222 | } |
223 | } else if (input1->type == kTfLiteInt16) { |
224 | // We have this check, because in case of int16 |
225 | // input1_val*input2_val can overflow int32: |
226 | // see MulElementwise - |
227 | // tensorflow/lite/kernels/internal/reference/integer_ops/mul.h in case of |
228 | // 16-bit this function is used in symmetric quantization, so offset |
229 | // should be zero. |
230 | TF_LITE_ENSURE_EQ(context, op_params.input1_offset, 0.0); |
231 | TF_LITE_ENSURE_EQ(context, op_params.input2_offset, 0.0); |
232 | TF_LITE_ENSURE_EQ(context, op_params.output_offset, 0.0); |
233 | |
234 | if (need_broadcast) { |
235 | TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int16_t); |
236 | } else { |
237 | TF_LITE_MUL(reference_integer_ops, Mul, int16_t); |
238 | } |
239 | } else { |
240 | // type == kTfLiteUInt8 |
241 | if (kernel_type == kReference) { |
242 | if (need_broadcast) { |
243 | TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t); |
244 | } else { |
245 | TF_LITE_MUL(reference_ops, Mul, uint8_t); |
246 | } |
247 | } else { |
248 | if (need_broadcast) { |
249 | TF_LITE_MUL(optimized_ops, BroadcastMulDispatch, uint8_t); |
250 | } else { |
251 | TF_LITE_MUL(optimized_ops, Mul, uint8_t); |
252 | } |
253 | } |
254 | } |
255 | #undef TF_LITE_MUL |
256 | } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 && |
257 | (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8)) { |
258 | #define TF_LITE_MUL(type, opname, output_dtype) \ |
259 | tflite::ArithmeticParams op_params; \ |
260 | SetActivationParams(data->output_activation_min, \ |
261 | data->output_activation_max, &op_params); \ |
262 | op_params.output_offset = output->params.zero_point; \ |
263 | type::opname(op_params, GetTensorShape(input1), \ |
264 | GetTensorData<int16_t>(input1), GetTensorShape(input2), \ |
265 | GetTensorData<int16_t>(input2), GetTensorShape(output), \ |
266 | GetTensorData<output_dtype>(output)) |
267 | if (output->type == kTfLiteInt8) { |
268 | TF_LITE_MUL(reference_integer_ops, Mul, int8_t); |
269 | } else { |
270 | if (kernel_type == kReference) { |
271 | TF_LITE_MUL(reference_ops, Mul, uint8_t); |
272 | } else { |
273 | TF_LITE_MUL(optimized_ops, Mul, uint8_t); |
274 | } |
275 | } |
276 | #undef TF_LITE_MUL |
277 | } else { |
278 | TF_LITE_KERNEL_LOG( |
279 | context, "Unsupported combination of input and output types in Mul." ); |
280 | return kTfLiteError; |
281 | } |
282 | return kTfLiteOk; |
283 | } |
284 | |
285 | template <KernelType kernel_type> |
286 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
287 | auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data); |
288 | OpData* data = reinterpret_cast<OpData*>(node->user_data); |
289 | |
290 | const TfLiteTensor* input1; |
291 | TF_LITE_ENSURE_OK(context, |
292 | GetInputSafe(context, node, kInputTensor1, &input1)); |
293 | const TfLiteTensor* input2; |
294 | TF_LITE_ENSURE_OK(context, |
295 | GetInputSafe(context, node, kInputTensor2, &input2)); |
296 | TfLiteTensor* output; |
297 | TF_LITE_ENSURE_OK(context, |
298 | GetOutputSafe(context, node, kOutputTensor, &output)); |
299 | if (output->type == kTfLiteComplex64 && params->activation) { |
300 | TF_LITE_KERNEL_LOG(context, |
301 | "Activation is not allowed for COMPLEX64 input." ); |
302 | return kTfLiteError; |
303 | } |
304 | if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 || |
305 | output->type == kTfLiteInt64 || output->type == kTfLiteComplex64) { |
306 | EvalMul<kernel_type>(context, node, params, data, input1, input2, output); |
307 | } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || |
308 | output->type == kTfLiteInt16) { |
309 | TF_LITE_ENSURE_OK( |
310 | context, EvalQuantized<kernel_type>(context, node, params, data, input1, |
311 | input2, output)); |
312 | } else { |
313 | TF_LITE_KERNEL_LOG( |
314 | context, "Mul only supports FLOAT32, COMPLEX32, INT8, INT16," , |
315 | " INT32, INT64 and quantized UINT8 now, got %d." , output->type); |
316 | return kTfLiteError; |
317 | } |
318 | |
319 | return kTfLiteOk; |
320 | } |
321 | |
322 | } // namespace mul |
323 | |
324 | TfLiteRegistration* Register_MUL_REF() { |
325 | static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, |
326 | mul::Eval<mul::kReference>}; |
327 | return &r; |
328 | } |
329 | |
330 | TfLiteRegistration* Register_MUL_GENERIC_OPT() { |
331 | static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, |
332 | mul::Eval<mul::kGenericOptimized>}; |
333 | return &r; |
334 | } |
335 | |
336 | TfLiteRegistration* Register_MUL_NEON_OPT() { |
337 | static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, |
338 | mul::Eval<mul::kNeonOptimized>}; |
339 | return &r; |
340 | } |
341 | |
342 | TfLiteRegistration* Register_MUL() { |
343 | #ifdef USE_NEON |
344 | return Register_MUL_NEON_OPT(); |
345 | #else |
346 | return Register_MUL_GENERIC_OPT(); |
347 | #endif |
348 | } |
349 | |
350 | } // namespace builtin |
351 | } // namespace ops |
352 | } // namespace tflite |
353 | |