1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/lite/kernels/internal/optimized/integer_ops/fully_connected.h"
17
18#include <algorithm>
19#include <cstddef>
20#include <cstdint>
21
22#include "tensorflow/lite/c/builtin_op_data.h"
23#include "tensorflow/lite/c/common.h"
24#include "tensorflow/lite/kernels/cpu_backend_context.h"
25#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26#include "tensorflow/lite/kernels/internal/optimized/sparse_ops/fully_connected.h"
27#include "tensorflow/lite/kernels/internal/quantization_util.h"
28#include "tensorflow/lite/kernels/internal/reference/fully_connected.h"
29#include "tensorflow/lite/kernels/internal/reference/integer_ops/fully_connected.h"
30#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
31#include "tensorflow/lite/kernels/internal/reference/sparse_ops/fully_connected.h"
32#include "tensorflow/lite/kernels/internal/tensor.h"
33#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
34#include "tensorflow/lite/kernels/internal/tensor_utils.h"
35#include "tensorflow/lite/kernels/internal/types.h"
36#include "tensorflow/lite/kernels/kernel_util.h"
37
38namespace tflite {
39namespace ops {
40namespace builtin {
41namespace fully_connected {
42
43namespace {
44bool SupportedSparsityFormat(const TfLiteSparsity& sparsity) {
45 if (sparsity.dim_metadata[0].format == kTfLiteDimDense &&
46 sparsity.dim_metadata[1].format == kTfLiteDimSparseCSR) {
47 return true;
48 }
49
50 return false;
51}
52
53static const int kDimMetadataSizeRandomSparse = 2;
54static const int kDimMetadataSizeBlockSparse = 3;
55
56TfLiteStatus CreateLedgerTensor(const TfLiteSparsity* sparsity,
57 TfLiteContext* context, TfLiteTensor* ledger) {
58 TF_LITE_ENSURE(context, sparsity != nullptr);
59 ledger->name = "FC_ledger";
60 ledger->type = kTfLiteUInt8;
61 ledger->allocation_type = kTfLiteArenaRwPersistent;
62 TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1);
63 ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size +
64 sparsity->dim_metadata[1].array_segments->size - 1;
65 return context->ResizeTensor(context, ledger, ledger_size);
66}
67
68TfLiteStatus PopulateLedgerData(const TfLiteSparsity* sparsity,
69 TfLiteContext* context, uint8_t* ledger_data) {
70 TF_LITE_ENSURE(context, sparsity != nullptr);
71 const auto* array_segments = sparsity->dim_metadata[1].array_segments;
72 const auto* array_indices = sparsity->dim_metadata[1].array_indices;
73 int output_data_ptr = 0;
74
75 for (int i = 0; i < array_segments->size - 1; i++) {
76 int row_start = array_segments->data[i];
77 int row_end = array_segments->data[i + 1];
78 if (row_end - row_start > UINT8_MAX) {
79 return kTfLiteError;
80 }
81 // Copy num of non-zero blocks in row i.
82 ledger_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start);
83 output_data_ptr++;
84
85 for (int j = row_start; j < row_end; j++) {
86 if (array_indices->data[j] > UINT8_MAX) {
87 return kTfLiteError;
88 }
89 // Copy indices of non-zero blocks in row i.
90 ledger_data[output_data_ptr] =
91 static_cast<uint8_t>(array_indices->data[j]);
92 output_data_ptr++;
93 }
94 }
95 return kTfLiteOk;
96}
97
98} // namespace
99
100// This file has four implementations of FullyConnected
101enum KernelType {
102 kReference,
103 kGenericOptimized,
104 kLegacyPie, // Legacy path used by the PIE team and related clients.
105};
106
107struct OpData {
108 // The scaling factor from input to output (aka the 'real multiplier') can
109 // be represented as a fixed point multiplier plus a left shift.
110 int32_t output_multiplier;
111 int output_shift;
112 // Per channel output multiplier and shift.
113 std::vector<int32_t> per_channel_output_multiplier;
114 std::vector<int> per_channel_output_shift;
115 // The range of the fused activation layer. For example for kNone and
116 // uint8_t these would be 0 and 255.
117 int32_t output_activation_min;
118 int32_t output_activation_max;
119 // The index of the temporary tensor where the quantized inputs are cached.
120 int scratch_tensor_index;
121 bool compute_row_sums = false;
122 // Only used for sparse hybrid fully connected kernels.
123 bool ledger_initialized;
124};
125
126constexpr int kInputTensor = 0;
127constexpr int kWeightsTensor = 1;
128constexpr int kBiasTensor = 2;
129constexpr int kOutputTensor = 0;
130constexpr int kShuffledInputWorkspaceTensor = 1;
131
132inline TfLiteStatus CheckTypes(TfLiteContext* context,
133 const TfLiteTensor* input,
134 const TfLiteTensor* filter,
135 const TfLiteTensor* bias, TfLiteTensor* output,
136 TfLiteFullyConnectedParams* params) {
137 const bool is_quantized =
138 ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
139 const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
140 const bool is_shuffled =
141 is_quantized && (params->weights_format ==
142 kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8);
143
144 // optional bias tensor.
145 const bool is_optional_bias_float = !bias || (bias->type == kTfLiteFloat32);
146 const bool is_optional_bias_int =
147 !bias || (bias->type == kTfLiteInt32) || (bias->type == kTfLiteInt64);
148
149 if (is_quantized) {
150 if (is_shuffled) {
151 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteUInt8);
152 TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteUInt8);
153 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteInt16);
154 TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
155 } else if (is_hybrid) {
156 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
157 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
158 TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
159 } else {
160 TF_LITE_ENSURE(context, input->type == kTfLiteUInt8 ||
161 input->type == kTfLiteInt8 ||
162 input->type == kTfLiteInt16);
163 TF_LITE_ENSURE(context, output->type == kTfLiteUInt8 ||
164 output->type == kTfLiteInt8 ||
165 output->type == kTfLiteInt16);
166 TF_LITE_ENSURE_EQ(context, is_optional_bias_int, true);
167 }
168 } else {
169 // Only float32 is supported currently
170 TF_LITE_ENSURE_TYPES_EQ(context, input->type, kTfLiteFloat32);
171 TF_LITE_ENSURE_TYPES_EQ(context, output->type, kTfLiteFloat32);
172 TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
173 TF_LITE_ENSURE_EQ(context, is_optional_bias_float, true);
174 }
175
176 return kTfLiteOk;
177}
178
179void* Init(TfLiteContext* context, const char* buffer, size_t length) {
180 // This is a builtin op, so we don't use the contents in 'buffer', if any.
181 // Instead, we allocate a new object to carry information from Prepare() to
182 // Eval().
183 auto* op_data = new OpData();
184 context->AddTensors(context, /*tensors_to_add=*/6,
185 &op_data->scratch_tensor_index);
186 return op_data;
187}
188
189void Free(TfLiteContext* context, void* buffer) {
190 delete reinterpret_cast<OpData*>(buffer);
191}
192
193TfLiteStatus PrepareImpl(TfLiteContext* context, TfLiteNode* node) {
194 auto* params =
195 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
196 OpData* data = reinterpret_cast<OpData*>(node->user_data);
197 // Check we have all the inputs and outputs we need.
198 TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
199 // Shuffled formats need a workspace to store the shuffled input activations.
200 const int expected_outputs_count =
201 params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault ? 1
202 : 2;
203 TF_LITE_ENSURE_EQ(context, node->outputs->size, expected_outputs_count);
204
205 const TfLiteTensor* input;
206 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
207 const TfLiteTensor* filter;
208 TF_LITE_ENSURE_OK(context,
209 GetInputSafe(context, node, kWeightsTensor, &filter));
210 const TfLiteTensor* bias =
211 (node->inputs->size == 3)
212 ? GetOptionalInputTensor(context, node, kBiasTensor)
213 : nullptr;
214 TfLiteTensor* output;
215 TF_LITE_ENSURE_OK(context,
216 GetOutputSafe(context, node, kOutputTensor, &output));
217
218 // Check proper datatype match among all Input Tensors
219 TF_LITE_ENSURE_STATUS(
220 CheckTypes(context, input, filter, bias, output, params));
221
222 // Check all the parameters of tensor match within themselves and match the
223 // input configuration.
224 int input_size = 1;
225 for (int i = 0; i < input->dims->size; i++) {
226 input_size *= input->dims->data[i];
227 }
228
229 TF_LITE_ENSURE_EQ(context, NumDimensions(filter), 2);
230
231 // When the second dimension size of the filter tensor is 0, we need to
232 // generate the output shape early to avoid dividing by 0.
233 if (filter->dims->data[1] == 0) {
234 TfLiteIntArray* output_size_array;
235 if (params->keep_num_dims) {
236 output_size_array = TfLiteIntArrayCopy(input->dims);
237 output_size_array->data[output_size_array->size - 1] =
238 filter->dims->data[0];
239 } else {
240 output_size_array = TfLiteIntArrayCreate(2);
241 // If `keep_num_dims` is false, we need to flatten the output tensor to
242 // have rank 2.
243 int batch_size = 1;
244 for (int i = 0; i < input->dims->size - 1; ++i)
245 batch_size *= input->dims->data[i];
246 output_size_array->data[0] = batch_size;
247 output_size_array->data[1] = filter->dims->data[0];
248 }
249 TF_LITE_ENSURE_OK(
250 context, context->ResizeTensor(context, output, output_size_array));
251 return kTfLiteOk;
252 }
253
254 const int batch_size = input_size / filter->dims->data[1];
255 const int num_units = filter->dims->data[0];
256
257 if (bias) {
258 TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 0));
259 }
260
261 // Note that quantized inference requires that all tensors have their
262 // parameters set. This is usually done during quantized training.
263 if (input->type == kTfLiteUInt8 || input->type == kTfLiteInt8 ||
264 input->type == kTfLiteInt16) {
265 // Populate scalar quantization parameters.
266 double real_multiplier = 0.0;
267 TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
268 context, input, filter, bias, output, &real_multiplier));
269 int exponent;
270 QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
271 data->output_shift = exponent;
272
273 // Populate per-channel quantization parameters, if per-channel
274 // quantization.
275 TF_LITE_ENSURE_EQ(context, input->quantization.type,
276 kTfLiteAffineQuantization);
277 TF_LITE_ENSURE_EQ(context, filter->quantization.type,
278 kTfLiteAffineQuantization);
279 const auto* affine_quantization =
280 reinterpret_cast<TfLiteAffineQuantization*>(
281 filter->quantization.params);
282 TF_LITE_ENSURE(context, affine_quantization);
283 TF_LITE_ENSURE(context, affine_quantization->scale);
284 const int per_channel_quantization_size = affine_quantization->scale->size;
285 const bool is_per_channel = per_channel_quantization_size > 1;
286 if (is_per_channel) {
287 // Currently only Int8/Int16 is supported for per channel quantization.
288 TF_LITE_ENSURE(context,
289 input->type == kTfLiteInt8 || input->type == kTfLiteInt16);
290 TF_LITE_ENSURE_EQ(context, filter->type, kTfLiteInt8);
291 TF_LITE_ENSURE_EQ(context, affine_quantization->scale->size,
292 per_channel_quantization_size);
293 TF_LITE_ENSURE_EQ(
294 context, per_channel_quantization_size,
295 filter->dims->data[affine_quantization->quantized_dimension]);
296 // Populate multiplier and shift using affine quantization.
297 const float input_scale = input->params.scale;
298 const float output_scale = output->params.scale;
299 const float* filter_scales = affine_quantization->scale->data;
300 data->per_channel_output_multiplier.resize(per_channel_quantization_size);
301 data->per_channel_output_shift.resize(per_channel_quantization_size);
302 int32_t* per_channel_multiplier =
303 data->per_channel_output_multiplier.data();
304 int32_t* per_channel_shift = data->per_channel_output_shift.data();
305 for (int i = 0; i < per_channel_quantization_size; ++i) {
306 const float scale = filter_scales[i];
307 const double filter_scale = static_cast<double>(scale);
308 const double effective_output_scale = static_cast<double>(input_scale) *
309 filter_scale /
310 static_cast<double>(output_scale);
311 int32_t significand;
312 int channel_shift;
313 QuantizeMultiplier(effective_output_scale, &significand,
314 &channel_shift);
315 per_channel_multiplier[i] = significand;
316 per_channel_shift[i] = channel_shift;
317 }
318 }
319
320 TF_LITE_ENSURE_STATUS(CalculateActivationRangeQuantized(
321 context, params->activation, output, &data->output_activation_min,
322 &data->output_activation_max));
323 }
324
325 if (input->type == kTfLiteInt16 && output->type == kTfLiteInt16) {
326 TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
327 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
328 }
329
330 // If we have to perform on-the-fly quantization (with quantized weights and
331 // float inputs) first we need to quantize the inputs. Allocate a temporary
332 // buffer to store the intermediate quantized values.
333 // Additionally, we allocate a temporary buffer to store the accumulated
334 // quantized values prior to multiplication by the scaling factor.
335 const bool is_hybrid =
336 (input->type == kTfLiteFloat32 &&
337 (filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8));
338 const bool is_sparse = filter->sparsity != nullptr;
339 if (is_hybrid) {
340 TfLiteIntArrayFree(node->temporaries);
341 data->compute_row_sums = true;
342 if (is_sparse) {
343 node->temporaries = TfLiteIntArrayCreate(6);
344 } else {
345 node->temporaries = TfLiteIntArrayCreate(5);
346 }
347 node->temporaries->data[0] = data->scratch_tensor_index;
348
349 TfLiteTensor* input_quantized;
350 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
351 &input_quantized));
352 input_quantized->type = filter->type;
353 input_quantized->allocation_type = kTfLiteArenaRw;
354
355 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
356 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
357 input_quantized_size));
358
359 node->temporaries->data[1] = data->scratch_tensor_index + 1;
360 TfLiteTensor* scaling_factors;
361 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
362 &scaling_factors));
363 scaling_factors->type = kTfLiteFloat32;
364 scaling_factors->allocation_type = kTfLiteArenaRw;
365
366 int scaling_dims[1] = {batch_size};
367 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
368 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
369 scaling_factors_size->data[0] = batch_size;
370 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
371 scaling_factors_size));
372 }
373
374 node->temporaries->data[2] = data->scratch_tensor_index + 2;
375 TfLiteTensor* accum_scratch;
376 TF_LITE_ENSURE_OK(
377 context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
378 accum_scratch->type = kTfLiteInt32;
379 accum_scratch->allocation_type = kTfLiteArenaRw;
380 int accum_scratch_dims[2] = {num_units, batch_size};
381 if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2,
382 accum_scratch_dims)) {
383 TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2);
384 accum_size->data[0] = num_units;
385 accum_size->data[1] = batch_size;
386 TF_LITE_ENSURE_OK(
387 context, context->ResizeTensor(context, accum_scratch, accum_size));
388 }
389
390 node->temporaries->data[3] = data->scratch_tensor_index + 3;
391 TfLiteTensor* input_offsets;
392 TF_LITE_ENSURE_OK(
393 context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
394 input_offsets->type = kTfLiteInt32;
395 input_offsets->allocation_type = kTfLiteArenaRw;
396 if (!TfLiteIntArrayEqualsArray(input_offsets->dims, 1, scaling_dims)) {
397 TfLiteIntArray* input_offsets_size = TfLiteIntArrayCreate(1);
398 input_offsets_size->data[0] = batch_size;
399 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_offsets,
400 input_offsets_size));
401 }
402 node->temporaries->data[4] = data->scratch_tensor_index + 4;
403 TfLiteTensor* row_sums;
404 TF_LITE_ENSURE_OK(context,
405 GetTemporarySafe(context, node, /*index=*/4, &row_sums));
406 row_sums->type = kTfLiteInt32;
407 row_sums->allocation_type = kTfLiteArenaRwPersistent;
408 int row_sums_dims[1] = {num_units};
409 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
410 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
411 row_sums_size->data[0] = row_sums_dims[0];
412 TF_LITE_ENSURE_OK(
413 context, context->ResizeTensor(context, row_sums, row_sums_size));
414 }
415
416 if (is_sparse) {
417 data->ledger_initialized = false;
418 node->temporaries->data[5] = data->scratch_tensor_index + 5;
419 TfLiteTensor* filter_ledger =
420 &context->tensors[node->temporaries->data[5]];
421 auto status =
422 CreateLedgerTensor(filter->sparsity, context, filter_ledger);
423 if (status != kTfLiteOk) return status;
424 }
425 }
426
427 // Resize output.
428 TfLiteIntArray* output_size_array = nullptr;
429 if (params->keep_num_dims) {
430 // When number of dimensions are kept the filter operates along the last
431 // dimensions. In other words, for an input tensor with shape
432 // [batch_size, ..., n_inputs] and a filter of shape [n_inputs, n_units]
433 // this Op produces an output of shape [batch_size, ..., n_units].
434 TF_LITE_ENSURE_EQ(context, input->dims->data[input->dims->size - 1],
435 SizeOfDimension(filter, 1));
436 output_size_array = TfLiteIntArrayCopy(input->dims);
437 output_size_array->data[output_size_array->size - 1] = num_units;
438 } else {
439 // Otherwise, the output is (potentially flattened to) a 2-D matrix.
440 output_size_array = TfLiteIntArrayCreate(2);
441 output_size_array->data[0] = batch_size;
442 output_size_array->data[1] = num_units;
443 }
444 TF_LITE_ENSURE_OK(context,
445 context->ResizeTensor(context, output, output_size_array));
446
447 return kTfLiteOk;
448}
449
450template <KernelType kernel_type>
451TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
452 // Check for supported activation types.
453 auto* params =
454 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
455 const TfLiteTensor* filter;
456 TF_LITE_ENSURE_OK(context,
457 GetInputSafe(context, node, kWeightsTensor, &filter));
458 const TfLiteTensor* input;
459 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
460 const bool is_quantized =
461 ((filter->type == kTfLiteUInt8) || (filter->type == kTfLiteInt8));
462 const bool is_hybrid = is_quantized && (input->type == kTfLiteFloat32);
463 const bool is_pie = kernel_type == kLegacyPie;
464
465 // Pie and hybrid path supports all kinds of fused activations, otherwise only
466 // clipping activations are supported.
467 if (!is_pie && !is_hybrid) {
468 TF_LITE_ENSURE(context, params->activation == kTfLiteActNone ||
469 params->activation == kTfLiteActRelu ||
470 params->activation == kTfLiteActReluN1To1 ||
471 params->activation == kTfLiteActRelu6);
472 }
473 return PrepareImpl(context, node);
474}
475
476TfLiteStatus EvalPie(TfLiteContext* context, TfLiteNode* node,
477 TfLiteFullyConnectedParams* params, OpData* data,
478 const TfLiteTensor* input, const TfLiteTensor* filter,
479 const TfLiteTensor* bias, TfLiteTensor* output) {
480 int total_input_size = 1;
481 for (int i = 0; i < input->dims->size; i++) {
482 total_input_size *= input->dims->data[i];
483 }
484
485 int input_size = filter->dims->data[1];
486 const int batch_size = total_input_size / filter->dims->data[1];
487 const int num_units = filter->dims->data[0];
488
489 // Output = bias if bias tensor exists.
490 if (bias) {
491 tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
492 batch_size,
493 GetTensorData<float>(output));
494 } else {
495 std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
496 }
497
498 // Compute output += weight * input
499 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
500 GetTensorData<float>(filter), num_units, input_size,
501 GetTensorData<float>(input), batch_size, GetTensorData<float>(output));
502
503 // Apply activation function
504 tensor_utils::ApplyActivationToVector(
505 GetTensorData<float>(output), batch_size * num_units, params->activation,
506 GetTensorData<float>(output));
507
508 return kTfLiteOk;
509}
510
511TfLiteStatus EvalHybridDense(
512 TfLiteContext* context, TfLiteNode* node,
513 TfLiteFullyConnectedParams* params, OpData* data, const TfLiteTensor* input,
514 const TfLiteTensor* filter, const TfLiteTensor* bias,
515 TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors,
516 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
517 TfLiteTensor* input_offsets, TfLiteTensor* output) {
518 int total_input_size = 1;
519 for (int i = 0; i < input->dims->size; i++) {
520 total_input_size *= input->dims->data[i];
521 }
522
523 const int input_size = filter->dims->data[1];
524 const int batch_size = total_input_size / filter->dims->data[1];
525 const int num_units = filter->dims->data[0];
526
527 // Output = bias if bias tensor exists.
528 if (bias) {
529 tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias), num_units,
530 batch_size,
531 GetTensorData<float>(output));
532 } else {
533 std::fill_n(GetTensorData<float>(output), batch_size * num_units, 0.0f);
534 }
535
536 // Save matrix multiplication computation for all zero input.
537 if (tensor_utils::IsZeroVector(GetTensorData<float>(input),
538 total_input_size)) {
539 tensor_utils::ApplyActivationToVector(
540 GetTensorData<float>(output), batch_size * num_units,
541 params->activation, GetTensorData<float>(output));
542 return kTfLiteOk;
543 }
544
545 // Quantize input from float to uint8 + quantization params (scaling factor).
546 float* scaling_factors_ptr = GetTensorData<float>(scaling_factors);
547 int32_t* input_offset_ptr = nullptr;
548 int32_t* row_sums_ptr = nullptr;
549 if (params->asymmetric_quantize_inputs) {
550 input_offset_ptr = GetTensorData<int32_t>(input_offsets);
551 row_sums_ptr = GetTensorData<int32_t>(row_sums);
552 }
553 int8_t* quant_data = GetTensorData<int8_t>(input_quantized);
554 const int8_t* filter_data = GetTensorData<int8_t>(filter);
555 const float* input_ptr = GetTensorData<float>(input);
556 tensor_utils::BatchQuantizeFloats(
557 input_ptr, batch_size, input_size, quant_data, scaling_factors_ptr,
558 input_offset_ptr, params->asymmetric_quantize_inputs);
559 for (int b = 0; b < batch_size; ++b) {
560 // Incorporate scaling of the filter.
561 scaling_factors_ptr[b] *= filter->params.scale;
562 }
563
564 // Compute output += weight * quantized_input
565 int32_t* scratch = GetTensorData<int32_t>(accum_scratch);
566 tensor_utils::MatrixBatchVectorMultiplyAccumulate(
567 filter_data, num_units, input_size, quant_data, scaling_factors_ptr,
568 batch_size, GetTensorData<float>(output), /*per_channel_scale=*/nullptr,
569 input_offset_ptr, scratch, row_sums_ptr, &data->compute_row_sums,
570 CpuBackendContext::GetFromContext(context));
571
572 // Apply activation function to floats.
573 tensor_utils::ApplyActivationToVector(
574 GetTensorData<float>(output), batch_size * num_units, params->activation,
575 GetTensorData<float>(output));
576 return kTfLiteOk;
577}
578
579void EvalSparseHybridImpl(TfLiteContext* context, TfLiteNode* node,
580 TfLiteFullyConnectedParams* params, OpData* data,
581 const TfLiteTensor* input, const TfLiteTensor* filter,
582 const TfLiteTensor* bias, int thread_start,
583 int thread_end, TfLiteTensor* input_quantized,
584 TfLiteTensor* scaling_factors,
585 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
586 TfLiteTensor* input_offsets, TfLiteTensor* output) {
587 ruy::profiler::ScopeLabel label("FullyConnected");
588 ruy::profiler::ScopeLabel inner_label("Sparse Hybrid Kernel");
589 const auto& input_shape = GetTensorShape(input);
590 const auto& output_shape = GetTensorShape(output);
591 const auto& filter_shape = GetTensorShape(filter);
592 const int input_dims_count = input_shape.DimensionsCount();
593 const int output_dims_count = output_shape.DimensionsCount();
594 const int filter_dims_count = filter_shape.DimensionsCount();
595 const int batch_size = thread_end - thread_start;
596 const int input_depth = MatchingDim(filter_shape, filter_dims_count - 1,
597 input_shape, input_dims_count - 1);
598 const int output_depth = MatchingDim(filter_shape, filter_dims_count - 2,
599 output_shape, output_dims_count - 1);
600 const int per_thread_input_size = batch_size * input_depth;
601
602 const float* per_thread_input =
603 GetTensorData<float>(input) + thread_start * input_depth;
604 float* per_thread_output =
605 GetTensorData<float>(output) + thread_start * output_depth;
606
607 // Output = bias if bias tensor exists.
608 if (bias) {
609 tensor_utils::VectorBatchVectorAssign(GetTensorData<float>(bias),
610 output_depth, batch_size,
611 per_thread_output);
612 } else {
613 std::fill_n(per_thread_output, batch_size * output_depth, 0.0f);
614 }
615
616 // Save matrix multiplication computation for all zero input.
617 if (tensor_utils::IsZeroVector(per_thread_input, per_thread_input_size)) {
618 tensor_utils::ApplyActivationToVector(
619 per_thread_output, batch_size * output_depth, params->activation,
620 per_thread_output);
621 return;
622 }
623
624 // Quantize input from float to uint8 + quantization params (scaling factor).
625 float* scaling_factors_ptr =
626 GetTensorData<float>(scaling_factors) + thread_start;
627 int32_t* input_offset_ptr = nullptr;
628 int32_t* row_sums_ptr = nullptr;
629 if (params->asymmetric_quantize_inputs) {
630 input_offset_ptr = GetTensorData<int32_t>(input_offsets) + thread_start;
631 row_sums_ptr = GetTensorData<int32_t>(row_sums);
632 }
633 int8_t* quant_data =
634 GetTensorData<int8_t>(input_quantized) + thread_start * input_depth;
635 tensor_utils::BatchQuantizeFloats(per_thread_input, batch_size, input_depth,
636 quant_data, scaling_factors_ptr,
637 input_offset_ptr,
638 params->asymmetric_quantize_inputs);
639 for (int b = 0; b < batch_size; ++b) {
640 // Incorporate scaling of the filter.
641 scaling_factors_ptr[b] *= filter->params.scale;
642 }
643
644 if (params->asymmetric_quantize_inputs) {
645 float* per_thread_output_ptr = per_thread_output;
646 for (int b = 0; b < batch_size; ++b) {
647 const float scaled_zp = scaling_factors_ptr[b] * input_offset_ptr[b];
648 for (int row = 0; row < output_depth; ++row) {
649 *per_thread_output_ptr++ -= scaled_zp * row_sums_ptr[row];
650 }
651 }
652 }
653
654 // Compute output += weight * quantized_input
655 TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
656 tensor_utils::SparseMatrixBatchVectorMultiplyAccumulate(
657 GetTensorData<int8_t>(filter), GetTensorData<uint8_t>(filter_ledger),
658 output_depth, input_depth, quant_data, scaling_factors_ptr, batch_size,
659 per_thread_output);
660
661 // Apply activation function to floats.
662 tensor_utils::ApplyActivationToVector(per_thread_output,
663 batch_size * output_depth,
664 params->activation, per_thread_output);
665}
666
667struct SparseHybridFullyConnectedTask : cpu_backend_threadpool::Task {
668 SparseHybridFullyConnectedTask(
669 TfLiteContext* context, TfLiteNode* node,
670 TfLiteFullyConnectedParams* params, OpData* data,
671 const TfLiteTensor* input, const TfLiteTensor* filter,
672 const TfLiteTensor* bias, const int thread_start, const int thread_end,
673 TfLiteTensor* input_quantized, TfLiteTensor* scaling_factors,
674 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
675 TfLiteTensor* input_offsets, TfLiteTensor* output)
676 : context(context),
677 node(node),
678 params(params),
679 data(data),
680 input(input),
681 filter(filter),
682 bias(bias),
683 thread_start(thread_start),
684 thread_end(thread_end),
685 input_quantized(input_quantized),
686 scaling_factors(scaling_factors),
687 accum_scratch(accum_scratch),
688 row_sums(row_sums),
689 input_offsets(input_offsets),
690 output(output) {}
691
692 void Run() override {
693 EvalSparseHybridImpl(context, node, params, data, input, filter, bias,
694 thread_start, thread_end, input_quantized,
695 scaling_factors, accum_scratch, row_sums,
696 input_offsets, output);
697 }
698
699 private:
700 TfLiteContext* context;
701 TfLiteNode* node;
702 TfLiteFullyConnectedParams* params;
703 OpData* data;
704 const TfLiteTensor* input;
705 const TfLiteTensor* filter;
706 const TfLiteTensor* bias;
707 const int thread_start;
708 const int thread_end;
709 TfLiteTensor* input_quantized;
710 TfLiteTensor* scaling_factors;
711 TfLiteTensor* accum_scratch;
712 TfLiteTensor* row_sums;
713 TfLiteTensor* input_offsets;
714 TfLiteTensor* output;
715};
716
717TfLiteStatus EvalHybrid(TfLiteContext* context, TfLiteNode* node,
718 TfLiteFullyConnectedParams* params, OpData* data,
719 const TfLiteTensor* input, const TfLiteTensor* filter,
720 const TfLiteTensor* bias, TfLiteTensor* input_quantized,
721 TfLiteTensor* scaling_factors,
722 TfLiteTensor* accum_scratch, TfLiteTensor* row_sums,
723 TfLiteTensor* input_offsets, TfLiteTensor* output) {
724 const auto& output_shape = GetTensorShape(output);
725 CpuBackendContext* cpu_backend_context =
726 CpuBackendContext::GetFromContext(context);
727 const bool is_dense = filter->sparsity == nullptr;
728 if (is_dense) {
729 return EvalHybridDense(context, node, params, data, input, filter, bias,
730 input_quantized, scaling_factors, accum_scratch,
731 row_sums, input_offsets, output);
732 }
733
734 TfLiteTensor* filter_ledger = &context->tensors[node->temporaries->data[5]];
735 if (!data->ledger_initialized) {
736 PopulateLedgerData(filter->sparsity, context,
737 GetTensorData<uint8_t>(filter_ledger));
738 data->ledger_initialized = true;
739 }
740
741 // The multi-threaded kernel slices the workload along the batch dimension. If
742 // there's not enough batches of data, the number of threads used is equal to
743 // the batch size.
744 // TODO(b/173442777): If needed, we can improve this later with slicing along
745 // the row dimension of the weight.
746 const int max_threads = cpu_backend_context->max_num_threads();
747 const int batches =
748 FlatSizeSkipDim(output_shape, output_shape.DimensionsCount() - 1);
749 const int thread_count = std::max(1, std::min(batches, max_threads));
750 if (params->asymmetric_quantize_inputs && data->compute_row_sums) {
751 // Precompute row sums.
752 static const int kBlockSize = 16;
753 const uint8_t* ledger_ptr = GetTensorData<uint8_t>(filter_ledger);
754 const int8_t* row_ptr = GetTensorData<int8_t>(filter);
755 const int output_depth = filter->dims->data[0];
756 int32_t* row_sums_ptr = GetTensorData<int32_t>(row_sums);
757 for (int row = 0; row < output_depth; ++row) {
758 int32_t row_sum = 0;
759 int num_nonzero_blocks = *ledger_ptr++;
760 for (int i = 0; i < num_nonzero_blocks; ++i, ++ledger_ptr) {
761 for (int c = 0; c < kBlockSize; c++) {
762 row_sum += (*row_ptr++);
763 }
764 }
765 row_sums_ptr[row] = row_sum;
766 }
767 data->compute_row_sums = false;
768 }
769 std::vector<SparseHybridFullyConnectedTask> tasks;
770 tasks.reserve(thread_count);
771 int thread_start = 0;
772 for (int i = 0; i < thread_count; ++i) {
773 // This makes sure the workload is relatively balanced when batches is not
774 // a multiple of thread_count. The first mod(batches, thread_count) tasks
775 // need to process one more batch than the rest.
776 int thread_end = thread_start + batches / thread_count;
777 if (i < batches % thread_count) thread_end++;
778
779 tasks.emplace_back(context, node, params, data, input, filter, bias,
780 thread_start, thread_end, input_quantized,
781 scaling_factors, accum_scratch, row_sums, input_offsets,
782 output);
783 thread_start = thread_end;
784 }
785 cpu_backend_threadpool::Execute(tasks.size(), tasks.data(),
786 cpu_backend_context);
787 return kTfLiteOk;
788}
789
790namespace {
791template <KernelType kernel_type>
792void FullyConnectedInt8(const OpData* data, const TfLiteTensor* input,
793 const TfLiteTensor* filter, const TfLiteTensor* bias,
794 TfLiteTensor* output,
795 CpuBackendContext* cpu_backend_context) {
796 FullyConnectedParams op_params;
797 op_params.input_offset = -input->params.zero_point;
798 op_params.weights_offset = -filter->params.zero_point;
799 op_params.output_offset = output->params.zero_point;
800 op_params.output_multiplier = data->output_multiplier;
801 op_params.output_shift = data->output_shift;
802 op_params.quantized_activation_min = data->output_activation_min;
803 op_params.quantized_activation_max = data->output_activation_max;
804 op_params.lhs_cacheable = IsConstantTensor(filter);
805 op_params.rhs_cacheable = IsConstantTensor(input);
806 if (kernel_type == kReference) {
807 reference_integer_ops::FullyConnected(
808 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
809 GetTensorShape(filter), GetTensorData<int8_t>(filter),
810 GetTensorShape(bias), GetTensorData<int32_t>(bias),
811 GetTensorShape(output), GetTensorData<int8_t>(output));
812 } else {
813 optimized_integer_ops::FullyConnected(
814 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
815 GetTensorShape(filter), GetTensorData<int8_t>(filter),
816 GetTensorShape(bias), GetTensorData<int32_t>(bias),
817 GetTensorShape(output), GetTensorData<int8_t>(output),
818 cpu_backend_context);
819 }
820}
821
822template <KernelType kernel_type>
823void FullyConnectedInt16(const OpData* data, const TfLiteTensor* input,
824 const TfLiteTensor* filter, const TfLiteTensor* bias,
825 TfLiteTensor* output) {
826 FullyConnectedParams op_params;
827 op_params.weights_offset = -filter->params.zero_point;
828 op_params.output_multiplier = data->output_multiplier;
829 op_params.output_shift = data->output_shift;
830 op_params.quantized_activation_min = data->output_activation_min;
831 op_params.quantized_activation_max = data->output_activation_max;
832 if (bias && bias->type == kTfLiteInt64) {
833 reference_integer_ops::FullyConnected(
834 op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
835 GetTensorShape(filter), GetTensorData<int8_t>(filter),
836 GetTensorShape(bias), GetTensorData<int64_t>(bias),
837 GetTensorShape(output), GetTensorData<int16_t>(output));
838 } else {
839 reference_integer_ops::FullyConnected(
840 op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
841 GetTensorShape(filter), GetTensorData<int8_t>(filter),
842 GetTensorShape(bias), GetTensorData<int32_t>(bias),
843 GetTensorShape(output), GetTensorData<int16_t>(output));
844 }
845}
846
847template <KernelType kernel_type>
848void FullyConnectedPerChannelInt8(const OpData* data, const TfLiteTensor* input,
849 const TfLiteTensor* filter,
850 const TfLiteTensor* bias,
851 TfLiteTensor* output,
852 CpuBackendContext* cpu_backend_context) {
853 // FullyConnectedPerChannel ops spec is that weights are symmetric.
854 // op_params.weights_offset is not set (filter.params.zero_point is not used),
855 // since it will be always assumed to be 0.
856 FullyConnectedParams op_params;
857 op_params.input_offset = -input->params.zero_point;
858 op_params.output_offset = output->params.zero_point;
859 op_params.quantized_activation_min = data->output_activation_min;
860 op_params.quantized_activation_max = data->output_activation_max;
861 op_params.lhs_cacheable = IsConstantTensor(filter);
862 op_params.rhs_cacheable = IsConstantTensor(input);
863 if (kernel_type == kReference) {
864 reference_integer_ops::FullyConnectedPerChannel(
865 op_params, data->per_channel_output_multiplier.data(),
866 data->per_channel_output_shift.data(), GetTensorShape(input),
867 GetTensorData<int8_t>(input), GetTensorShape(filter),
868 GetTensorData<int8_t>(filter), GetTensorShape(bias),
869 GetTensorData<int32_t>(bias), GetTensorShape(output),
870 GetTensorData<int8_t>(output));
871 } else {
872 optimized_integer_ops::FullyConnectedPerChannel(
873 op_params, data->per_channel_output_multiplier.data(),
874 data->per_channel_output_shift.data(), GetTensorShape(input),
875 GetTensorData<int8_t>(input), GetTensorShape(filter),
876 GetTensorData<int8_t>(filter), GetTensorShape(bias),
877 GetTensorData<int32_t>(bias), GetTensorShape(output),
878 GetTensorData<int8_t>(output), cpu_backend_context);
879 }
880}
881
882template <KernelType kernel_type>
883void FullyConnectedPerChannelInt16(const OpData* data,
884 const TfLiteTensor* input,
885 const TfLiteTensor* filter,
886 const TfLiteTensor* bias,
887 TfLiteTensor* output) {
888 // FullyConnectedPerChannel ops spec is that weights are symmetric.
889 // op_params.weights_offset is not set (filter.params.zero_point is not used),
890 // since it will be always assumed to be 0.
891 FullyConnectedParams op_params;
892 op_params.quantized_activation_min = data->output_activation_min;
893 op_params.quantized_activation_max = data->output_activation_max;
894 if (bias && bias->type == kTfLiteInt64) {
895 reference_integer_ops::FullyConnectedPerChannel(
896 op_params, data->per_channel_output_multiplier.data(),
897 data->per_channel_output_shift.data(), GetTensorShape(input),
898 GetTensorData<int16_t>(input), GetTensorShape(filter),
899 GetTensorData<int8_t>(filter), GetTensorShape(bias),
900 GetTensorData<int64_t>(bias), GetTensorShape(output),
901 GetTensorData<int16_t>(output));
902 } else {
903 reference_integer_ops::FullyConnectedPerChannel(
904 op_params, data->per_channel_output_multiplier.data(),
905 data->per_channel_output_shift.data(), GetTensorShape(input),
906 GetTensorData<int16_t>(input), GetTensorShape(filter),
907 GetTensorData<int8_t>(filter), GetTensorShape(bias),
908 GetTensorData<int32_t>(bias), GetTensorShape(output),
909 GetTensorData<int16_t>(output));
910 }
911}
912
913} // namespace
914
915// Verifies that sparsity values are valid given input/weight/output.
916bool VerifySparsity(const RuntimeShape& weights_shape,
917 const RuntimeShape& input_shape,
918 const RuntimeShape& output_shape,
919 const TfLiteSparsity* sparsity) {
920 const int weights_dims_count = weights_shape.DimensionsCount();
921 const int output_dims_count = output_shape.DimensionsCount();
922 const int w0_size = sparsity->dim_metadata[0].dense_size;
923 const int accum_depth = weights_shape.Dims(weights_dims_count - 1);
924 const int output_elements = output_shape.FlatSize();
925 const int input_elements = input_shape.FlatSize();
926 const int batches = FlatSizeSkipDim(output_shape, output_dims_count - 1);
927 const int output_depth = MatchingDim(weights_shape, weights_dims_count - 2,
928 output_shape, output_dims_count - 1);
929 const int max_batch_index = batches - 1;
930 const int max_output = max_batch_index * output_depth + w0_size;
931 const int max_batch_depth = accum_depth * max_batch_index;
932
933 // Verify output size is enough.
934 if (output_elements < max_output) return false;
935
936 // Verify index from sparse in input is valid.
937 for (int i = 0; i < sparsity->dim_metadata[1].array_indices->size; ++i) {
938 if (input_elements <=
939 max_batch_depth + sparsity->dim_metadata[1].array_indices->data[i])
940 return false;
941 }
942 return true;
943}
944
945template <KernelType kernel_type>
946TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
947 TfLiteFullyConnectedParams* params, OpData* data,
948 const TfLiteTensor* input,
949 const TfLiteTensor* filter, const TfLiteTensor* bias,
950 TfLiteTensor* output) {
951 const bool is_per_channel = data->per_channel_output_multiplier.size() > 1;
952 int32_t input_offset = -input->params.zero_point;
953 int32_t filter_offset = -filter->params.zero_point;
954 int32_t output_offset = output->params.zero_point;
955 // Only the Pie path supports quantized models and float inputs/outputs.
956 if (input->type == kTfLiteFloat32) {
957 TfLiteTensor* input_quantized;
958 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/0,
959 &input_quantized));
960 TfLiteTensor* scaling_factors;
961 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
962 &scaling_factors));
963 TfLiteTensor* accum_scratch;
964 TF_LITE_ENSURE_OK(
965 context, GetTemporarySafe(context, node, /*index=*/2, &accum_scratch));
966 TfLiteTensor* input_offsets;
967 TF_LITE_ENSURE_OK(
968 context, GetTemporarySafe(context, node, /*index=*/3, &input_offsets));
969 TfLiteTensor* row_sums;
970 TF_LITE_ENSURE_OK(context,
971 GetTemporarySafe(context, node, /*index=*/4, &row_sums));
972 return EvalHybrid(context, node, params, data, input, filter, bias,
973 input_quantized, scaling_factors, accum_scratch, row_sums,
974 input_offsets, output);
975 } else {
976 FullyConnectedParams op_params;
977 op_params.input_offset = input_offset;
978 op_params.weights_offset = filter_offset;
979 op_params.output_offset = output_offset;
980 op_params.output_multiplier = data->output_multiplier;
981 op_params.output_shift = data->output_shift;
982 op_params.quantized_activation_min = data->output_activation_min;
983 op_params.quantized_activation_max = data->output_activation_max;
984 op_params.lhs_cacheable = IsConstantTensor(filter);
985 op_params.rhs_cacheable = IsConstantTensor(input);
986 switch (output->type) {
987 case kTfLiteUInt8:
988 if (kernel_type == kReference) {
989 reference_ops::FullyConnected(
990 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
991 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
992 GetTensorShape(bias), GetTensorData<int32_t>(bias),
993 GetTensorShape(output), GetTensorData<uint8_t>(output));
994 } else {
995 optimized_ops::FullyConnected(
996 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
997 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
998 GetTensorShape(bias), GetTensorData<int32_t>(bias),
999 GetTensorShape(output), GetTensorData<uint8_t>(output),
1000 CpuBackendContext::GetFromContext(context));
1001 }
1002 break;
1003 case kTfLiteInt8:
1004 if (filter->sparsity != nullptr) {
1005 const TfLiteSparsity& sparsity = *filter->sparsity;
1006 const auto input_shape = GetTensorShape(input);
1007 const auto filter_shape = GetTensorShape(filter);
1008 const auto output_shape = GetTensorShape(output);
1009 const auto bias_shape = GetTensorShape(bias);
1010 if (filter_offset != 0) {
1011 TF_LITE_KERNEL_LOG(context,
1012 "Quantized and sparse fully-connected format "
1013 "supports symmetric weight quantization only.");
1014 return kTfLiteError;
1015 }
1016 if (!SupportedSparsityFormat(sparsity) ||
1017 !VerifySparsity(filter_shape, input_shape, output_shape,
1018 &sparsity)) {
1019 TF_LITE_KERNEL_LOG(
1020 context,
1021 "Invalid quantized and sparse fully-connected format.");
1022 return kTfLiteError;
1023 }
1024 if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
1025 sparsity.dim_metadata[2].dense_size == 16) {
1026 // Block sparse with block size of 1x16.
1027 optimized_ops::FullyConnectedSparseWeight1x16(
1028 sparsity, op_params, input_shape, GetTensorData<int8_t>(input),
1029 filter_shape, GetTensorData<int8_t>(filter), bias_shape,
1030 GetTensorData<int32_t>(bias), output_shape,
1031 GetTensorData<int8_t>(output),
1032 CpuBackendContext::GetFromContext(context));
1033 } else {
1034 TF_LITE_KERNEL_LOG(
1035 context, "Unsupported sparse fully-connected weight format.");
1036 return kTfLiteError;
1037 }
1038 } else {
1039 is_per_channel ? FullyConnectedPerChannelInt8<kernel_type>(
1040 data, input, filter, bias, output,
1041 CpuBackendContext::GetFromContext(context))
1042 : FullyConnectedInt8<kernel_type>(
1043 data, input, filter, bias, output,
1044 CpuBackendContext::GetFromContext(context));
1045 }
1046 break;
1047 case kTfLiteInt16:
1048 if (input->type == kTfLiteInt16) {
1049 // To avoid 32bit accum overflow, it enables RUY only
1050 // when zero_point is 0.
1051 bool has_non_zero_point = input->params.zero_point ||
1052 filter->params.zero_point ||
1053 output->params.zero_point;
1054 if (kernel_type == kReference || has_non_zero_point ||
1055 (bias && bias->type == kTfLiteInt64)) {
1056 is_per_channel ? FullyConnectedPerChannelInt16<kernel_type>(
1057 data, input, filter, bias, output)
1058 : FullyConnectedInt16<kernel_type>(
1059 data, input, filter, bias, output);
1060 } else {
1061 is_per_channel
1062 ? optimized_integer_ops::FullyConnectedPerChannel(
1063 op_params, data->per_channel_output_multiplier.data(),
1064 data->per_channel_output_shift.data(),
1065 GetTensorShape(input), GetTensorData<int16_t>(input),
1066 GetTensorShape(filter), GetTensorData<int8_t>(filter),
1067 GetTensorShape(bias), GetTensorData<int32_t>(bias),
1068 GetTensorShape(output), GetTensorData<int16_t>(output),
1069 CpuBackendContext::GetFromContext(context))
1070 : optimized_integer_ops::FullyConnected(
1071 op_params, GetTensorShape(input),
1072 GetTensorData<int16_t>(input), GetTensorShape(filter),
1073 GetTensorData<int8_t>(filter), GetTensorShape(bias),
1074 GetTensorData<int32_t>(bias), GetTensorShape(output),
1075 GetTensorData<int16_t>(output),
1076 CpuBackendContext::GetFromContext(context));
1077 }
1078 } else if (kernel_type == kReference) {
1079 reference_ops::FullyConnected(
1080 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1081 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
1082 GetTensorShape(bias), GetTensorData<int32_t>(bias),
1083 GetTensorShape(output), GetTensorData<int16_t>(output));
1084 } else {
1085 optimized_ops::FullyConnected(
1086 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1087 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
1088 GetTensorShape(bias), GetTensorData<int32_t>(bias),
1089 GetTensorShape(output), GetTensorData<int16_t>(output),
1090 CpuBackendContext::GetFromContext(context));
1091 }
1092 break;
1093 default:
1094 TF_LITE_KERNEL_LOG(context,
1095 "Quantized FullyConnected expects output data "
1096 "type uint8, int8 or int16");
1097 return kTfLiteError;
1098 }
1099 }
1100
1101 return kTfLiteOk;
1102}
1103
1104template <KernelType kernel_type>
1105TfLiteStatus EvalShuffledQuantized(TfLiteContext* context, TfLiteNode* node,
1106 TfLiteFullyConnectedParams* params,
1107 OpData* data, const TfLiteTensor* input,
1108 const TfLiteTensor* filter,
1109 const TfLiteTensor* bias,
1110 TfLiteTensor* output,
1111 TfLiteTensor* shuffled_input_workspace) {
1112 // TODO(b/110697972) decide more consistently if / how / where we want
1113 // to perform this kind of runtime data type checks.
1114 if (shuffled_input_workspace->type != kTfLiteUInt8) {
1115 TF_LITE_KERNEL_LOG(context, "Unexpected data type");
1116 return kTfLiteError;
1117 }
1118
1119#define TF_LITE_SHUFFLED_FULLY_CONNECTED(type) \
1120 { \
1121 type::ShuffledFullyConnected( \
1122 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), \
1123 GetTensorShape(filter), GetTensorData<uint8_t>(filter), \
1124 GetTensorShape(bias), GetTensorData<int32_t>(bias), \
1125 GetTensorShape(output), GetTensorData<int16_t>(output), \
1126 GetTensorData<uint8_t>(shuffled_input_workspace), \
1127 CpuBackendContext::GetFromContext(context)); \
1128 }
1129 FullyConnectedParams op_params;
1130 op_params.output_multiplier = data->output_multiplier;
1131 op_params.output_shift = data->output_shift;
1132 op_params.quantized_activation_min = data->output_activation_min;
1133 op_params.quantized_activation_max = data->output_activation_max;
1134 op_params.lhs_cacheable = IsConstantTensor(filter);
1135 op_params.rhs_cacheable = IsConstantTensor(input);
1136 if (kernel_type == kReference) {
1137 reference_ops::ShuffledFullyConnected(
1138 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1139 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
1140 GetTensorShape(bias), GetTensorData<int32_t>(bias),
1141 GetTensorShape(output), GetTensorData<int16_t>(output),
1142 GetTensorData<uint8_t>(shuffled_input_workspace));
1143 } else {
1144 optimized_ops::ShuffledFullyConnected(
1145 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
1146 GetTensorShape(filter), GetTensorData<uint8_t>(filter),
1147 GetTensorShape(bias), GetTensorData<int32_t>(bias),
1148 GetTensorShape(output), GetTensorData<int16_t>(output),
1149 GetTensorData<uint8_t>(shuffled_input_workspace),
1150 CpuBackendContext::GetFromContext(context));
1151 }
1152#undef TF_LITE_SHUFFLED_FULLY_CONNECTED
1153
1154 return kTfLiteOk;
1155}
1156
1157template <KernelType kernel_type>
1158TfLiteStatus EvalFloat(TfLiteContext* context, TfLiteNode* node,
1159 TfLiteFullyConnectedParams* params, OpData* data,
1160 const TfLiteTensor* input, const TfLiteTensor* filter,
1161 const TfLiteTensor* bias, TfLiteTensor* output) {
1162 float output_activation_min, output_activation_max;
1163 CalculateActivationRange(params->activation, &output_activation_min,
1164 &output_activation_max);
1165 if (kernel_type == kReference) {
1166 FullyConnectedParams op_params;
1167 op_params.float_activation_min = output_activation_min;
1168 op_params.float_activation_max = output_activation_max;
1169 if (filter->sparsity != nullptr) {
1170 const auto& sparsity = *filter->sparsity;
1171 reference_ops::FullyConnectedSparseWeight(
1172 sparsity, op_params, GetTensorShape(input),
1173 GetTensorData<float>(input), GetTensorShape(filter),
1174 GetTensorData<float>(filter), GetTensorShape(bias),
1175 GetTensorData<float>(bias), GetTensorShape(output),
1176 GetTensorData<float>(output));
1177 } else {
1178 reference_ops::FullyConnected(
1179 op_params, GetTensorShape(input), GetTensorData<float>(input),
1180 GetTensorShape(filter), GetTensorData<float>(filter),
1181 GetTensorShape(bias), GetTensorData<float>(bias),
1182 GetTensorShape(output), GetTensorData<float>(output));
1183 }
1184 } else if (kernel_type == kLegacyPie) {
1185 return EvalPie(context, node, params, data, input, filter, bias, output);
1186 } else {
1187 FullyConnectedParams op_params;
1188 op_params.float_activation_min = output_activation_min;
1189 op_params.float_activation_max = output_activation_max;
1190 if (filter->sparsity != nullptr) {
1191 const auto& sparsity = *filter->sparsity;
1192 if (!SupportedSparsityFormat(sparsity)) {
1193 TF_LITE_KERNEL_LOG(context,
1194 "Unsupported sparse fully-connected weight format.");
1195 return kTfLiteError;
1196 }
1197 const auto& input_shape = GetTensorShape(input);
1198 const auto& filter_shape = GetTensorShape(filter);
1199 const auto& output_shape = GetTensorShape(output);
1200 const auto& bias_shape = GetTensorShape(bias);
1201 if (!VerifySparsity(filter_shape, input_shape, output_shape, &sparsity)) {
1202 TF_LITE_KERNEL_LOG(context, "Invalid sparse fully-connected format.");
1203 return kTfLiteError;
1204 }
1205
1206 if (sparsity.dim_metadata_size == kDimMetadataSizeRandomSparse) {
1207 // Random sparse.
1208 optimized_ops::FullyConnectedSparseWeight(
1209 sparsity, op_params, // Disable formatting
1210 input_shape, GetTensorData<float>(input), // Disable formatting
1211 filter_shape, GetTensorData<float>(filter), // Disable formatting
1212 bias_shape, GetTensorData<float>(bias), // Disable formatting
1213 output_shape, GetTensorData<float>(output));
1214 } else if (sparsity.dim_metadata_size == kDimMetadataSizeBlockSparse &&
1215 sparsity.dim_metadata[2].dense_size == 4) {
1216 // Block sparse with block size of 1x4.
1217 optimized_ops::FullyConnectedSparseWeight1x4(
1218 sparsity, op_params, // Disable formatting
1219 input_shape, GetTensorData<float>(input), // Disable formatting
1220 filter_shape, GetTensorData<float>(filter), // Disable formatting
1221 bias_shape, GetTensorData<float>(bias), // Disable formatting
1222 output_shape, GetTensorData<float>(output),
1223 CpuBackendContext::GetFromContext(context));
1224 } else {
1225 TF_LITE_KERNEL_LOG(context,
1226 "Unsupported sparse fully-connected weight format.");
1227 return kTfLiteError;
1228 }
1229
1230 } else {
1231 op_params.lhs_cacheable = IsConstantTensor(filter);
1232 op_params.rhs_cacheable = IsConstantTensor(input);
1233 optimized_ops::FullyConnected(
1234 op_params, GetTensorShape(input), GetTensorData<float>(input),
1235 GetTensorShape(filter), GetTensorData<float>(filter),
1236 GetTensorShape(bias), GetTensorData<float>(bias),
1237 GetTensorShape(output), GetTensorData<float>(output),
1238 CpuBackendContext::GetFromContext(context));
1239 }
1240 }
1241
1242 return kTfLiteOk;
1243}
1244
1245template <KernelType kernel_type>
1246TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
1247 auto* params =
1248 reinterpret_cast<TfLiteFullyConnectedParams*>(node->builtin_data);
1249 OpData* data = reinterpret_cast<OpData*>(node->user_data);
1250
1251 const TfLiteTensor* input;
1252 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
1253 const TfLiteTensor* filter;
1254 TF_LITE_ENSURE_OK(context,
1255 GetInputSafe(context, node, kWeightsTensor, &filter));
1256 const TfLiteTensor* bias =
1257 (node->inputs->size == 3)
1258 ? GetOptionalInputTensor(context, node, kBiasTensor)
1259 : nullptr;
1260 TfLiteTensor* output;
1261 TF_LITE_ENSURE_OK(context,
1262 GetOutputSafe(context, node, kOutputTensor, &output));
1263 // Do nothing if expected output is empty.
1264 if (NumElements(output) == 0) {
1265 return kTfLiteOk;
1266 }
1267
1268 if (filter->dims->data[1] == 0) {
1269 memset(output->data.data, 0, output->bytes);
1270 return kTfLiteOk;
1271 }
1272
1273 switch (filter->type) {
1274 case kTfLiteFloat32:
1275 return EvalFloat<kernel_type>(context, node, params, data, input, filter,
1276 bias, output);
1277 case kTfLiteUInt8:
1278 if (params->weights_format ==
1279 kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8) {
1280 TfLiteTensor* shuffled_input_workspace;
1281 TF_LITE_ENSURE_OK(
1282 context, GetOutputSafe(context, node, kShuffledInputWorkspaceTensor,
1283 &shuffled_input_workspace));
1284 return EvalShuffledQuantized<kernel_type>(context, node, params, data,
1285 input, filter, bias, output,
1286 shuffled_input_workspace);
1287 } else if (params->weights_format ==
1288 kTfLiteFullyConnectedWeightsFormatDefault) {
1289 return EvalQuantized<kernel_type>(context, node, params, data, input,
1290 filter, bias, output);
1291 } else {
1292 TF_LITE_KERNEL_LOG(context, "Unhandled fully-connected weights format");
1293 return kTfLiteError;
1294 }
1295 case kTfLiteInt8:
1296 if (params->weights_format == kTfLiteFullyConnectedWeightsFormatDefault) {
1297 return EvalQuantized<kernel_type>(context, node, params, data, input,
1298 filter, bias, output);
1299 } else {
1300 TF_LITE_KERNEL_LOG(context, "Unhandled fully-connected weights format");
1301 return kTfLiteError;
1302 }
1303 default:
1304 TF_LITE_KERNEL_LOG(context,
1305 "Filter data type %s currently not supported.",
1306 TfLiteTypeGetName(filter->type));
1307 return kTfLiteError;
1308 }
1309 return kTfLiteOk;
1310}
1311
1312} // namespace fully_connected
1313
1314TfLiteRegistration* Register_FULLY_CONNECTED_REF() {
1315 static TfLiteRegistration r = {
1316 fully_connected::Init, fully_connected::Free,
1317 fully_connected::Prepare<fully_connected::kReference>,
1318 fully_connected::Eval<fully_connected::kReference>};
1319 return &r;
1320}
1321
1322TfLiteRegistration* Register_FULLY_CONNECTED_GENERIC_OPT() {
1323 static TfLiteRegistration r = {
1324 fully_connected::Init, fully_connected::Free,
1325 fully_connected::Prepare<fully_connected::kGenericOptimized>,
1326 fully_connected::Eval<fully_connected::kGenericOptimized>};
1327 return &r;
1328}
1329
1330// Legacy path for PIE clients.
1331TfLiteRegistration* Register_FULLY_CONNECTED_PIE() {
1332 static TfLiteRegistration r = {
1333 fully_connected::Init, fully_connected::Free,
1334 fully_connected::Prepare<fully_connected::kLegacyPie>,
1335 fully_connected::Eval<fully_connected::kLegacyPie>};
1336 return &r;
1337}
1338
1339TfLiteRegistration* Register_FULLY_CONNECTED() {
1340 return Register_FULLY_CONNECTED_GENERIC_OPT();
1341}
1342
1343} // namespace builtin
1344} // namespace ops
1345} // namespace tflite
1346