1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// SVDF op that compresses a fully connected op via low-rank matrix
17// factorization. See https://research.google.com/pubs/archive/43813.pdf for
18// details.
19
20#include "tensorflow/lite/kernels/internal/reference/svdf.h"
21
22#include <cstddef>
23#include <cstdint>
24
25#include "tensorflow/lite/c/builtin_op_data.h"
26#include "tensorflow/lite/c/common.h"
27#include "tensorflow/lite/kernels/internal/compatibility.h"
28#include "tensorflow/lite/kernels/internal/quantization_util.h"
29#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
30#include "tensorflow/lite/kernels/internal/tensor_utils.h"
31#include "tensorflow/lite/kernels/kernel_util.h"
32
33namespace tflite {
34namespace ops {
35namespace builtin {
36namespace svdf {
37
38namespace {
39
40struct OpData {
41 int scratch_tensor_index;
42 bool float_weights_time_initialized;
43 int32 effective_scale_1_a;
44 int effective_scale_1_b;
45 int32 effective_scale_2_a;
46 int effective_scale_2_b;
47 bool compute_row_sums = false;
48};
49
50} // namespace
51
52// Input tensors.
53constexpr int kInputTensor = 0;
54constexpr int kWeightsFeatureTensor = 1;
55constexpr int kWeightsTimeTensor = 2;
56constexpr int kBiasTensor = 3;
57// This is a variable tensor, and will be modified by this op.
58constexpr int kStateTensor = 4;
59
60// Output tensor.
61constexpr int kOutputTensor = 0;
62
63void* Init(TfLiteContext* context, const char* buffer, size_t length) {
64 auto* op_data = new OpData();
65 op_data->float_weights_time_initialized = false;
66 // Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise.
67 context->AddTensors(context, /*tensors_to_add=*/6,
68 &op_data->scratch_tensor_index);
69 return op_data;
70}
71
72void Free(TfLiteContext* context, void* buffer) {
73 delete reinterpret_cast<OpData*>(buffer);
74}
75
76TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
77 const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
78 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
79 int scratch_tensor_index = op_data->scratch_tensor_index;
80
81 // Check we have all the inputs and outputs we need.
82 TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
83 TF_LITE_ENSURE_EQ(context, node->inputs->size, 5);
84
85 const TfLiteTensor* input;
86 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
87 const TfLiteTensor* weights_feature;
88 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
89 &weights_feature));
90 const TfLiteTensor* weights_time;
91 TF_LITE_ENSURE_OK(
92 context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
93
94 TF_LITE_ENSURE(context,
95 input->type == kTfLiteFloat32 || input->type == kTfLiteInt8);
96
97 // Check all the parameters of tensor match within themselves and match the
98 // input configuration.
99 const int rank = params->rank;
100 const int batch_size = input->dims->data[0];
101 const int num_filters = weights_feature->dims->data[0];
102 TF_LITE_ENSURE(context, rank != 0);
103 TF_LITE_ENSURE_EQ(context, num_filters % rank, 0);
104 const int num_units = num_filters / rank;
105 const int memory_size = weights_time->dims->data[1];
106 TF_LITE_ENSURE_EQ(context, input->dims->data[1],
107 weights_feature->dims->data[1]);
108 TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters);
109
110 const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
111 if (bias) {
112 TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units);
113 }
114
115 const TfLiteTensor* state;
116 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStateTensor, &state));
117 TfLiteTensor* output;
118 TF_LITE_ENSURE_OK(context,
119 GetOutputSafe(context, node, kOutputTensor, &output));
120
121 // Check the shape of input state tensors.
122 TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2);
123 TF_LITE_ENSURE_EQ(context, SizeOfDimension(state, 0), batch_size);
124 TF_LITE_ENSURE_EQ(context, SizeOfDimension(state, 1),
125 memory_size * num_filters);
126
127 // Resize output.
128 TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2);
129 output_size_array->data[0] = batch_size;
130 output_size_array->data[1] = num_units;
131 TF_LITE_ENSURE_OK(context,
132 context->ResizeTensor(context, output, output_size_array));
133
134 // The weights are of consistent type, so it suffices to check one.
135 const bool is_hybrid_op = IsHybridOp(input, weights_feature);
136 const bool is_full_integer = input->type == kTfLiteInt8;
137
138 // Resize scratch.
139 TfLiteIntArrayFree(node->temporaries);
140 if (is_hybrid_op) {
141 node->temporaries = TfLiteIntArrayCreate(6);
142 } else if (is_full_integer) {
143 node->temporaries = TfLiteIntArrayCreate(2);
144 } else {
145 node->temporaries = TfLiteIntArrayCreate(1);
146 }
147 node->temporaries->data[0] = scratch_tensor_index;
148
149 TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2);
150 scratch_size_array->data[0] = batch_size;
151 scratch_size_array->data[1] = num_filters;
152
153 TfLiteTensor* scratch_tensor;
154 TF_LITE_ENSURE_OK(
155 context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor));
156
157 // The scratch buffer is of type int32 for full integer svdf and it's of type
158 // float32 for hybrid and float case.
159 if (is_full_integer) {
160 scratch_tensor->type = kTfLiteInt32;
161 } else {
162 scratch_tensor->type = kTfLiteFloat32;
163 }
164 scratch_tensor->allocation_type = kTfLiteArenaRw;
165 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor,
166 scratch_size_array));
167
168 if (is_hybrid_op) {
169 op_data->compute_row_sums = true;
170 // Tell interpreter to allocate temporary tensors to store quantized values
171 // of input tensors.
172 node->temporaries->data[1] = scratch_tensor_index + 1;
173 TfLiteTensor* input_quantized;
174 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
175 &input_quantized));
176 input_quantized->type = weights_feature->type;
177 input_quantized->allocation_type = kTfLiteArenaRw;
178 if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
179 TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
180 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
181 input_quantized_size));
182 }
183
184 // Tell interpreter to allocate temporary tensors to store scaling factors.
185 node->temporaries->data[2] = scratch_tensor_index + 2;
186 TfLiteTensor* scaling_factors;
187 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
188 &scaling_factors));
189 scaling_factors->type = kTfLiteFloat32;
190 scaling_factors->allocation_type = kTfLiteArenaRw;
191 int scaling_dims[1] = {batch_size};
192 if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
193 TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
194 scaling_factors_size->data[0] = batch_size;
195 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
196 scaling_factors_size));
197 }
198
199 // Used to store dequantized weights_time matrix for hybrid computation of
200 // matmul(state, weights_time), which occurs in floating point.
201 node->temporaries->data[3] = scratch_tensor_index + 3;
202 TfLiteTensor* float_weights_time;
203 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
204 &float_weights_time));
205 float_weights_time->type = kTfLiteFloat32;
206 float_weights_time->name = "Svdf_float_weights_time";
207 // Persistent so that we can compute the dequantized weights only once.
208 float_weights_time->allocation_type = kTfLiteArenaRwPersistent;
209 if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) {
210 TfLiteIntArray* float_weights_time_size =
211 TfLiteIntArrayCopy(weights_time->dims);
212 TF_LITE_ENSURE_OK(context,
213 context->ResizeTensor(context, float_weights_time,
214 float_weights_time_size));
215 }
216
217 node->temporaries->data[4] = scratch_tensor_index + 4;
218 TfLiteTensor* zero_points;
219 TF_LITE_ENSURE_OK(
220 context, GetTemporarySafe(context, node, /*index=*/4, &zero_points));
221 zero_points->type = kTfLiteFloat32;
222 zero_points->allocation_type = kTfLiteArenaRw;
223 int zero_points_dims[1] = {batch_size};
224 if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) {
225 TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1);
226 zero_points_size->data[0] = zero_points_dims[0];
227 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points,
228 zero_points_size));
229 }
230
231 node->temporaries->data[5] = scratch_tensor_index + 5;
232 TfLiteTensor* row_sums;
233 TF_LITE_ENSURE_OK(context,
234 GetTemporarySafe(context, node, /*index=*/5, &row_sums));
235 row_sums->type = kTfLiteFloat32;
236 float_weights_time->name = "Svdf_row_sums";
237 row_sums->allocation_type = kTfLiteArenaRwPersistent;
238 int row_sums_dims[1] = {num_filters};
239 if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) {
240 TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1);
241 row_sums_size->data[0] = row_sums_dims[0];
242 TF_LITE_ENSURE_OK(
243 context, context->ResizeTensor(context, row_sums, row_sums_size));
244 }
245 }
246 if (is_full_integer) {
247 // Allocated one extra tensor.
248 TfLiteIntArray* output_temp_size_array = TfLiteIntArrayCreate(2);
249 output_temp_size_array->data[0] = num_units;
250 output_temp_size_array->data[1] = batch_size;
251 node->temporaries->data[1] = scratch_tensor_index + 1;
252 TfLiteTensor* output_temp;
253 TF_LITE_ENSURE_OK(
254 context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
255 output_temp->type = kTfLiteInt32;
256 output_temp->allocation_type = kTfLiteArenaRw;
257 TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp,
258 output_temp_size_array));
259
260 // Calculate effective scales.
261 TF_LITE_ENSURE(context, input->quantization.type != kTfLiteNoQuantization);
262 auto* input_params =
263 reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
264 TF_LITE_ENSURE(context,
265 weights_feature->quantization.type != kTfLiteNoQuantization);
266 auto* weights_feature_params = reinterpret_cast<TfLiteAffineQuantization*>(
267 weights_feature->quantization.params);
268 TF_LITE_ENSURE(context, state->quantization.type != kTfLiteNoQuantization);
269 auto* state_params =
270 reinterpret_cast<TfLiteAffineQuantization*>(state->quantization.params);
271 TF_LITE_ENSURE(context,
272 weights_time->quantization.type != kTfLiteNoQuantization);
273 auto* weight_time_params = reinterpret_cast<TfLiteAffineQuantization*>(
274 weights_time->quantization.params);
275 TF_LITE_ENSURE(context, output->quantization.type != kTfLiteNoQuantization);
276 auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
277 output->quantization.params);
278 const double effective_scale_1 = input_params->scale->data[0] *
279 weights_feature_params->scale->data[0] /
280 state_params->scale->data[0];
281 const double effective_scale_2 = state_params->scale->data[0] *
282 weight_time_params->scale->data[0] /
283 output_params->scale->data[0];
284 QuantizeMultiplier(effective_scale_1, &op_data->effective_scale_1_a,
285 &op_data->effective_scale_1_b);
286 QuantizeMultiplier(effective_scale_2, &op_data->effective_scale_2_a,
287 &op_data->effective_scale_2_b);
288 }
289 return kTfLiteOk;
290}
291
292TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
293 auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data);
294 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
295
296 const TfLiteTensor* input;
297 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
298 const TfLiteTensor* weights_feature;
299 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor,
300 &weights_feature));
301 const TfLiteTensor* weights_time;
302 TF_LITE_ENSURE_OK(
303 context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time));
304 const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor);
305
306 TfLiteTensor* scratch;
307 TF_LITE_ENSURE_OK(context,
308 GetTemporarySafe(context, node, /*index=*/0, &scratch));
309
310 TfLiteTensor* state = GetVariableInput(context, node, kStateTensor);
311 TF_LITE_ENSURE(context, state != nullptr);
312 TfLiteTensor* output;
313 TF_LITE_ENSURE_OK(context,
314 GetOutputSafe(context, node, kOutputTensor, &output));
315
316 switch (weights_feature->type) {
317 case kTfLiteFloat32: {
318 reference_ops::EvalFloatSVDF(
319 params, GetTensorShape(input), GetTensorData<float>(input),
320 GetTensorShape(weights_feature),
321 GetTensorData<float>(weights_feature), GetTensorShape(weights_time),
322 GetTensorData<float>(weights_time), GetTensorShape(bias),
323 GetTensorData<float>(bias), GetTensorData<float>(scratch),
324 GetTensorData<float>(state), GetTensorShape(output),
325 GetTensorData<float>(output));
326 return kTfLiteOk;
327 }
328 case kTfLiteUInt8:
329 case kTfLiteInt8: {
330 if (input->type == kTfLiteFloat32) {
331 TfLiteTensor* input_quantized;
332 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1,
333 &input_quantized));
334 TfLiteTensor* scaling_factors;
335 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2,
336 &scaling_factors));
337 TfLiteTensor* float_weights_time;
338 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3,
339 &float_weights_time));
340 TfLiteTensor* zero_points;
341 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/4,
342 &zero_points));
343 TfLiteTensor* row_sums;
344 TF_LITE_ENSURE_OK(
345 context, GetTemporarySafe(context, node, /*index=*/5, &row_sums));
346 // Dequantize weights time.
347 // TODO(alanchiao): this dequantization initialization only needs to
348 // happen once per model and should theoretically be placed in either
349 // Init or Prepare. However, TFLite doesn't allocate float_weights_time
350 // until the Eval function.
351 // TODO(alanchiao): refactor logic out into dequantize function.
352 if (!op_data->float_weights_time_initialized) {
353 const float dequantization_scale = weights_time->params.scale;
354 const int8_t* weights_time_ptr = GetTensorData<int8_t>(weights_time);
355 float* float_weights_time_ptr =
356 GetTensorData<float>(float_weights_time);
357 for (int i = 0; i < NumElements(float_weights_time); ++i) {
358 float_weights_time_ptr[i] =
359 weights_time_ptr[i] * dequantization_scale;
360 }
361 op_data->float_weights_time_initialized = true;
362 }
363
364 int32_t* zero_points_ptr = nullptr;
365 int32_t* row_sums_ptr = nullptr;
366 if (params->asymmetric_quantize_inputs && row_sums != nullptr) {
367 zero_points_ptr = GetTensorData<int32_t>(zero_points);
368 row_sums_ptr = GetTensorData<int32_t>(row_sums);
369 }
370
371 reference_ops::EvalHybridSVDF(
372 params, GetTensorShape(input), GetTensorData<float>(input),
373 GetTensorShape(weights_feature),
374 GetTensorData<int8_t>(weights_feature),
375 weights_feature->params.scale, GetTensorShape(float_weights_time),
376 GetTensorData<float>(float_weights_time), GetTensorShape(bias),
377 GetTensorData<float>(bias), GetTensorData<float>(scratch),
378 GetTensorData<float>(scaling_factors),
379 GetTensorData<int8_t>(input_quantized), GetTensorData<float>(state),
380 GetTensorShape(output), GetTensorData<float>(output),
381 zero_points_ptr, row_sums_ptr, &op_data->compute_row_sums);
382 return kTfLiteOk;
383 }
384 auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>(
385 input->quantization.params);
386 auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>(
387 output->quantization.params);
388 TfLiteTensor* output_temp;
389 TF_LITE_ENSURE_OK(
390 context, GetTemporarySafe(context, node, /*index=*/1, &output_temp));
391
392 // Currently supports only ReLU.
393 // TODO(jianlijianli): support other activations.
394 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu);
395
396 reference_ops::EvalIntegerSVDF(
397 params, GetTensorShape(input), GetTensorData<int8_t>(input),
398 GetTensorShape(weights_feature),
399 GetTensorData<int8_t>(weights_feature), GetTensorShape(weights_time),
400 GetTensorData<int16_t>(weights_time), GetTensorShape(bias),
401 GetTensorData<int32_t>(bias), GetTensorData<int16_t>(state),
402 GetTensorShape(output), GetTensorData<int8_t>(output),
403 GetTensorData<int32_t>(scratch), GetTensorData<int32_t>(output_temp),
404 op_data->effective_scale_1_a, op_data->effective_scale_1_b,
405 op_data->effective_scale_2_a, op_data->effective_scale_2_b,
406 input_params->zero_point->data[0],
407 output_params->zero_point->data[0]);
408 return kTfLiteOk;
409 }
410 default:
411 TF_LITE_KERNEL_LOG(context, "Type %s not currently supported.",
412 TfLiteTypeGetName(weights_feature->type));
413 return kTfLiteError;
414 }
415}
416
417} // namespace svdf
418
419TfLiteRegistration* Register_SVDF() {
420 static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare,
421 svdf::Eval};
422 return &r;
423}
424
425} // namespace builtin
426} // namespace ops
427} // namespace tflite
428