1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/kernels/internal/reference/batch_matmul.h"
17
18#include <stddef.h>
19
20#include <algorithm>
21#include <cstdint>
22#include <limits>
23
24#include "tensorflow/lite/c/builtin_op_data.h"
25#include "tensorflow/lite/c/common.h"
26#include "tensorflow/lite/kernels/cpu_backend_context.h"
27#include "tensorflow/lite/kernels/internal/compatibility.h"
28#include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h"
29#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
30#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31#include "tensorflow/lite/kernels/internal/tensor.h"
32#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
33#include "tensorflow/lite/kernels/internal/tensor_utils.h"
34#include "tensorflow/lite/kernels/internal/types.h"
35#include "tensorflow/lite/kernels/kernel_util.h"
36
37namespace tflite {
38namespace ops {
39namespace builtin {
40namespace batch_matmul {
41
42static const int kInputLHSTensor = 0;
43static const int kInputRHSTensor = 1;
44static const int kOutputTensor = 0;
45
46static const int kNumTempTensorsForAdjoints = 2;
47static const int kNumTempTensorsForHybrid = 5;
48
49// This file has two implementations of Transpose.
50enum KernelType {
51 kReference,
52 kGenericOptimized,
53};
54
55struct OpData {
56 // The scaling factor from input to output (aka the 'real multiplier') can
57 // be represented as a fixed point multiplier plus a left shift.
58 int32_t output_multiplier;
59 int output_shift;
60 // The range of the fused activation layer. For example for kNone and
61 // uint8_t these would be 0 and 255.
62 int32_t output_activation_min;
63 int32_t output_activation_max;
64 // The index of the temporary tensors where we store transposed LHS/RHS.
65 int scratch_tensor_index;
66 bool rhs_transposed;
67 bool compute_row_sums = false;
68};
69
70struct OpContext {
71 OpContext(TfLiteContext* context, TfLiteNode* node) {
72 params = reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
73 lhs = GetInput(context, node, kInputLHSTensor);
74 rhs = GetInput(context, node, kInputRHSTensor);
75 output = GetOutput(context, node, 0);
76 }
77 TfLiteBatchMatMulParams* params;
78 const TfLiteTensor* lhs;
79 const TfLiteTensor* rhs;
80 TfLiteTensor* output;
81};
82
83void* Init(TfLiteContext* context, const char* buffer, size_t length) {
84 auto* op_data = new OpData();
85 // If the RHS is constant, we only transpose once.
86 op_data->rhs_transposed = false;
87 // Creates the temp tensors to store the transposed LHS and/or RHS, and
88 // extra buffers for the quantized case.
89 context->AddTensors(context,
90 kNumTempTensorsForAdjoints + kNumTempTensorsForHybrid,
91 &op_data->scratch_tensor_index);
92 return op_data;
93}
94
95void Free(TfLiteContext* context, void* buffer) {
96 delete static_cast<OpData*>(buffer);
97}
98
99TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
100 const RuntimeShape& extended_lhs_shape,
101 const RuntimeShape& extended_rhs_shape,
102 bool adj_x, bool adj_y, int output_rank,
103 TfLiteTensor* output) {
104 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank);
105 // Fill in any broadcast dimensions.
106 for (int i = 0; i < output_rank - 2; ++i) {
107 const int lhs_dim = extended_lhs_shape.Dims(i);
108 const int rhs_dim = extended_rhs_shape.Dims(i);
109 int broadcast_dim = lhs_dim;
110 if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) {
111 broadcast_dim = rhs_dim;
112 }
113 output_shape->data[i] = broadcast_dim;
114 }
115 // Fill in the matmul dimensions.
116 int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
117 int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
118
119 output_shape->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index);
120 output_shape->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index);
121 TfLiteStatus stat = context->ResizeTensor(context, output, output_shape);
122 return stat;
123}
124
125// Initializes temp tensors to store transposed operands.
126TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node,
127 OpContext* op_context) {
128 // Create temporary tensors to hold transposed LHS/RHS.
129 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
130 const TfLiteTensor* lhs = op_context->lhs;
131 const TfLiteTensor* rhs = op_context->rhs;
132 TfLiteIntArrayFree(node->temporaries);
133 // For "hybrid" quantization, we impose the constraint that the LHS
134 // is float (typically an activation from a prior layer) and the RHS
135 // is quantized int8.
136 bool is_hybrid =
137 (op_context->lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8);
138 if (is_hybrid) {
139 node->temporaries = TfLiteIntArrayCreate(kNumTempTensorsForAdjoints +
140 kNumTempTensorsForHybrid);
141 } else {
142 node->temporaries = TfLiteIntArrayCreate(kNumTempTensorsForAdjoints);
143 }
144
145 const int lhs_rank = NumDimensions(lhs);
146 const int rhs_rank = NumDimensions(rhs);
147 const int batch_size = op_context->params->adj_x
148 ? lhs->dims->data[lhs_rank - 1]
149 : lhs->dims->data[lhs_rank - 2];
150 const int num_units = op_context->params->adj_y
151 ? rhs->dims->data[rhs_rank - 2]
152 : rhs->dims->data[rhs_rank - 1];
153
154 // Temp tensor for Transposed LHS;
155 {
156 node->temporaries->data[0] = op_data->scratch_tensor_index;
157 TfLiteTensor* scratch_buffer;
158 TF_LITE_ENSURE_OK(
159 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer));
160 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank);
161 for (int i = 0; i < lhs_rank - 2; ++i) {
162 scratch_buffer_size->data[i] = lhs->dims->data[i];
163 }
164 // Swap last two dimensions.
165 scratch_buffer_size->data[lhs_rank - 2] = lhs->dims->data[lhs_rank - 1];
166 scratch_buffer_size->data[lhs_rank - 1] = lhs->dims->data[lhs_rank - 2];
167
168 scratch_buffer->type = op_context->lhs->type;
169 scratch_buffer->allocation_type = kTfLiteArenaRw;
170 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
171 scratch_buffer_size));
172 }
173
174 // We need a temp buffer for the RHS if we need to transpose the RHS. We
175 // transpose by default, so that the two inputs (LHS and RHS) are in a proper
176 // layout for our fast matrix multiplication routines. If the transpose flag
177 // is set by the caller, the data is already in the desired layout.
178 {
179 node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
180 TfLiteTensor* scratch_buffer;
181 TF_LITE_ENSURE_OK(
182 context, GetTemporarySafe(context, node, /*index=*/1, &scratch_buffer));
183 scratch_buffer->name = "BatchMatMul_scratch_buffer";
184 const TfLiteTensor* rhs = op_context->rhs;
185 int rhs_rank = NumDimensions(rhs);
186 TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank);
187 for (int i = 0; i < rhs_rank - 2; ++i) {
188 scratch_buffer_size->data[i] = rhs->dims->data[i];
189 }
190 // Swap last two dimensions.
191 scratch_buffer_size->data[rhs_rank - 2] = rhs->dims->data[rhs_rank - 1];
192 scratch_buffer_size->data[rhs_rank - 1] = rhs->dims->data[rhs_rank - 2];
193
194 if (IsConstantTensor(op_context->rhs)) {
195 scratch_buffer->allocation_type = kTfLiteArenaRwPersistent;
196 } else {
197 scratch_buffer->allocation_type = kTfLiteArenaRw;
198 }
199 scratch_buffer->type = op_context->rhs->type;
200 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
201 scratch_buffer_size));
202 }
203
204 // If we have to perform on-the-fly quantization (with quantized weights and
205 // float inputs) first we need to quantize the inputs. Allocate temporary
206 // buffer to store the intermediate quantized values, the batch scaling
207 // factors, the accumulator buffer (optimized version), the input offsets,
208 // and the sums of the rows for each weights matrix.
209 // RHS = weights, LHS = inputs
210 if (is_hybrid) {
211 // Calculate the total number of LHS batches.
212 int num_batches = 1;
213 for (int i = 0; i < lhs_rank - 2; ++i) {
214 num_batches *= lhs->dims->data[i];
215 }
216 int num_weights_matrices = 1;
217 for (int i = 0; i < rhs_rank - 2; ++i) {
218 num_weights_matrices *= rhs->dims->data[i];
219 }
220 op_data->compute_row_sums = true;
221 node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
222 TfLiteTensor* input_quantized;
223 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
224 &input_quantized));
225 input_quantized->type = op_context->rhs->type;
226 input_quantized->allocation_type = kTfLiteArenaRw;
227
228 TfLiteIntArray* input_quantized_size =
229 TfLiteIntArrayCopy(op_context->lhs->dims);
230 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
231 input_quantized_size));
232
233 node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
234 TfLiteTensor* scaling_factors;
235 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
236 &scaling_factors));
237 scaling_factors->type = kTfLiteFloat32;
238 scaling_factors->allocation_type = kTfLiteArenaRw;
239 // Total size of scaling factors is batch size * number of total batches
240 int scaling_dims[1] = {num_batches * batch_size};
241 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
242 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
243 scaling_factors_size->data[0] = scaling_dims[0];
244 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
245 scaling_factors_size));
246 }
247
248 node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
249 TfLiteTensor* accum_scratch;
250 TF_LITE_ENSURE_OK(
251 context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
252 accum_scratch->type = kTfLiteInt32;
253 accum_scratch->allocation_type = kTfLiteArenaRw;
254 int accum_scratch_dims[2] = {num_units, batch_size};
255 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
256 accum_scratch_dims)) {
257 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
258 accum_size->data[0] = num_units;
259 accum_size->data[1] = batch_size;
260 TF_LITE_ENSURE_OK(
261 context, context->ResizeTensor(context, accum_scratch, accum_size));
262 }
263
264 node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
265 TfLiteTensor* input_offsets;
266 TF_LITE_ENSURE_OK(
267 context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
268 input_offsets->type = kTfLiteInt32;
269 input_offsets->allocation_type = kTfLiteArenaRw;
270 if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
271 TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
272 input_offsets_size->data[0] = num_batches * batch_size;
273 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
274 input_offsets_size));
275 }
276 node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
277 TfLiteTensor* row_sums;
278 TF_LITE_ENSURE_OK(context,
279 GetTemporarySafe(context, node, /*index=*/6, &row_sums));
280 row_sums->type = kTfLiteInt32;
281 row_sums->allocation_type = kTfLiteArenaRwPersistent;
282 int row_sums_dims[1] = {num_weights_matrices * num_units};
283 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
284 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
285 row_sums_size->data[0] = row_sums_dims[0];
286 TF_LITE_ENSURE_OK(
287 context, context->ResizeTensor(context, row_sums, row_sums_size));
288 }
289 }
290
291 return kTfLiteOk;
292}
293
294TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
295 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
296 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
297
298 OpContext op_context(context, node);
299 TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context));
300 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
301
302 bool adj_x = op_context.params->adj_x;
303 bool adj_y = op_context.params->adj_y;
304
305 const TfLiteTensor* lhs_data;
306 TF_LITE_ENSURE_OK(context,
307 GetInputSafe(context, node, kInputLHSTensor, &lhs_data));
308 const TfLiteTensor* rhs_data;
309 TF_LITE_ENSURE_OK(context,
310 GetInputSafe(context, node, kInputRHSTensor, &rhs_data));
311 TfLiteTensor* output;
312 TF_LITE_ENSURE_OK(context,
313 GetOutputSafe(context, node, kOutputTensor, &output));
314
315 // Note that quantized inference requires that all tensors have their
316 // parameters set. This is usually done during quantized training.
317 if ((lhs_data->type == kTfLiteInt8 || lhs_data->type == kTfLiteInt16) &&
318 output->type != kTfLiteInt32) {
319 double real_multiplier = 0.0;
320 TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
321 context, lhs_data, rhs_data, output, &real_multiplier));
322 int exponent;
323 QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent);
324 op_data->output_shift = exponent;
325 // BatchMatMul has no fused activation functions. Therefore, set
326 // output activation min and max to min and max of int8_t or int16_t
327 // type.
328 if (lhs_data->type == kTfLiteInt8) {
329 op_data->output_activation_min = std::numeric_limits<int8_t>::min();
330 op_data->output_activation_max = std::numeric_limits<int8_t>::max();
331 } else {
332 op_data->output_activation_min = std::numeric_limits<int16_t>::min();
333 op_data->output_activation_max = std::numeric_limits<int16_t>::max();
334 }
335 }
336
337 if (lhs_data->type == kTfLiteInt16) {
338 TF_LITE_ENSURE_EQ(context, lhs_data->params.zero_point, 0);
339 TF_LITE_ENSURE_EQ(context, rhs_data->params.zero_point, 0);
340 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
341 }
342
343 TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 ||
344 lhs_data->type == kTfLiteInt8 ||
345 lhs_data->type == kTfLiteInt16);
346 TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 ||
347 rhs_data->type == kTfLiteInt8 ||
348 rhs_data->type == kTfLiteInt16);
349 // Either we have a hybrid quantization with a float32 and an int8 input,
350 // otherwise both inputs should be of the same type.
351 TF_LITE_ENSURE(context, (lhs_data->type == kTfLiteFloat32 &&
352 rhs_data->type == kTfLiteInt8) ||
353 lhs_data->type == rhs_data->type);
354 // Support dimensions between 2 and 5, inclusive.
355 TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2);
356 TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5);
357 TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2);
358 TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5);
359
360 const int lhs_rank = NumDimensions(lhs_data);
361 const int rhs_rank = NumDimensions(rhs_data);
362 const int output_rank = std::max(lhs_rank, rhs_rank);
363 const RuntimeShape extended_lhs_shape =
364 RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data));
365 const RuntimeShape extended_rhs_shape =
366 RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data));
367
368 // Ensure any batch dimensions obey broacasting rules.
369 for (int i = 0; i < output_rank - 2; ++i) {
370 const int lhs_dim = extended_lhs_shape.Dims(i);
371 const int rhs_dim = extended_rhs_shape.Dims(i);
372 if (lhs_dim != rhs_dim) {
373 if (lhs_dim != 1) {
374 TF_LITE_ENSURE_EQ(context, rhs_dim, 1);
375 }
376 }
377 }
378 // Ensure other dimensions work for matrix multiplication.
379 int accum_dim_lhs = adj_x ? extended_lhs_shape.Dims(output_rank - 2)
380 : extended_lhs_shape.Dims(output_rank - 1);
381 int accum_dim_rhs = adj_y ? extended_rhs_shape.Dims(output_rank - 1)
382 : extended_rhs_shape.Dims(output_rank - 2);
383
384 TF_LITE_ENSURE_EQ(context, accum_dim_lhs, accum_dim_rhs);
385 TfLiteStatus status =
386 ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape, adj_x,
387 adj_y, output_rank, output);
388 return status;
389}
390
391template <typename scalar>
392void TransposeRowsColumnsImpl(const TfLiteTensor* tensor_in,
393 const scalar* input, TfLiteTensor* tensor_out,
394 scalar* output) {
395 RuntimeShape transposed_shape(GetTensorShape(tensor_in));
396 RuntimeShape shape(GetTensorShape(tensor_in));
397 TransposeParams params;
398 int rank = NumDimensions(tensor_in);
399 params.perm_count = rank;
400 for (int i = 0; i < rank - 2; ++i) {
401 params.perm[i] = i;
402 }
403 // Transpose the last two dimensions.
404 params.perm[rank - 2] = rank - 1;
405 params.perm[rank - 1] = rank - 2;
406 transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2));
407 transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1));
408 optimized_ops::Transpose(params, shape, input, transposed_shape, output);
409}
410
411TfLiteStatus TransposeRowsColumns(TfLiteContext* context,
412 const TfLiteTensor* tensor_in,
413 TfLiteTensor* tensor_out) {
414 if (tensor_in->type == kTfLiteFloat32) {
415 TransposeRowsColumnsImpl<float>(tensor_in, GetTensorData<float>(tensor_in),
416 tensor_out,
417 GetTensorData<float>(tensor_out));
418 return kTfLiteOk;
419 } else if (tensor_in->type == kTfLiteInt8) {
420 TransposeRowsColumnsImpl<int8_t>(
421 tensor_in, GetTensorData<int8_t>(tensor_in), tensor_out,
422 GetTensorData<int8_t>(tensor_out));
423 return kTfLiteOk;
424 } else if (tensor_in->type == kTfLiteInt16) {
425 TransposeRowsColumnsImpl<int16_t>(
426 tensor_in, GetTensorData<int16_t>(tensor_in), tensor_out,
427 GetTensorData<int16_t>(tensor_out));
428 return kTfLiteOk;
429 } else {
430 TF_LITE_KERNEL_LOG(
431 context, "Can only transpose tensors with float, int8 or int16 type.");
432 return kTfLiteError;
433 }
434}
435
436RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) {
437 RuntimeShape swapped_shape(shape);
438 const int32_t dims = shape.DimensionsCount();
439 swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1));
440 swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2));
441 return swapped_shape;
442}
443
444template <KernelType kernel_type>
445TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data,
446 const RuntimeShape& input_shape,
447 const TfLiteTensor* input,
448 const RuntimeShape& filter_shape,
449 const TfLiteTensor* filter,
450 TfLiteTensor* input_quantized,
451 TfLiteTensor* scaling_factors,
452 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
453 TfLiteTensor* input_offsets, TfLiteTensor* output) {
454 const auto* params =
455 reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data);
456 const int32_t num_input_dims = input_shape.DimensionsCount();
457
458 // Input row/cols have been swapped at this point, so dims are
459 // {input_size, num_batches}
460 const int input_size = input_shape.Dims(num_input_dims - 2);
461 const int batch_size = input_shape.Dims(num_input_dims - 1);
462
463 int num_batches_to_quantize = batch_size;
464 for (int i = 0; i < input_shape.DimensionsCount() - 2; ++i) {
465 num_batches_to_quantize *= input_shape.Dims(i);
466 }
467 // Quantize input from float to uint8 + quantization params (scaling factor).
468 const int scaling_factor_size = GetTensorShape(scaling_factors).FlatSize();
469 TF_LITE_ENSURE(context, scaling_factor_size >= num_batches_to_quantize);
470 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
471 int32_t* input_offset_ptr = nullptr;
472 int32_t* row_sums_ptr = nullptr;
473 input_offset_ptr = GetTensorData<int32_t>(input_offsets);
474 row_sums_ptr = GetTensorData<int32_t>(row_sums);
475 if (!params->asymmetric_quantize_inputs) {
476 memset(input_offset_ptr, 0, input_offsets->bytes);
477 }
478 int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
479 const int8_t* filter_data = GetTensorData<int8_t>(filter);
480 const float* input_ptr = GetTensorData<float>(input);
481 // Quantize each batch independently.
482 tensor_utils::BatchQuantizeFloats(input_ptr, num_batches_to_quantize,
483 input_size, quant_data, scaling_factors_ptr,
484 input_offset_ptr,
485 params->asymmetric_quantize_inputs);
486 for (int b = 0; b < num_batches_to_quantize; ++b) {
487 // Incorporate scaling of the filter.
488 scaling_factors_ptr[b] *= filter->params.scale;
489 }
490
491 RuntimeShape output_shape = GetTensorShape(output);
492 int output_size = 1;
493 for (int i = 0; i < output_shape.DimensionsCount(); ++i) {
494 output_size *= output_shape.Dims(i);
495 }
496 std::fill_n(GetTensorData<float>(output), output_size, 0.0f);
497 if (kernel_type == kGenericOptimized) {
498 optimized_ops::BatchMatMul(
499 filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr,
500 input_offset_ptr, row_sums_ptr, GetTensorShape(output),
501 GetTensorData<int32_t>(accum_scratch), GetTensorData<float>(output),
502 &(data->compute_row_sums), CpuBackendContext::GetFromContext(context));
503 } else {
504 reference_ops::BatchMatMul(
505 filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr,
506 input_offset_ptr, row_sums_ptr, GetTensorShape(output),
507 GetTensorData<float>(output), &(data->compute_row_sums));
508 }
509
510 return kTfLiteOk;
511}
512
513template <KernelType kernel_type>
514TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data,
515 const RuntimeShape& lhs_shape,
516 const TfLiteTensor* lhs,
517 const RuntimeShape& rhs_shape,
518 const TfLiteTensor* rhs,
519 const RuntimeShape& output_shape,
520 TfLiteTensor* output) {
521 // Reuse params struct from FullyConnected Op.
522 FullyConnectedParams op_params;
523 int32_t input_offset = -lhs->params.zero_point;
524 int32_t filter_offset = -rhs->params.zero_point;
525 int32_t output_offset = output->params.zero_point;
526 op_params.input_offset = input_offset;
527 op_params.weights_offset = filter_offset;
528 op_params.output_offset = output_offset;
529 op_params.output_multiplier = data->output_multiplier;
530 op_params.output_shift = data->output_shift;
531 op_params.quantized_activation_min = data->output_activation_min;
532 op_params.quantized_activation_max = data->output_activation_max;
533 op_params.lhs_cacheable = IsConstantTensor(lhs);
534 op_params.rhs_cacheable = IsConstantTensor(rhs);
535
536 if (kernel_type == kReference) {
537 reference_ops::BatchMatMul<int8_t, int32_t>(
538 op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape,
539 GetTensorData<int8_t>(lhs), GetTensorShape(output),
540 GetTensorData<int8_t>(output));
541 } else {
542 optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
543 lhs_shape, GetTensorData<int8_t>(lhs),
544 GetTensorShape(output),
545 GetTensorData<int8_t>(output),
546 CpuBackendContext::GetFromContext(context));
547 }
548 return kTfLiteOk;
549}
550
551template <KernelType kernel_type>
552TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data,
553 const RuntimeShape& lhs_shape,
554 const TfLiteTensor* lhs,
555 const RuntimeShape& rhs_shape,
556 const TfLiteTensor* rhs,
557 const RuntimeShape& output_shape,
558 TfLiteTensor* output) {
559 // Reuse params struct from FullyConnected Op.
560 FullyConnectedParams op_params;
561 int32_t input_offset = -lhs->params.zero_point;
562 int32_t weights_offset = -rhs->params.zero_point;
563 int32_t output_offset = output->params.zero_point;
564 op_params.input_offset = input_offset;
565 op_params.weights_offset = weights_offset;
566 op_params.output_offset = output_offset;
567 op_params.output_multiplier = data->output_multiplier;
568 op_params.output_shift = data->output_shift;
569 op_params.quantized_activation_min = data->output_activation_min;
570 op_params.quantized_activation_max = data->output_activation_max;
571 op_params.lhs_cacheable = IsConstantTensor(lhs);
572 op_params.rhs_cacheable = IsConstantTensor(rhs);
573
574 // Set BatchMatMul lhs param to rhs(filter) and rhs param to lhs(input). For
575 // the reason, see comment of Eval() function.
576 if (kernel_type == kReference) {
577 reference_ops::BatchMatMul<int8, int8, int32>(
578 rhs_shape, GetTensorData<int8>(rhs), lhs_shape,
579 GetTensorData<int8>(lhs), GetTensorShape(output),
580 GetTensorData<int32>(output));
581 } else {
582 optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs),
583 lhs_shape, GetTensorData<int8_t>(lhs),
584 GetTensorShape(output),
585 GetTensorData<int32_t>(output),
586 CpuBackendContext::GetFromContext(context));
587 }
588 return kTfLiteOk;
589}
590
591template <KernelType kernel_type>
592TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data,
593 const RuntimeShape& lhs_shape, const TfLiteTensor* lhs,
594 const RuntimeShape& rhs_shape, const TfLiteTensor* rhs,
595 const RuntimeShape& output_shape, TfLiteTensor* output) {
596 // Reuse params struct from FullyConnected Op.
597 FullyConnectedParams op_params;
598 int32_t input_offset = -lhs->params.zero_point;
599 int32_t filter_offset = -rhs->params.zero_point;
600 int32_t output_offset = output->params.zero_point;
601 op_params.input_offset = input_offset;
602 op_params.weights_offset = filter_offset;
603 op_params.output_offset = output_offset;
604 op_params.output_multiplier = data->output_multiplier;
605 op_params.output_shift = data->output_shift;
606 op_params.quantized_activation_min = data->output_activation_min;
607 op_params.quantized_activation_max = data->output_activation_max;
608
609 // optimized_ops not yet implemnted for int16_t, use reference_ops in all
610 // cases.
611 reference_ops::BatchMatMul<int16_t, int64_t>(
612 op_params, rhs_shape, GetTensorData<int16_t>(rhs), lhs_shape,
613 GetTensorData<int16_t>(lhs), GetTensorShape(output),
614 GetTensorData<int16_t>(output));
615 return kTfLiteOk;
616}
617
618template <KernelType kernel_type>
619TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
620 OpData* data, const RuntimeShape& lhs_shape,
621 const TfLiteTensor* lhs,
622 const RuntimeShape& rhs_shape,
623 const TfLiteTensor* rhs, TfLiteTensor* output) {
624 if (lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) {
625 TfLiteTensor* input_quantized;
626 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
627 &input_quantized));
628 TfLiteTensor* scaling_factors;
629 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
630 &scaling_factors));
631 TfLiteTensor* accum_scratch;
632 TF_LITE_ENSURE_OK(
633 context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch));
634 TfLiteTensor* input_offsets;
635 TF_LITE_ENSURE_OK(
636 context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets));
637 TfLiteTensor* row_sums;
638 TF_LITE_ENSURE_OK(context,
639 GetTemporarySafe(context, node, /*index=*/6, &row_sums));
640 return EvalHybrid<kernel_type>(
641 context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized,
642 scaling_factors, accum_scratch, row_sums, input_offsets, output);
643 } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) {
644 if (output->type == kTfLiteInt8) {
645 return EvalInt8Int8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape,
646 rhs, GetTensorShape(output), output);
647 } else {
648 return EvalInt8Int32<kernel_type>(context, data, lhs_shape, lhs,
649 rhs_shape, rhs, GetTensorShape(output),
650 output);
651 }
652 } else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) {
653 return EvalInt16<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs,
654 GetTensorShape(output), output);
655 } else {
656 TF_LITE_KERNEL_LOG(
657 context,
658 "Currently only hybrid, int8 and int16 quantization are supported.\n");
659 return kTfLiteError;
660 }
661 return kTfLiteOk;
662}
663
664TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node,
665 const TfLiteTensor* rhs) {
666 TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1);
667 if (transposed_rhs == nullptr) {
668 return nullptr;
669 }
670
671 if (rhs->type == kTfLiteInt8 || rhs->type == kTfLiteInt16) {
672 // Get the quantization params from the RHS tensor.
673 transposed_rhs->params.scale = rhs->params.scale;
674 transposed_rhs->params.zero_point = rhs->params.zero_point;
675 }
676 return transposed_rhs;
677}
678
679TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node,
680 const TfLiteTensor* lhs) {
681 TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0);
682 if (transposed_lhs == nullptr) {
683 return nullptr;
684 }
685
686 if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) {
687 // Get the quantization params from the LHS tensor.
688 transposed_lhs->params.scale = lhs->params.scale;
689 transposed_lhs->params.zero_point = lhs->params.zero_point;
690 }
691 return transposed_lhs;
692}
693
694// Perform a batch matrix multiply on
695// LHS <..., A, B> X RHS<..., B, C>
696// where the leading dimensions of LHS and RHS obey broadcasting rules
697// (this Op will apply broadcasting rules).
698// We assume that LHS and RHS are both row oriented (adjacent values in memory
699// are in the same row) and will output in the same memory layout. However,
700// our fast GEMM libraries assume RCC layout (LHS row oriented,
701// RHS column oriented, output column oriented). Therefore, we perform
702// RHS <..., C, B> X LHS <..., B, A>
703// where output is a C X A column-oriented, which is equivalent to
704// A X C row-oriented.
705template <KernelType kernel_type>
706TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
707 OpContext op_context(context, node);
708 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
709 const TfLiteTensor* lhs;
710 TF_LITE_ENSURE_OK(context,
711 GetInputSafe(context, node, kInputLHSTensor, &lhs));
712 const TfLiteTensor* rhs;
713 TF_LITE_ENSURE_OK(context,
714 GetInputSafe(context, node, kInputRHSTensor, &rhs));
715 TfLiteTensor* output;
716 TF_LITE_ENSURE_OK(context,
717 GetOutputSafe(context, node, kOutputTensor, &output));
718 RuntimeShape orig_lhs_shape = GetTensorShape(lhs);
719 RuntimeShape orig_rhs_shape = GetTensorShape(rhs);
720
721 bool adj_y = op_context.params->adj_y;
722 bool adj_x = op_context.params->adj_x;
723
724 const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetTempRhs(context, node, rhs);
725 const TfLiteTensor* lhs_tensor = adj_x ? GetTempLhs(context, node, lhs) : lhs;
726 if (!adj_y) {
727 // TODO(b/154760341) Constant tensors should already be transposed, but
728 // we transpose once if necessary for now.
729 if (!(IsConstantTensor(rhs) && op_data->rhs_transposed)) {
730 TransposeRowsColumns(context, rhs, GetTemporary(context, node, 1));
731 op_data->rhs_transposed = true;
732 }
733 }
734 if (adj_x) {
735 TransposeRowsColumns(context, lhs, GetTemporary(context, node, 0));
736 }
737 RuntimeShape rhs_shape =
738 adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape);
739 RuntimeShape lhs_shape =
740 adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape);
741
742 switch (rhs->type) {
743 case kTfLiteFloat32:
744 // Note we pass RHS args first, LHS args second. See note above.
745 if (kernel_type == kGenericOptimized) {
746 optimized_ops::BatchMatMul(rhs_shape, GetTensorData<float>(rhs_tensor),
747 lhs_shape, GetTensorData<float>(lhs_tensor),
748 GetTensorShape(output),
749 GetTensorData<float>(output),
750 CpuBackendContext::GetFromContext(context));
751 } else {
752 reference_ops::BatchMatMul(rhs_shape, GetTensorData<float>(rhs_tensor),
753 lhs_shape, GetTensorData<float>(lhs_tensor),
754 GetTensorShape(output),
755 GetTensorData<float>(output));
756 }
757 break;
758 case kTfLiteInt8:
759 case kTfLiteInt16:
760 EvalQuantized<kernel_type>(context, node, op_data, lhs_shape, lhs_tensor,
761 rhs_shape, rhs_tensor, output);
762 break;
763 default:
764 TF_LITE_KERNEL_LOG(context,
765 "Currently BatchMatMul doesn't support type: %s",
766 TfLiteTypeGetName(lhs->type));
767 return kTfLiteError;
768 }
769 return kTfLiteOk;
770}
771
772} // namespace batch_matmul
773
774TfLiteRegistration* Register_BATCH_MATMUL_REF() {
775 static TfLiteRegistration r = {batch_matmul::Init, batch_matmul::Free,
776 batch_matmul::Prepare,
777 batch_matmul::Eval<batch_matmul::kReference>};
778 return &r;
779}
780
781TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() {
782 static TfLiteRegistration r = {
783 batch_matmul::Init, batch_matmul::Free, batch_matmul::Prepare,
784 batch_matmul::Eval<batch_matmul::kGenericOptimized>};
785 return &r;
786}
787
788TfLiteRegistration* Register_BATCH_MATMUL() {
789 return Register_BATCH_MATMUL_GENERIC_OPTIMIZED();
790}
791
792} // namespace builtin
793} // namespace ops
794} // namespace tflite
795