1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/lite/kernels/internal/reference/batch_matmul.h" |
17 | |
18 | #include <stddef.h> |
19 | |
20 | #include <algorithm> |
21 | #include <cstdint> |
22 | #include <limits> |
23 | |
24 | #include "tensorflow/lite/c/builtin_op_data.h" |
25 | #include "tensorflow/lite/c/common.h" |
26 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
27 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
28 | #include "tensorflow/lite/kernels/internal/optimized/batch_matmul.h" |
29 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
30 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
31 | #include "tensorflow/lite/kernels/internal/tensor.h" |
32 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
33 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
34 | #include "tensorflow/lite/kernels/internal/types.h" |
35 | #include "tensorflow/lite/kernels/kernel_util.h" |
36 | |
37 | namespace tflite { |
38 | namespace ops { |
39 | namespace builtin { |
40 | namespace batch_matmul { |
41 | |
42 | static const int kInputLHSTensor = 0; |
43 | static const int kInputRHSTensor = 1; |
44 | static const int kOutputTensor = 0; |
45 | |
46 | static const int kNumTempTensorsForAdjoints = 2; |
47 | static const int kNumTempTensorsForHybrid = 5; |
48 | |
49 | // This file has two implementations of Transpose. |
50 | enum KernelType { |
51 | kReference, |
52 | kGenericOptimized, |
53 | }; |
54 | |
55 | struct OpData { |
56 | // The scaling factor from input to output (aka the 'real multiplier') can |
57 | // be represented as a fixed point multiplier plus a left shift. |
58 | int32_t output_multiplier; |
59 | int output_shift; |
60 | // The range of the fused activation layer. For example for kNone and |
61 | // uint8_t these would be 0 and 255. |
62 | int32_t output_activation_min; |
63 | int32_t output_activation_max; |
64 | // The index of the temporary tensors where we store transposed LHS/RHS. |
65 | int scratch_tensor_index; |
66 | bool rhs_transposed; |
67 | bool compute_row_sums = false; |
68 | }; |
69 | |
70 | struct OpContext { |
71 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
72 | params = reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data); |
73 | lhs = GetInput(context, node, kInputLHSTensor); |
74 | rhs = GetInput(context, node, kInputRHSTensor); |
75 | output = GetOutput(context, node, 0); |
76 | } |
77 | TfLiteBatchMatMulParams* params; |
78 | const TfLiteTensor* lhs; |
79 | const TfLiteTensor* rhs; |
80 | TfLiteTensor* output; |
81 | }; |
82 | |
83 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
84 | auto* op_data = new OpData(); |
85 | // If the RHS is constant, we only transpose once. |
86 | op_data->rhs_transposed = false; |
87 | // Creates the temp tensors to store the transposed LHS and/or RHS, and |
88 | // extra buffers for the quantized case. |
89 | context->AddTensors(context, |
90 | kNumTempTensorsForAdjoints + kNumTempTensorsForHybrid, |
91 | &op_data->scratch_tensor_index); |
92 | return op_data; |
93 | } |
94 | |
95 | void Free(TfLiteContext* context, void* buffer) { |
96 | delete static_cast<OpData*>(buffer); |
97 | } |
98 | |
99 | TfLiteStatus ResizeOutputTensor(TfLiteContext* context, |
100 | const RuntimeShape& extended_lhs_shape, |
101 | const RuntimeShape& extended_rhs_shape, |
102 | bool adj_x, bool adj_y, int output_rank, |
103 | TfLiteTensor* output) { |
104 | TfLiteIntArray* output_shape = TfLiteIntArrayCreate(output_rank); |
105 | // Fill in any broadcast dimensions. |
106 | for (int i = 0; i < output_rank - 2; ++i) { |
107 | const int lhs_dim = extended_lhs_shape.Dims(i); |
108 | const int rhs_dim = extended_rhs_shape.Dims(i); |
109 | int broadcast_dim = lhs_dim; |
110 | if ((lhs_dim != rhs_dim) && (lhs_dim == 1)) { |
111 | broadcast_dim = rhs_dim; |
112 | } |
113 | output_shape->data[i] = broadcast_dim; |
114 | } |
115 | // Fill in the matmul dimensions. |
116 | int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2; |
117 | int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1; |
118 | |
119 | output_shape->data[output_rank - 2] = extended_lhs_shape.Dims(lhs_rows_index); |
120 | output_shape->data[output_rank - 1] = extended_rhs_shape.Dims(rhs_cols_index); |
121 | TfLiteStatus stat = context->ResizeTensor(context, output, output_shape); |
122 | return stat; |
123 | } |
124 | |
125 | // Initializes temp tensors to store transposed operands. |
126 | TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, |
127 | OpContext* op_context) { |
128 | // Create temporary tensors to hold transposed LHS/RHS. |
129 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
130 | const TfLiteTensor* lhs = op_context->lhs; |
131 | const TfLiteTensor* rhs = op_context->rhs; |
132 | TfLiteIntArrayFree(node->temporaries); |
133 | // For "hybrid" quantization, we impose the constraint that the LHS |
134 | // is float (typically an activation from a prior layer) and the RHS |
135 | // is quantized int8. |
136 | bool is_hybrid = |
137 | (op_context->lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8); |
138 | if (is_hybrid) { |
139 | node->temporaries = TfLiteIntArrayCreate(kNumTempTensorsForAdjoints + |
140 | kNumTempTensorsForHybrid); |
141 | } else { |
142 | node->temporaries = TfLiteIntArrayCreate(kNumTempTensorsForAdjoints); |
143 | } |
144 | |
145 | const int lhs_rank = NumDimensions(lhs); |
146 | const int rhs_rank = NumDimensions(rhs); |
147 | const int batch_size = op_context->params->adj_x |
148 | ? lhs->dims->data[lhs_rank - 1] |
149 | : lhs->dims->data[lhs_rank - 2]; |
150 | const int num_units = op_context->params->adj_y |
151 | ? rhs->dims->data[rhs_rank - 2] |
152 | : rhs->dims->data[rhs_rank - 1]; |
153 | |
154 | // Temp tensor for Transposed LHS; |
155 | { |
156 | node->temporaries->data[0] = op_data->scratch_tensor_index; |
157 | TfLiteTensor* scratch_buffer; |
158 | TF_LITE_ENSURE_OK( |
159 | context, GetTemporarySafe(context, node, /*index=*/0, &scratch_buffer)); |
160 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(lhs_rank); |
161 | for (int i = 0; i < lhs_rank - 2; ++i) { |
162 | scratch_buffer_size->data[i] = lhs->dims->data[i]; |
163 | } |
164 | // Swap last two dimensions. |
165 | scratch_buffer_size->data[lhs_rank - 2] = lhs->dims->data[lhs_rank - 1]; |
166 | scratch_buffer_size->data[lhs_rank - 1] = lhs->dims->data[lhs_rank - 2]; |
167 | |
168 | scratch_buffer->type = op_context->lhs->type; |
169 | scratch_buffer->allocation_type = kTfLiteArenaRw; |
170 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, |
171 | scratch_buffer_size)); |
172 | } |
173 | |
174 | // We need a temp buffer for the RHS if we need to transpose the RHS. We |
175 | // transpose by default, so that the two inputs (LHS and RHS) are in a proper |
176 | // layout for our fast matrix multiplication routines. If the transpose flag |
177 | // is set by the caller, the data is already in the desired layout. |
178 | { |
179 | node->temporaries->data[1] = op_data->scratch_tensor_index + 1; |
180 | TfLiteTensor* scratch_buffer; |
181 | TF_LITE_ENSURE_OK( |
182 | context, GetTemporarySafe(context, node, /*index=*/1, &scratch_buffer)); |
183 | scratch_buffer->name = "BatchMatMul_scratch_buffer" ; |
184 | const TfLiteTensor* rhs = op_context->rhs; |
185 | int rhs_rank = NumDimensions(rhs); |
186 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(rhs_rank); |
187 | for (int i = 0; i < rhs_rank - 2; ++i) { |
188 | scratch_buffer_size->data[i] = rhs->dims->data[i]; |
189 | } |
190 | // Swap last two dimensions. |
191 | scratch_buffer_size->data[rhs_rank - 2] = rhs->dims->data[rhs_rank - 1]; |
192 | scratch_buffer_size->data[rhs_rank - 1] = rhs->dims->data[rhs_rank - 2]; |
193 | |
194 | if (IsConstantTensor(op_context->rhs)) { |
195 | scratch_buffer->allocation_type = kTfLiteArenaRwPersistent; |
196 | } else { |
197 | scratch_buffer->allocation_type = kTfLiteArenaRw; |
198 | } |
199 | scratch_buffer->type = op_context->rhs->type; |
200 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, |
201 | scratch_buffer_size)); |
202 | } |
203 | |
204 | // If we have to perform on-the-fly quantization (with quantized weights and |
205 | // float inputs) first we need to quantize the inputs. Allocate temporary |
206 | // buffer to store the intermediate quantized values, the batch scaling |
207 | // factors, the accumulator buffer (optimized version), the input offsets, |
208 | // and the sums of the rows for each weights matrix. |
209 | // RHS = weights, LHS = inputs |
210 | if (is_hybrid) { |
211 | // Calculate the total number of LHS batches. |
212 | int num_batches = 1; |
213 | for (int i = 0; i < lhs_rank - 2; ++i) { |
214 | num_batches *= lhs->dims->data[i]; |
215 | } |
216 | int num_weights_matrices = 1; |
217 | for (int i = 0; i < rhs_rank - 2; ++i) { |
218 | num_weights_matrices *= rhs->dims->data[i]; |
219 | } |
220 | op_data->compute_row_sums = true; |
221 | node->temporaries->data[2] = op_data->scratch_tensor_index + 2; |
222 | TfLiteTensor* input_quantized; |
223 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, |
224 | &input_quantized)); |
225 | input_quantized->type = op_context->rhs->type; |
226 | input_quantized->allocation_type = kTfLiteArenaRw; |
227 | |
228 | TfLiteIntArray* input_quantized_size = |
229 | TfLiteIntArrayCopy(op_context->lhs->dims); |
230 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
231 | input_quantized_size)); |
232 | |
233 | node->temporaries->data[3] = op_data->scratch_tensor_index + 3; |
234 | TfLiteTensor* scaling_factors; |
235 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3, |
236 | &scaling_factors)); |
237 | scaling_factors->type = kTfLiteFloat32; |
238 | scaling_factors->allocation_type = kTfLiteArenaRw; |
239 | // Total size of scaling factors is batch size * number of total batches |
240 | int scaling_dims[1] = {num_batches * batch_size}; |
241 | if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { |
242 | TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); |
243 | scaling_factors_size->data[0] = scaling_dims[0]; |
244 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, |
245 | scaling_factors_size)); |
246 | } |
247 | |
248 | node->temporaries->data[4] = op_data->scratch_tensor_index + 4; |
249 | TfLiteTensor* accum_scratch; |
250 | TF_LITE_ENSURE_OK( |
251 | context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch)); |
252 | accum_scratch->type = kTfLiteInt32; |
253 | accum_scratch->allocation_type = kTfLiteArenaRw; |
254 | int accum_scratch_dims[2] = {num_units, batch_size}; |
255 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
256 | accum_scratch_dims)) { |
257 | TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); |
258 | accum_size->data[0] = num_units; |
259 | accum_size->data[1] = batch_size; |
260 | TF_LITE_ENSURE_OK( |
261 | context, context->ResizeTensor(context, accum_scratch, accum_size)); |
262 | } |
263 | |
264 | node->temporaries->data[5] = op_data->scratch_tensor_index + 5; |
265 | TfLiteTensor* input_offsets; |
266 | TF_LITE_ENSURE_OK( |
267 | context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets)); |
268 | input_offsets->type = kTfLiteInt32; |
269 | input_offsets->allocation_type = kTfLiteArenaRw; |
270 | if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) { |
271 | TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1); |
272 | input_offsets_size->data[0] = num_batches * batch_size; |
273 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets, |
274 | input_offsets_size)); |
275 | } |
276 | node->temporaries->data[6] = op_data->scratch_tensor_index + 6; |
277 | TfLiteTensor* row_sums; |
278 | TF_LITE_ENSURE_OK(context, |
279 | GetTemporarySafe(context, node, /*index=*/6, &row_sums)); |
280 | row_sums->type = kTfLiteInt32; |
281 | row_sums->allocation_type = kTfLiteArenaRwPersistent; |
282 | int row_sums_dims[1] = {num_weights_matrices * num_units}; |
283 | if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) { |
284 | TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1); |
285 | row_sums_size->data[0] = row_sums_dims[0]; |
286 | TF_LITE_ENSURE_OK( |
287 | context, context->ResizeTensor(context, row_sums, row_sums_size)); |
288 | } |
289 | } |
290 | |
291 | return kTfLiteOk; |
292 | } |
293 | |
294 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
295 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
296 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
297 | |
298 | OpContext op_context(context, node); |
299 | TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); |
300 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
301 | |
302 | bool adj_x = op_context.params->adj_x; |
303 | bool adj_y = op_context.params->adj_y; |
304 | |
305 | const TfLiteTensor* lhs_data; |
306 | TF_LITE_ENSURE_OK(context, |
307 | GetInputSafe(context, node, kInputLHSTensor, &lhs_data)); |
308 | const TfLiteTensor* rhs_data; |
309 | TF_LITE_ENSURE_OK(context, |
310 | GetInputSafe(context, node, kInputRHSTensor, &rhs_data)); |
311 | TfLiteTensor* output; |
312 | TF_LITE_ENSURE_OK(context, |
313 | GetOutputSafe(context, node, kOutputTensor, &output)); |
314 | |
315 | // Note that quantized inference requires that all tensors have their |
316 | // parameters set. This is usually done during quantized training. |
317 | if ((lhs_data->type == kTfLiteInt8 || lhs_data->type == kTfLiteInt16) && |
318 | output->type != kTfLiteInt32) { |
319 | double real_multiplier = 0.0; |
320 | TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler( |
321 | context, lhs_data, rhs_data, output, &real_multiplier)); |
322 | int exponent; |
323 | QuantizeMultiplier(real_multiplier, &op_data->output_multiplier, &exponent); |
324 | op_data->output_shift = exponent; |
325 | // BatchMatMul has no fused activation functions. Therefore, set |
326 | // output activation min and max to min and max of int8_t or int16_t |
327 | // type. |
328 | if (lhs_data->type == kTfLiteInt8) { |
329 | op_data->output_activation_min = std::numeric_limits<int8_t>::min(); |
330 | op_data->output_activation_max = std::numeric_limits<int8_t>::max(); |
331 | } else { |
332 | op_data->output_activation_min = std::numeric_limits<int16_t>::min(); |
333 | op_data->output_activation_max = std::numeric_limits<int16_t>::max(); |
334 | } |
335 | } |
336 | |
337 | if (lhs_data->type == kTfLiteInt16) { |
338 | TF_LITE_ENSURE_EQ(context, lhs_data->params.zero_point, 0); |
339 | TF_LITE_ENSURE_EQ(context, rhs_data->params.zero_point, 0); |
340 | TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0); |
341 | } |
342 | |
343 | TF_LITE_ENSURE(context, lhs_data->type == kTfLiteFloat32 || |
344 | lhs_data->type == kTfLiteInt8 || |
345 | lhs_data->type == kTfLiteInt16); |
346 | TF_LITE_ENSURE(context, rhs_data->type == kTfLiteFloat32 || |
347 | rhs_data->type == kTfLiteInt8 || |
348 | rhs_data->type == kTfLiteInt16); |
349 | // Either we have a hybrid quantization with a float32 and an int8 input, |
350 | // otherwise both inputs should be of the same type. |
351 | TF_LITE_ENSURE(context, (lhs_data->type == kTfLiteFloat32 && |
352 | rhs_data->type == kTfLiteInt8) || |
353 | lhs_data->type == rhs_data->type); |
354 | // Support dimensions between 2 and 5, inclusive. |
355 | TF_LITE_ENSURE(context, NumDimensions(lhs_data) >= 2); |
356 | TF_LITE_ENSURE(context, NumDimensions(lhs_data) <= 5); |
357 | TF_LITE_ENSURE(context, NumDimensions(rhs_data) >= 2); |
358 | TF_LITE_ENSURE(context, NumDimensions(rhs_data) <= 5); |
359 | |
360 | const int lhs_rank = NumDimensions(lhs_data); |
361 | const int rhs_rank = NumDimensions(rhs_data); |
362 | const int output_rank = std::max(lhs_rank, rhs_rank); |
363 | const RuntimeShape extended_lhs_shape = |
364 | RuntimeShape::ExtendedShape(output_rank, GetTensorShape(lhs_data)); |
365 | const RuntimeShape extended_rhs_shape = |
366 | RuntimeShape::ExtendedShape(output_rank, GetTensorShape(rhs_data)); |
367 | |
368 | // Ensure any batch dimensions obey broacasting rules. |
369 | for (int i = 0; i < output_rank - 2; ++i) { |
370 | const int lhs_dim = extended_lhs_shape.Dims(i); |
371 | const int rhs_dim = extended_rhs_shape.Dims(i); |
372 | if (lhs_dim != rhs_dim) { |
373 | if (lhs_dim != 1) { |
374 | TF_LITE_ENSURE_EQ(context, rhs_dim, 1); |
375 | } |
376 | } |
377 | } |
378 | // Ensure other dimensions work for matrix multiplication. |
379 | int accum_dim_lhs = adj_x ? extended_lhs_shape.Dims(output_rank - 2) |
380 | : extended_lhs_shape.Dims(output_rank - 1); |
381 | int accum_dim_rhs = adj_y ? extended_rhs_shape.Dims(output_rank - 1) |
382 | : extended_rhs_shape.Dims(output_rank - 2); |
383 | |
384 | TF_LITE_ENSURE_EQ(context, accum_dim_lhs, accum_dim_rhs); |
385 | TfLiteStatus status = |
386 | ResizeOutputTensor(context, extended_lhs_shape, extended_rhs_shape, adj_x, |
387 | adj_y, output_rank, output); |
388 | return status; |
389 | } |
390 | |
391 | template <typename scalar> |
392 | void TransposeRowsColumnsImpl(const TfLiteTensor* tensor_in, |
393 | const scalar* input, TfLiteTensor* tensor_out, |
394 | scalar* output) { |
395 | RuntimeShape transposed_shape(GetTensorShape(tensor_in)); |
396 | RuntimeShape shape(GetTensorShape(tensor_in)); |
397 | TransposeParams params; |
398 | int rank = NumDimensions(tensor_in); |
399 | params.perm_count = rank; |
400 | for (int i = 0; i < rank - 2; ++i) { |
401 | params.perm[i] = i; |
402 | } |
403 | // Transpose the last two dimensions. |
404 | params.perm[rank - 2] = rank - 1; |
405 | params.perm[rank - 1] = rank - 2; |
406 | transposed_shape.SetDim(rank - 1, shape.Dims(rank - 2)); |
407 | transposed_shape.SetDim(rank - 2, shape.Dims(rank - 1)); |
408 | optimized_ops::Transpose(params, shape, input, transposed_shape, output); |
409 | } |
410 | |
411 | TfLiteStatus TransposeRowsColumns(TfLiteContext* context, |
412 | const TfLiteTensor* tensor_in, |
413 | TfLiteTensor* tensor_out) { |
414 | if (tensor_in->type == kTfLiteFloat32) { |
415 | TransposeRowsColumnsImpl<float>(tensor_in, GetTensorData<float>(tensor_in), |
416 | tensor_out, |
417 | GetTensorData<float>(tensor_out)); |
418 | return kTfLiteOk; |
419 | } else if (tensor_in->type == kTfLiteInt8) { |
420 | TransposeRowsColumnsImpl<int8_t>( |
421 | tensor_in, GetTensorData<int8_t>(tensor_in), tensor_out, |
422 | GetTensorData<int8_t>(tensor_out)); |
423 | return kTfLiteOk; |
424 | } else if (tensor_in->type == kTfLiteInt16) { |
425 | TransposeRowsColumnsImpl<int16_t>( |
426 | tensor_in, GetTensorData<int16_t>(tensor_in), tensor_out, |
427 | GetTensorData<int16_t>(tensor_out)); |
428 | return kTfLiteOk; |
429 | } else { |
430 | TF_LITE_KERNEL_LOG( |
431 | context, "Can only transpose tensors with float, int8 or int16 type." ); |
432 | return kTfLiteError; |
433 | } |
434 | } |
435 | |
436 | RuntimeShape SwapRowColumnDims(const RuntimeShape& shape) { |
437 | RuntimeShape swapped_shape(shape); |
438 | const int32_t dims = shape.DimensionsCount(); |
439 | swapped_shape.SetDim(dims - 2, shape.Dims(dims - 1)); |
440 | swapped_shape.SetDim(dims - 1, shape.Dims(dims - 2)); |
441 | return swapped_shape; |
442 | } |
443 | |
444 | template <KernelType kernel_type> |
445 | TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node, OpData* data, |
446 | const RuntimeShape& input_shape, |
447 | const TfLiteTensor* input, |
448 | const RuntimeShape& filter_shape, |
449 | const TfLiteTensor* filter, |
450 | TfLiteTensor* input_quantized, |
451 | TfLiteTensor* scaling_factors, |
452 | TfLiteTensor* accum_scratch, TfLiteTensor* row_sums, |
453 | TfLiteTensor* input_offsets, TfLiteTensor* output) { |
454 | const auto* params = |
455 | reinterpret_cast<TfLiteBatchMatMulParams*>(node->builtin_data); |
456 | const int32_t num_input_dims = input_shape.DimensionsCount(); |
457 | |
458 | // Input row/cols have been swapped at this point, so dims are |
459 | // {input_size, num_batches} |
460 | const int input_size = input_shape.Dims(num_input_dims - 2); |
461 | const int batch_size = input_shape.Dims(num_input_dims - 1); |
462 | |
463 | int num_batches_to_quantize = batch_size; |
464 | for (int i = 0; i < input_shape.DimensionsCount() - 2; ++i) { |
465 | num_batches_to_quantize *= input_shape.Dims(i); |
466 | } |
467 | // Quantize input from float to uint8 + quantization params (scaling factor). |
468 | const int scaling_factor_size = GetTensorShape(scaling_factors).FlatSize(); |
469 | TF_LITE_ENSURE(context, scaling_factor_size >= num_batches_to_quantize); |
470 | float* scaling_factors_ptr = GetTensorData<float>(scaling_factors); |
471 | int32_t* input_offset_ptr = nullptr; |
472 | int32_t* row_sums_ptr = nullptr; |
473 | input_offset_ptr = GetTensorData<int32_t>(input_offsets); |
474 | row_sums_ptr = GetTensorData<int32_t>(row_sums); |
475 | if (!params->asymmetric_quantize_inputs) { |
476 | memset(input_offset_ptr, 0, input_offsets->bytes); |
477 | } |
478 | int8_t* quant_data = GetTensorData<int8_t>(input_quantized); |
479 | const int8_t* filter_data = GetTensorData<int8_t>(filter); |
480 | const float* input_ptr = GetTensorData<float>(input); |
481 | // Quantize each batch independently. |
482 | tensor_utils::BatchQuantizeFloats(input_ptr, num_batches_to_quantize, |
483 | input_size, quant_data, scaling_factors_ptr, |
484 | input_offset_ptr, |
485 | params->asymmetric_quantize_inputs); |
486 | for (int b = 0; b < num_batches_to_quantize; ++b) { |
487 | // Incorporate scaling of the filter. |
488 | scaling_factors_ptr[b] *= filter->params.scale; |
489 | } |
490 | |
491 | RuntimeShape output_shape = GetTensorShape(output); |
492 | int output_size = 1; |
493 | for (int i = 0; i < output_shape.DimensionsCount(); ++i) { |
494 | output_size *= output_shape.Dims(i); |
495 | } |
496 | std::fill_n(GetTensorData<float>(output), output_size, 0.0f); |
497 | if (kernel_type == kGenericOptimized) { |
498 | optimized_ops::BatchMatMul( |
499 | filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, |
500 | input_offset_ptr, row_sums_ptr, GetTensorShape(output), |
501 | GetTensorData<int32_t>(accum_scratch), GetTensorData<float>(output), |
502 | &(data->compute_row_sums), CpuBackendContext::GetFromContext(context)); |
503 | } else { |
504 | reference_ops::BatchMatMul( |
505 | filter_shape, filter_data, input_shape, quant_data, scaling_factors_ptr, |
506 | input_offset_ptr, row_sums_ptr, GetTensorShape(output), |
507 | GetTensorData<float>(output), &(data->compute_row_sums)); |
508 | } |
509 | |
510 | return kTfLiteOk; |
511 | } |
512 | |
513 | template <KernelType kernel_type> |
514 | TfLiteStatus EvalInt8Int8(TfLiteContext* context, const OpData* data, |
515 | const RuntimeShape& lhs_shape, |
516 | const TfLiteTensor* lhs, |
517 | const RuntimeShape& rhs_shape, |
518 | const TfLiteTensor* rhs, |
519 | const RuntimeShape& output_shape, |
520 | TfLiteTensor* output) { |
521 | // Reuse params struct from FullyConnected Op. |
522 | FullyConnectedParams op_params; |
523 | int32_t input_offset = -lhs->params.zero_point; |
524 | int32_t filter_offset = -rhs->params.zero_point; |
525 | int32_t output_offset = output->params.zero_point; |
526 | op_params.input_offset = input_offset; |
527 | op_params.weights_offset = filter_offset; |
528 | op_params.output_offset = output_offset; |
529 | op_params.output_multiplier = data->output_multiplier; |
530 | op_params.output_shift = data->output_shift; |
531 | op_params.quantized_activation_min = data->output_activation_min; |
532 | op_params.quantized_activation_max = data->output_activation_max; |
533 | op_params.lhs_cacheable = IsConstantTensor(lhs); |
534 | op_params.rhs_cacheable = IsConstantTensor(rhs); |
535 | |
536 | if (kernel_type == kReference) { |
537 | reference_ops::BatchMatMul<int8_t, int32_t>( |
538 | op_params, rhs_shape, GetTensorData<int8_t>(rhs), lhs_shape, |
539 | GetTensorData<int8_t>(lhs), GetTensorShape(output), |
540 | GetTensorData<int8_t>(output)); |
541 | } else { |
542 | optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs), |
543 | lhs_shape, GetTensorData<int8_t>(lhs), |
544 | GetTensorShape(output), |
545 | GetTensorData<int8_t>(output), |
546 | CpuBackendContext::GetFromContext(context)); |
547 | } |
548 | return kTfLiteOk; |
549 | } |
550 | |
551 | template <KernelType kernel_type> |
552 | TfLiteStatus EvalInt8Int32(TfLiteContext* context, const OpData* data, |
553 | const RuntimeShape& lhs_shape, |
554 | const TfLiteTensor* lhs, |
555 | const RuntimeShape& rhs_shape, |
556 | const TfLiteTensor* rhs, |
557 | const RuntimeShape& output_shape, |
558 | TfLiteTensor* output) { |
559 | // Reuse params struct from FullyConnected Op. |
560 | FullyConnectedParams op_params; |
561 | int32_t input_offset = -lhs->params.zero_point; |
562 | int32_t weights_offset = -rhs->params.zero_point; |
563 | int32_t output_offset = output->params.zero_point; |
564 | op_params.input_offset = input_offset; |
565 | op_params.weights_offset = weights_offset; |
566 | op_params.output_offset = output_offset; |
567 | op_params.output_multiplier = data->output_multiplier; |
568 | op_params.output_shift = data->output_shift; |
569 | op_params.quantized_activation_min = data->output_activation_min; |
570 | op_params.quantized_activation_max = data->output_activation_max; |
571 | op_params.lhs_cacheable = IsConstantTensor(lhs); |
572 | op_params.rhs_cacheable = IsConstantTensor(rhs); |
573 | |
574 | // Set BatchMatMul lhs param to rhs(filter) and rhs param to lhs(input). For |
575 | // the reason, see comment of Eval() function. |
576 | if (kernel_type == kReference) { |
577 | reference_ops::BatchMatMul<int8, int8, int32>( |
578 | rhs_shape, GetTensorData<int8>(rhs), lhs_shape, |
579 | GetTensorData<int8>(lhs), GetTensorShape(output), |
580 | GetTensorData<int32>(output)); |
581 | } else { |
582 | optimized_ops::BatchMatMul(op_params, rhs_shape, GetTensorData<int8_t>(rhs), |
583 | lhs_shape, GetTensorData<int8_t>(lhs), |
584 | GetTensorShape(output), |
585 | GetTensorData<int32_t>(output), |
586 | CpuBackendContext::GetFromContext(context)); |
587 | } |
588 | return kTfLiteOk; |
589 | } |
590 | |
591 | template <KernelType kernel_type> |
592 | TfLiteStatus EvalInt16(TfLiteContext* context, const OpData* data, |
593 | const RuntimeShape& lhs_shape, const TfLiteTensor* lhs, |
594 | const RuntimeShape& rhs_shape, const TfLiteTensor* rhs, |
595 | const RuntimeShape& output_shape, TfLiteTensor* output) { |
596 | // Reuse params struct from FullyConnected Op. |
597 | FullyConnectedParams op_params; |
598 | int32_t input_offset = -lhs->params.zero_point; |
599 | int32_t filter_offset = -rhs->params.zero_point; |
600 | int32_t output_offset = output->params.zero_point; |
601 | op_params.input_offset = input_offset; |
602 | op_params.weights_offset = filter_offset; |
603 | op_params.output_offset = output_offset; |
604 | op_params.output_multiplier = data->output_multiplier; |
605 | op_params.output_shift = data->output_shift; |
606 | op_params.quantized_activation_min = data->output_activation_min; |
607 | op_params.quantized_activation_max = data->output_activation_max; |
608 | |
609 | // optimized_ops not yet implemnted for int16_t, use reference_ops in all |
610 | // cases. |
611 | reference_ops::BatchMatMul<int16_t, int64_t>( |
612 | op_params, rhs_shape, GetTensorData<int16_t>(rhs), lhs_shape, |
613 | GetTensorData<int16_t>(lhs), GetTensorShape(output), |
614 | GetTensorData<int16_t>(output)); |
615 | return kTfLiteOk; |
616 | } |
617 | |
618 | template <KernelType kernel_type> |
619 | TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node, |
620 | OpData* data, const RuntimeShape& lhs_shape, |
621 | const TfLiteTensor* lhs, |
622 | const RuntimeShape& rhs_shape, |
623 | const TfLiteTensor* rhs, TfLiteTensor* output) { |
624 | if (lhs->type == kTfLiteFloat32 && rhs->type == kTfLiteInt8) { |
625 | TfLiteTensor* input_quantized; |
626 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, |
627 | &input_quantized)); |
628 | TfLiteTensor* scaling_factors; |
629 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3, |
630 | &scaling_factors)); |
631 | TfLiteTensor* accum_scratch; |
632 | TF_LITE_ENSURE_OK( |
633 | context, GetTemporarySafe(context, node, /*index=*/4, &accum_scratch)); |
634 | TfLiteTensor* input_offsets; |
635 | TF_LITE_ENSURE_OK( |
636 | context, GetTemporarySafe(context, node, /*index=*/5, &input_offsets)); |
637 | TfLiteTensor* row_sums; |
638 | TF_LITE_ENSURE_OK(context, |
639 | GetTemporarySafe(context, node, /*index=*/6, &row_sums)); |
640 | return EvalHybrid<kernel_type>( |
641 | context, node, data, lhs_shape, lhs, rhs_shape, rhs, input_quantized, |
642 | scaling_factors, accum_scratch, row_sums, input_offsets, output); |
643 | } else if (lhs->type == kTfLiteInt8 && rhs->type == kTfLiteInt8) { |
644 | if (output->type == kTfLiteInt8) { |
645 | return EvalInt8Int8<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, |
646 | rhs, GetTensorShape(output), output); |
647 | } else { |
648 | return EvalInt8Int32<kernel_type>(context, data, lhs_shape, lhs, |
649 | rhs_shape, rhs, GetTensorShape(output), |
650 | output); |
651 | } |
652 | } else if (lhs->type == kTfLiteInt16 && rhs->type == kTfLiteInt16) { |
653 | return EvalInt16<kernel_type>(context, data, lhs_shape, lhs, rhs_shape, rhs, |
654 | GetTensorShape(output), output); |
655 | } else { |
656 | TF_LITE_KERNEL_LOG( |
657 | context, |
658 | "Currently only hybrid, int8 and int16 quantization are supported.\n" ); |
659 | return kTfLiteError; |
660 | } |
661 | return kTfLiteOk; |
662 | } |
663 | |
664 | TfLiteTensor* GetTempRhs(TfLiteContext* context, TfLiteNode* node, |
665 | const TfLiteTensor* rhs) { |
666 | TfLiteTensor* transposed_rhs = GetTemporary(context, node, 1); |
667 | if (transposed_rhs == nullptr) { |
668 | return nullptr; |
669 | } |
670 | |
671 | if (rhs->type == kTfLiteInt8 || rhs->type == kTfLiteInt16) { |
672 | // Get the quantization params from the RHS tensor. |
673 | transposed_rhs->params.scale = rhs->params.scale; |
674 | transposed_rhs->params.zero_point = rhs->params.zero_point; |
675 | } |
676 | return transposed_rhs; |
677 | } |
678 | |
679 | TfLiteTensor* GetTempLhs(TfLiteContext* context, TfLiteNode* node, |
680 | const TfLiteTensor* lhs) { |
681 | TfLiteTensor* transposed_lhs = GetTemporary(context, node, 0); |
682 | if (transposed_lhs == nullptr) { |
683 | return nullptr; |
684 | } |
685 | |
686 | if (lhs->type == kTfLiteInt8 || lhs->type == kTfLiteInt16) { |
687 | // Get the quantization params from the LHS tensor. |
688 | transposed_lhs->params.scale = lhs->params.scale; |
689 | transposed_lhs->params.zero_point = lhs->params.zero_point; |
690 | } |
691 | return transposed_lhs; |
692 | } |
693 | |
694 | // Perform a batch matrix multiply on |
695 | // LHS <..., A, B> X RHS<..., B, C> |
696 | // where the leading dimensions of LHS and RHS obey broadcasting rules |
697 | // (this Op will apply broadcasting rules). |
698 | // We assume that LHS and RHS are both row oriented (adjacent values in memory |
699 | // are in the same row) and will output in the same memory layout. However, |
700 | // our fast GEMM libraries assume RCC layout (LHS row oriented, |
701 | // RHS column oriented, output column oriented). Therefore, we perform |
702 | // RHS <..., C, B> X LHS <..., B, A> |
703 | // where output is a C X A column-oriented, which is equivalent to |
704 | // A X C row-oriented. |
705 | template <KernelType kernel_type> |
706 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
707 | OpContext op_context(context, node); |
708 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
709 | const TfLiteTensor* lhs; |
710 | TF_LITE_ENSURE_OK(context, |
711 | GetInputSafe(context, node, kInputLHSTensor, &lhs)); |
712 | const TfLiteTensor* rhs; |
713 | TF_LITE_ENSURE_OK(context, |
714 | GetInputSafe(context, node, kInputRHSTensor, &rhs)); |
715 | TfLiteTensor* output; |
716 | TF_LITE_ENSURE_OK(context, |
717 | GetOutputSafe(context, node, kOutputTensor, &output)); |
718 | RuntimeShape orig_lhs_shape = GetTensorShape(lhs); |
719 | RuntimeShape orig_rhs_shape = GetTensorShape(rhs); |
720 | |
721 | bool adj_y = op_context.params->adj_y; |
722 | bool adj_x = op_context.params->adj_x; |
723 | |
724 | const TfLiteTensor* rhs_tensor = adj_y ? rhs : GetTempRhs(context, node, rhs); |
725 | const TfLiteTensor* lhs_tensor = adj_x ? GetTempLhs(context, node, lhs) : lhs; |
726 | if (!adj_y) { |
727 | // TODO(b/154760341) Constant tensors should already be transposed, but |
728 | // we transpose once if necessary for now. |
729 | if (!(IsConstantTensor(rhs) && op_data->rhs_transposed)) { |
730 | TransposeRowsColumns(context, rhs, GetTemporary(context, node, 1)); |
731 | op_data->rhs_transposed = true; |
732 | } |
733 | } |
734 | if (adj_x) { |
735 | TransposeRowsColumns(context, lhs, GetTemporary(context, node, 0)); |
736 | } |
737 | RuntimeShape rhs_shape = |
738 | adj_y ? orig_rhs_shape : SwapRowColumnDims(orig_rhs_shape); |
739 | RuntimeShape lhs_shape = |
740 | adj_x ? orig_lhs_shape : SwapRowColumnDims(orig_lhs_shape); |
741 | |
742 | switch (rhs->type) { |
743 | case kTfLiteFloat32: |
744 | // Note we pass RHS args first, LHS args second. See note above. |
745 | if (kernel_type == kGenericOptimized) { |
746 | optimized_ops::BatchMatMul(rhs_shape, GetTensorData<float>(rhs_tensor), |
747 | lhs_shape, GetTensorData<float>(lhs_tensor), |
748 | GetTensorShape(output), |
749 | GetTensorData<float>(output), |
750 | CpuBackendContext::GetFromContext(context)); |
751 | } else { |
752 | reference_ops::BatchMatMul(rhs_shape, GetTensorData<float>(rhs_tensor), |
753 | lhs_shape, GetTensorData<float>(lhs_tensor), |
754 | GetTensorShape(output), |
755 | GetTensorData<float>(output)); |
756 | } |
757 | break; |
758 | case kTfLiteInt8: |
759 | case kTfLiteInt16: |
760 | EvalQuantized<kernel_type>(context, node, op_data, lhs_shape, lhs_tensor, |
761 | rhs_shape, rhs_tensor, output); |
762 | break; |
763 | default: |
764 | TF_LITE_KERNEL_LOG(context, |
765 | "Currently BatchMatMul doesn't support type: %s" , |
766 | TfLiteTypeGetName(lhs->type)); |
767 | return kTfLiteError; |
768 | } |
769 | return kTfLiteOk; |
770 | } |
771 | |
772 | } // namespace batch_matmul |
773 | |
774 | TfLiteRegistration* Register_BATCH_MATMUL_REF() { |
775 | static TfLiteRegistration r = {batch_matmul::Init, batch_matmul::Free, |
776 | batch_matmul::Prepare, |
777 | batch_matmul::Eval<batch_matmul::kReference>}; |
778 | return &r; |
779 | } |
780 | |
781 | TfLiteRegistration* Register_BATCH_MATMUL_GENERIC_OPTIMIZED() { |
782 | static TfLiteRegistration r = { |
783 | batch_matmul::Init, batch_matmul::Free, batch_matmul::Prepare, |
784 | batch_matmul::Eval<batch_matmul::kGenericOptimized>}; |
785 | return &r; |
786 | } |
787 | |
788 | TfLiteRegistration* Register_BATCH_MATMUL() { |
789 | return Register_BATCH_MATMUL_GENERIC_OPTIMIZED(); |
790 | } |
791 | |
792 | } // namespace builtin |
793 | } // namespace ops |
794 | } // namespace tflite |
795 | |