1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/internal/optimized/integer_ops/mul.h"
16
17#include <stddef.h>
18#include <stdint.h>
19
20#include <complex>
21
22#include "tensorflow/lite/c/builtin_op_data.h"
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/internal/compatibility.h"
25#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
26#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
27#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
28#include "tensorflow/lite/kernels/internal/quantization_util.h"
29#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
30#include "tensorflow/lite/kernels/internal/reference/mul.h"
31#include "tensorflow/lite/kernels/internal/reference/process_broadcast_shapes.h"
32#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
33#include "tensorflow/lite/kernels/internal/tensor.h"
34#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
35#include "tensorflow/lite/kernels/internal/types.h"
36#include "tensorflow/lite/kernels/kernel_util.h"
37
38namespace tflite {
39namespace ops {
40namespace builtin {
41namespace mul {
42
43// This file has three implementation of Mul.
44enum KernelType {
45 kReference,
46 kGenericOptimized, // Neon-free
47 kNeonOptimized,
48};
49
50constexpr int kInputTensor1 = 0;
51constexpr int kInputTensor2 = 1;
52constexpr int kOutputTensor = 0;
53
54struct OpData {
55 // Parameters used in the quantized paths where the output is 8bit
56 int32 output_activation_min;
57 int32 output_activation_max;
58
59 // Parameters used in all quantized paths
60 int32_t output_multiplier;
61 int output_shift;
62};
63
64void* Init(TfLiteContext* context, const char* buffer, size_t length) {
65 auto* data = new OpData;
66 return data;
67}
68
69void Free(TfLiteContext* context, void* buffer) {
70 delete reinterpret_cast<OpData*>(buffer);
71}
72
73TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
74 auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
75 OpData* data = reinterpret_cast<OpData*>(node->user_data);
76
77 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
78 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
79
80 const TfLiteTensor* input1;
81 TF_LITE_ENSURE_OK(context,
82 GetInputSafe(context, node, kInputTensor1, &input1));
83 const TfLiteTensor* input2;
84 TF_LITE_ENSURE_OK(context,
85 GetInputSafe(context, node, kInputTensor2, &input2));
86 TfLiteTensor* output;
87 TF_LITE_ENSURE_OK(context,
88 GetOutputSafe(context, node, kOutputTensor, &output));
89
90 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
91
92 const bool requires_broadcast = !HaveSameShapes(input1, input2);
93
94 TfLiteIntArray* output_size = nullptr;
95 if (requires_broadcast) {
96 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
97 context, input1, input2, &output_size));
98 } else {
99 output_size = TfLiteIntArrayCopy(input1->dims);
100 }
101
102 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
103 output->type == kTfLiteInt16) {
104 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
105 context, params->activation, output, &data->output_activation_min,
106 &data->output_activation_max));
107 double real_multiplier =
108 input1->params.scale * input2->params.scale / output->params.scale;
109 QuantizeMultiplier(real_multiplier, &data->output_multiplier,
110 &data->output_shift);
111 }
112
113 return context->ResizeTensor(context, output, output_size);
114}
115
116template <KernelType kernel_type>
117void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
118 const OpData* data, const TfLiteTensor* input1,
119 const TfLiteTensor* input2, TfLiteTensor* output) {
120 tflite::ArithmeticParams op_params;
121 const bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
122 GetTensorShape(input1), GetTensorShape(input2), &op_params);
123#define TF_LITE_MUL(type, opname, data_type) \
124 data_type output_activation_min, output_activation_max; \
125 CalculateActivationRange(params->activation, &output_activation_min, \
126 &output_activation_max); \
127 SetActivationParams(output_activation_min, output_activation_max, \
128 &op_params); \
129 type::opname(op_params, GetTensorShape(input1), \
130 GetTensorData<data_type>(input1), GetTensorShape(input2), \
131 GetTensorData<data_type>(input2), GetTensorShape(output), \
132 GetTensorData<data_type>(output))
133
134 if (output->type == kTfLiteInt32) {
135 if (kernel_type == kReference) {
136 if (need_broadcast) {
137 TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int32_t);
138 } else {
139 TF_LITE_MUL(reference_ops, Mul, int32_t);
140 }
141 } else {
142 if (need_broadcast) {
143 TF_LITE_MUL(optimized_ops, BroadcastMul4DSlow, int32_t);
144 } else {
145 TF_LITE_MUL(optimized_ops, Mul, int32_t);
146 }
147 }
148 } else if (output->type == kTfLiteFloat32) {
149 if (kernel_type == kReference) {
150 if (need_broadcast) {
151 TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, float);
152 } else {
153 TF_LITE_MUL(reference_ops, Mul, float);
154 }
155 } else {
156 if (need_broadcast) {
157 TF_LITE_MUL(optimized_ops, BroadcastMulDispatch, float);
158 } else {
159 TF_LITE_MUL(optimized_ops, Mul, float);
160 }
161 }
162 } else if (output->type == kTfLiteInt64) {
163 if (need_broadcast) {
164 TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, int64_t);
165 } else {
166 TF_LITE_MUL(reference_ops, Mul, int64_t);
167 }
168#undef TF_LITE_MUL
169 } else if (output->type == kTfLiteComplex64) {
170#define TF_LITE_MUL_COMPLEX(op_name) \
171 reference_ops::op_name( \
172 op_params, GetTensorShape(input1), \
173 GetTensorData<std::complex<float>>(input1), GetTensorShape(input2), \
174 GetTensorData<std::complex<float>>(input2), GetTensorShape(output), \
175 GetTensorData<std::complex<float>>(output));
176
177 if (need_broadcast) {
178 TF_LITE_MUL_COMPLEX(BroadcastMul4DSlow);
179 } else {
180 TF_LITE_MUL_COMPLEX(Mul);
181 }
182#undef TF_LITE_MUL_COMPLEX
183 }
184}
185
186template <KernelType kernel_type>
187TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
188 TfLiteMulParams* params, const OpData* data,
189 const TfLiteTensor* input1,
190 const TfLiteTensor* input2, TfLiteTensor* output) {
191 if (input1->type == input2->type && input1->type == output->type &&
192 (input1->type == kTfLiteUInt8 || input1->type == kTfLiteInt8 ||
193 input1->type == kTfLiteInt16)) {
194 tflite::ArithmeticParams op_params;
195 SetActivationParams(data->output_activation_min,
196 data->output_activation_max, &op_params);
197 op_params.input1_offset = -input1->params.zero_point;
198 op_params.input2_offset = -input2->params.zero_point;
199 op_params.output_offset = output->params.zero_point;
200 op_params.output_multiplier = data->output_multiplier;
201 op_params.output_shift = data->output_shift;
202 bool need_broadcast = optimized_ops::ProcessBroadcastShapes(
203 GetTensorShape(input1), GetTensorShape(input2), &op_params);
204#define TF_LITE_MUL(type, opname, dtype) \
205 type::opname(op_params, GetTensorShape(input1), \
206 GetTensorData<dtype>(input1), GetTensorShape(input2), \
207 GetTensorData<dtype>(input2), GetTensorShape(output), \
208 GetTensorData<dtype>(output))
209 if (input1->type == kTfLiteInt8) {
210 if (kernel_type == kReference) {
211 if (need_broadcast) {
212 TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int8_t);
213 } else {
214 TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
215 }
216 } else {
217 if (need_broadcast) {
218 TF_LITE_MUL(optimized_integer_ops, BroadcastMulDispatch, int8_t);
219 } else {
220 TF_LITE_MUL(optimized_integer_ops, Mul, int8_t);
221 }
222 }
223 } else if (input1->type == kTfLiteInt16) {
224 // We have this check, because in case of int16
225 // input1_val*input2_val can overflow int32:
226 // see MulElementwise -
227 // tensorflow/lite/kernels/internal/reference/integer_ops/mul.h in case of
228 // 16-bit this function is used in symmetric quantization, so offset
229 // should be zero.
230 TF_LITE_ENSURE_EQ(context, op_params.input1_offset, 0.0);
231 TF_LITE_ENSURE_EQ(context, op_params.input2_offset, 0.0);
232 TF_LITE_ENSURE_EQ(context, op_params.output_offset, 0.0);
233
234 if (need_broadcast) {
235 TF_LITE_MUL(reference_integer_ops, BroadcastMul4DSlow, int16_t);
236 } else {
237 TF_LITE_MUL(reference_integer_ops, Mul, int16_t);
238 }
239 } else {
240 // type == kTfLiteUInt8
241 if (kernel_type == kReference) {
242 if (need_broadcast) {
243 TF_LITE_MUL(reference_ops, BroadcastMul4DSlow, uint8_t);
244 } else {
245 TF_LITE_MUL(reference_ops, Mul, uint8_t);
246 }
247 } else {
248 if (need_broadcast) {
249 TF_LITE_MUL(optimized_ops, BroadcastMulDispatch, uint8_t);
250 } else {
251 TF_LITE_MUL(optimized_ops, Mul, uint8_t);
252 }
253 }
254 }
255#undef TF_LITE_MUL
256 } else if (input1->type == kTfLiteInt16 && input2->type == kTfLiteInt16 &&
257 (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8)) {
258#define TF_LITE_MUL(type, opname, output_dtype) \
259 tflite::ArithmeticParams op_params; \
260 SetActivationParams(data->output_activation_min, \
261 data->output_activation_max, &op_params); \
262 op_params.output_offset = output->params.zero_point; \
263 type::opname(op_params, GetTensorShape(input1), \
264 GetTensorData<int16_t>(input1), GetTensorShape(input2), \
265 GetTensorData<int16_t>(input2), GetTensorShape(output), \
266 GetTensorData<output_dtype>(output))
267 if (output->type == kTfLiteInt8) {
268 TF_LITE_MUL(reference_integer_ops, Mul, int8_t);
269 } else {
270 if (kernel_type == kReference) {
271 TF_LITE_MUL(reference_ops, Mul, uint8_t);
272 } else {
273 TF_LITE_MUL(optimized_ops, Mul, uint8_t);
274 }
275 }
276#undef TF_LITE_MUL
277 } else {
278 TF_LITE_KERNEL_LOG(
279 context, "Unsupported combination of input and output types in Mul.");
280 return kTfLiteError;
281 }
282 return kTfLiteOk;
283}
284
285template <KernelType kernel_type>
286TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
287 auto* params = reinterpret_cast<TfLiteMulParams*>(node->builtin_data);
288 OpData* data = reinterpret_cast<OpData*>(node->user_data);
289
290 const TfLiteTensor* input1;
291 TF_LITE_ENSURE_OK(context,
292 GetInputSafe(context, node, kInputTensor1, &input1));
293 const TfLiteTensor* input2;
294 TF_LITE_ENSURE_OK(context,
295 GetInputSafe(context, node, kInputTensor2, &input2));
296 TfLiteTensor* output;
297 TF_LITE_ENSURE_OK(context,
298 GetOutputSafe(context, node, kOutputTensor, &output));
299 if (output->type == kTfLiteComplex64 && params->activation) {
300 TF_LITE_KERNEL_LOG(context,
301 "Activation is not allowed for COMPLEX64 input.");
302 return kTfLiteError;
303 }
304 if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
305 output->type == kTfLiteInt64 || output->type == kTfLiteComplex64) {
306 EvalMul<kernel_type>(context, node, params, data, input1, input2, output);
307 } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
308 output->type == kTfLiteInt16) {
309 TF_LITE_ENSURE_OK(
310 context, EvalQuantized<kernel_type>(context, node, params, data, input1,
311 input2, output));
312 } else {
313 TF_LITE_KERNEL_LOG(
314 context, "Mul only supports FLOAT32, COMPLEX32, INT8, INT16,",
315 " INT32, INT64 and quantized UINT8 now, got %d.", output->type);
316 return kTfLiteError;
317 }
318
319 return kTfLiteOk;
320}
321
322} // namespace mul
323
324TfLiteRegistration* Register_MUL_REF() {
325 static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
326 mul::Eval<mul::kReference>};
327 return &r;
328}
329
330TfLiteRegistration* Register_MUL_GENERIC_OPT() {
331 static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
332 mul::Eval<mul::kGenericOptimized>};
333 return &r;
334}
335
336TfLiteRegistration* Register_MUL_NEON_OPT() {
337 static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare,
338 mul::Eval<mul::kNeonOptimized>};
339 return &r;
340}
341
342TfLiteRegistration* Register_MUL() {
343#ifdef USE_NEON
344 return Register_MUL_NEON_OPT();
345#else
346 return Register_MUL_GENERIC_OPT();
347#endif
348}
349
350} // namespace builtin
351} // namespace ops
352} // namespace tflite
353