1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // SVDF op that compresses a fully connected op via low-rank matrix |
17 | // factorization. See https://research.google.com/pubs/archive/43813.pdf for |
18 | // details. |
19 | |
20 | #include "tensorflow/lite/kernels/internal/reference/svdf.h" |
21 | |
22 | #include <cstddef> |
23 | #include <cstdint> |
24 | |
25 | #include "tensorflow/lite/c/builtin_op_data.h" |
26 | #include "tensorflow/lite/c/common.h" |
27 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
28 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
29 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
30 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
31 | #include "tensorflow/lite/kernels/kernel_util.h" |
32 | |
33 | namespace tflite { |
34 | namespace ops { |
35 | namespace builtin { |
36 | namespace svdf { |
37 | |
38 | namespace { |
39 | |
40 | struct OpData { |
41 | int scratch_tensor_index; |
42 | bool float_weights_time_initialized; |
43 | int32 effective_scale_1_a; |
44 | int effective_scale_1_b; |
45 | int32 effective_scale_2_a; |
46 | int effective_scale_2_b; |
47 | bool compute_row_sums = false; |
48 | }; |
49 | |
50 | } // namespace |
51 | |
52 | // Input tensors. |
53 | constexpr int kInputTensor = 0; |
54 | constexpr int kWeightsFeatureTensor = 1; |
55 | constexpr int kWeightsTimeTensor = 2; |
56 | constexpr int kBiasTensor = 3; |
57 | // This is a variable tensor, and will be modified by this op. |
58 | constexpr int kStateTensor = 4; |
59 | |
60 | // Output tensor. |
61 | constexpr int kOutputTensor = 0; |
62 | |
63 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
64 | auto* op_data = new OpData(); |
65 | op_data->float_weights_time_initialized = false; |
66 | // Note: only needs 6 scratch tensors when is_hybrid_op, only 1 otherwise. |
67 | context->AddTensors(context, /*tensors_to_add=*/6, |
68 | &op_data->scratch_tensor_index); |
69 | return op_data; |
70 | } |
71 | |
72 | void Free(TfLiteContext* context, void* buffer) { |
73 | delete reinterpret_cast<OpData*>(buffer); |
74 | } |
75 | |
76 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
77 | const auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data); |
78 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
79 | int scratch_tensor_index = op_data->scratch_tensor_index; |
80 | |
81 | // Check we have all the inputs and outputs we need. |
82 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
83 | TF_LITE_ENSURE_EQ(context, node->inputs->size, 5); |
84 | |
85 | const TfLiteTensor* input; |
86 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
87 | const TfLiteTensor* weights_feature; |
88 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor, |
89 | &weights_feature)); |
90 | const TfLiteTensor* weights_time; |
91 | TF_LITE_ENSURE_OK( |
92 | context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time)); |
93 | |
94 | TF_LITE_ENSURE(context, |
95 | input->type == kTfLiteFloat32 || input->type == kTfLiteInt8); |
96 | |
97 | // Check all the parameters of tensor match within themselves and match the |
98 | // input configuration. |
99 | const int rank = params->rank; |
100 | const int batch_size = input->dims->data[0]; |
101 | const int num_filters = weights_feature->dims->data[0]; |
102 | TF_LITE_ENSURE(context, rank != 0); |
103 | TF_LITE_ENSURE_EQ(context, num_filters % rank, 0); |
104 | const int num_units = num_filters / rank; |
105 | const int memory_size = weights_time->dims->data[1]; |
106 | TF_LITE_ENSURE_EQ(context, input->dims->data[1], |
107 | weights_feature->dims->data[1]); |
108 | TF_LITE_ENSURE_EQ(context, weights_time->dims->data[0], num_filters); |
109 | |
110 | const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); |
111 | if (bias) { |
112 | TF_LITE_ENSURE_EQ(context, bias->dims->data[0], num_units); |
113 | } |
114 | |
115 | const TfLiteTensor* state; |
116 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kStateTensor, &state)); |
117 | TfLiteTensor* output; |
118 | TF_LITE_ENSURE_OK(context, |
119 | GetOutputSafe(context, node, kOutputTensor, &output)); |
120 | |
121 | // Check the shape of input state tensors. |
122 | TF_LITE_ENSURE_EQ(context, NumDimensions(state), 2); |
123 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(state, 0), batch_size); |
124 | TF_LITE_ENSURE_EQ(context, SizeOfDimension(state, 1), |
125 | memory_size * num_filters); |
126 | |
127 | // Resize output. |
128 | TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); |
129 | output_size_array->data[0] = batch_size; |
130 | output_size_array->data[1] = num_units; |
131 | TF_LITE_ENSURE_OK(context, |
132 | context->ResizeTensor(context, output, output_size_array)); |
133 | |
134 | // The weights are of consistent type, so it suffices to check one. |
135 | const bool is_hybrid_op = IsHybridOp(input, weights_feature); |
136 | const bool is_full_integer = input->type == kTfLiteInt8; |
137 | |
138 | // Resize scratch. |
139 | TfLiteIntArrayFree(node->temporaries); |
140 | if (is_hybrid_op) { |
141 | node->temporaries = TfLiteIntArrayCreate(6); |
142 | } else if (is_full_integer) { |
143 | node->temporaries = TfLiteIntArrayCreate(2); |
144 | } else { |
145 | node->temporaries = TfLiteIntArrayCreate(1); |
146 | } |
147 | node->temporaries->data[0] = scratch_tensor_index; |
148 | |
149 | TfLiteIntArray* scratch_size_array = TfLiteIntArrayCreate(2); |
150 | scratch_size_array->data[0] = batch_size; |
151 | scratch_size_array->data[1] = num_filters; |
152 | |
153 | TfLiteTensor* scratch_tensor; |
154 | TF_LITE_ENSURE_OK( |
155 | context, GetTemporarySafe(context, node, /*index=*/0, &scratch_tensor)); |
156 | |
157 | // The scratch buffer is of type int32 for full integer svdf and it's of type |
158 | // float32 for hybrid and float case. |
159 | if (is_full_integer) { |
160 | scratch_tensor->type = kTfLiteInt32; |
161 | } else { |
162 | scratch_tensor->type = kTfLiteFloat32; |
163 | } |
164 | scratch_tensor->allocation_type = kTfLiteArenaRw; |
165 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_tensor, |
166 | scratch_size_array)); |
167 | |
168 | if (is_hybrid_op) { |
169 | op_data->compute_row_sums = true; |
170 | // Tell interpreter to allocate temporary tensors to store quantized values |
171 | // of input tensors. |
172 | node->temporaries->data[1] = scratch_tensor_index + 1; |
173 | TfLiteTensor* input_quantized; |
174 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1, |
175 | &input_quantized)); |
176 | input_quantized->type = weights_feature->type; |
177 | input_quantized->allocation_type = kTfLiteArenaRw; |
178 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
179 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
180 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
181 | input_quantized_size)); |
182 | } |
183 | |
184 | // Tell interpreter to allocate temporary tensors to store scaling factors. |
185 | node->temporaries->data[2] = scratch_tensor_index + 2; |
186 | TfLiteTensor* scaling_factors; |
187 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, |
188 | &scaling_factors)); |
189 | scaling_factors->type = kTfLiteFloat32; |
190 | scaling_factors->allocation_type = kTfLiteArenaRw; |
191 | int scaling_dims[1] = {batch_size}; |
192 | if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) { |
193 | TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1); |
194 | scaling_factors_size->data[0] = batch_size; |
195 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors, |
196 | scaling_factors_size)); |
197 | } |
198 | |
199 | // Used to store dequantized weights_time matrix for hybrid computation of |
200 | // matmul(state, weights_time), which occurs in floating point. |
201 | node->temporaries->data[3] = scratch_tensor_index + 3; |
202 | TfLiteTensor* float_weights_time; |
203 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3, |
204 | &float_weights_time)); |
205 | float_weights_time->type = kTfLiteFloat32; |
206 | float_weights_time->name = "Svdf_float_weights_time" ; |
207 | // Persistent so that we can compute the dequantized weights only once. |
208 | float_weights_time->allocation_type = kTfLiteArenaRwPersistent; |
209 | if (!TfLiteIntArrayEqual(float_weights_time->dims, weights_time->dims)) { |
210 | TfLiteIntArray* float_weights_time_size = |
211 | TfLiteIntArrayCopy(weights_time->dims); |
212 | TF_LITE_ENSURE_OK(context, |
213 | context->ResizeTensor(context, float_weights_time, |
214 | float_weights_time_size)); |
215 | } |
216 | |
217 | node->temporaries->data[4] = scratch_tensor_index + 4; |
218 | TfLiteTensor* zero_points; |
219 | TF_LITE_ENSURE_OK( |
220 | context, GetTemporarySafe(context, node, /*index=*/4, &zero_points)); |
221 | zero_points->type = kTfLiteFloat32; |
222 | zero_points->allocation_type = kTfLiteArenaRw; |
223 | int zero_points_dims[1] = {batch_size}; |
224 | if (!TfLiteIntArrayEqualsArray(zero_points->dims, 1, zero_points_dims)) { |
225 | TfLiteIntArray* zero_points_size = TfLiteIntArrayCreate(1); |
226 | zero_points_size->data[0] = zero_points_dims[0]; |
227 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, zero_points, |
228 | zero_points_size)); |
229 | } |
230 | |
231 | node->temporaries->data[5] = scratch_tensor_index + 5; |
232 | TfLiteTensor* row_sums; |
233 | TF_LITE_ENSURE_OK(context, |
234 | GetTemporarySafe(context, node, /*index=*/5, &row_sums)); |
235 | row_sums->type = kTfLiteFloat32; |
236 | float_weights_time->name = "Svdf_row_sums" ; |
237 | row_sums->allocation_type = kTfLiteArenaRwPersistent; |
238 | int row_sums_dims[1] = {num_filters}; |
239 | if (!TfLiteIntArrayEqualsArray(row_sums->dims, 1, row_sums_dims)) { |
240 | TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(1); |
241 | row_sums_size->data[0] = row_sums_dims[0]; |
242 | TF_LITE_ENSURE_OK( |
243 | context, context->ResizeTensor(context, row_sums, row_sums_size)); |
244 | } |
245 | } |
246 | if (is_full_integer) { |
247 | // Allocated one extra tensor. |
248 | TfLiteIntArray* output_temp_size_array = TfLiteIntArrayCreate(2); |
249 | output_temp_size_array->data[0] = num_units; |
250 | output_temp_size_array->data[1] = batch_size; |
251 | node->temporaries->data[1] = scratch_tensor_index + 1; |
252 | TfLiteTensor* output_temp; |
253 | TF_LITE_ENSURE_OK( |
254 | context, GetTemporarySafe(context, node, /*index=*/1, &output_temp)); |
255 | output_temp->type = kTfLiteInt32; |
256 | output_temp->allocation_type = kTfLiteArenaRw; |
257 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_temp, |
258 | output_temp_size_array)); |
259 | |
260 | // Calculate effective scales. |
261 | TF_LITE_ENSURE(context, input->quantization.type != kTfLiteNoQuantization); |
262 | auto* input_params = |
263 | reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params); |
264 | TF_LITE_ENSURE(context, |
265 | weights_feature->quantization.type != kTfLiteNoQuantization); |
266 | auto* weights_feature_params = reinterpret_cast<TfLiteAffineQuantization*>( |
267 | weights_feature->quantization.params); |
268 | TF_LITE_ENSURE(context, state->quantization.type != kTfLiteNoQuantization); |
269 | auto* state_params = |
270 | reinterpret_cast<TfLiteAffineQuantization*>(state->quantization.params); |
271 | TF_LITE_ENSURE(context, |
272 | weights_time->quantization.type != kTfLiteNoQuantization); |
273 | auto* weight_time_params = reinterpret_cast<TfLiteAffineQuantization*>( |
274 | weights_time->quantization.params); |
275 | TF_LITE_ENSURE(context, output->quantization.type != kTfLiteNoQuantization); |
276 | auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>( |
277 | output->quantization.params); |
278 | const double effective_scale_1 = input_params->scale->data[0] * |
279 | weights_feature_params->scale->data[0] / |
280 | state_params->scale->data[0]; |
281 | const double effective_scale_2 = state_params->scale->data[0] * |
282 | weight_time_params->scale->data[0] / |
283 | output_params->scale->data[0]; |
284 | QuantizeMultiplier(effective_scale_1, &op_data->effective_scale_1_a, |
285 | &op_data->effective_scale_1_b); |
286 | QuantizeMultiplier(effective_scale_2, &op_data->effective_scale_2_a, |
287 | &op_data->effective_scale_2_b); |
288 | } |
289 | return kTfLiteOk; |
290 | } |
291 | |
292 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
293 | auto* params = reinterpret_cast<TfLiteSVDFParams*>(node->builtin_data); |
294 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
295 | |
296 | const TfLiteTensor* input; |
297 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
298 | const TfLiteTensor* weights_feature; |
299 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kWeightsFeatureTensor, |
300 | &weights_feature)); |
301 | const TfLiteTensor* weights_time; |
302 | TF_LITE_ENSURE_OK( |
303 | context, GetInputSafe(context, node, kWeightsTimeTensor, &weights_time)); |
304 | const TfLiteTensor* bias = GetOptionalInputTensor(context, node, kBiasTensor); |
305 | |
306 | TfLiteTensor* scratch; |
307 | TF_LITE_ENSURE_OK(context, |
308 | GetTemporarySafe(context, node, /*index=*/0, &scratch)); |
309 | |
310 | TfLiteTensor* state = GetVariableInput(context, node, kStateTensor); |
311 | TF_LITE_ENSURE(context, state != nullptr); |
312 | TfLiteTensor* output; |
313 | TF_LITE_ENSURE_OK(context, |
314 | GetOutputSafe(context, node, kOutputTensor, &output)); |
315 | |
316 | switch (weights_feature->type) { |
317 | case kTfLiteFloat32: { |
318 | reference_ops::EvalFloatSVDF( |
319 | params, GetTensorShape(input), GetTensorData<float>(input), |
320 | GetTensorShape(weights_feature), |
321 | GetTensorData<float>(weights_feature), GetTensorShape(weights_time), |
322 | GetTensorData<float>(weights_time), GetTensorShape(bias), |
323 | GetTensorData<float>(bias), GetTensorData<float>(scratch), |
324 | GetTensorData<float>(state), GetTensorShape(output), |
325 | GetTensorData<float>(output)); |
326 | return kTfLiteOk; |
327 | } |
328 | case kTfLiteUInt8: |
329 | case kTfLiteInt8: { |
330 | if (input->type == kTfLiteFloat32) { |
331 | TfLiteTensor* input_quantized; |
332 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/1, |
333 | &input_quantized)); |
334 | TfLiteTensor* scaling_factors; |
335 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/2, |
336 | &scaling_factors)); |
337 | TfLiteTensor* float_weights_time; |
338 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/3, |
339 | &float_weights_time)); |
340 | TfLiteTensor* zero_points; |
341 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, /*index=*/4, |
342 | &zero_points)); |
343 | TfLiteTensor* row_sums; |
344 | TF_LITE_ENSURE_OK( |
345 | context, GetTemporarySafe(context, node, /*index=*/5, &row_sums)); |
346 | // Dequantize weights time. |
347 | // TODO(alanchiao): this dequantization initialization only needs to |
348 | // happen once per model and should theoretically be placed in either |
349 | // Init or Prepare. However, TFLite doesn't allocate float_weights_time |
350 | // until the Eval function. |
351 | // TODO(alanchiao): refactor logic out into dequantize function. |
352 | if (!op_data->float_weights_time_initialized) { |
353 | const float dequantization_scale = weights_time->params.scale; |
354 | const int8_t* weights_time_ptr = GetTensorData<int8_t>(weights_time); |
355 | float* float_weights_time_ptr = |
356 | GetTensorData<float>(float_weights_time); |
357 | for (int i = 0; i < NumElements(float_weights_time); ++i) { |
358 | float_weights_time_ptr[i] = |
359 | weights_time_ptr[i] * dequantization_scale; |
360 | } |
361 | op_data->float_weights_time_initialized = true; |
362 | } |
363 | |
364 | int32_t* zero_points_ptr = nullptr; |
365 | int32_t* row_sums_ptr = nullptr; |
366 | if (params->asymmetric_quantize_inputs && row_sums != nullptr) { |
367 | zero_points_ptr = GetTensorData<int32_t>(zero_points); |
368 | row_sums_ptr = GetTensorData<int32_t>(row_sums); |
369 | } |
370 | |
371 | reference_ops::EvalHybridSVDF( |
372 | params, GetTensorShape(input), GetTensorData<float>(input), |
373 | GetTensorShape(weights_feature), |
374 | GetTensorData<int8_t>(weights_feature), |
375 | weights_feature->params.scale, GetTensorShape(float_weights_time), |
376 | GetTensorData<float>(float_weights_time), GetTensorShape(bias), |
377 | GetTensorData<float>(bias), GetTensorData<float>(scratch), |
378 | GetTensorData<float>(scaling_factors), |
379 | GetTensorData<int8_t>(input_quantized), GetTensorData<float>(state), |
380 | GetTensorShape(output), GetTensorData<float>(output), |
381 | zero_points_ptr, row_sums_ptr, &op_data->compute_row_sums); |
382 | return kTfLiteOk; |
383 | } |
384 | auto* input_params = reinterpret_cast<TfLiteAffineQuantization*>( |
385 | input->quantization.params); |
386 | auto* output_params = reinterpret_cast<TfLiteAffineQuantization*>( |
387 | output->quantization.params); |
388 | TfLiteTensor* output_temp; |
389 | TF_LITE_ENSURE_OK( |
390 | context, GetTemporarySafe(context, node, /*index=*/1, &output_temp)); |
391 | |
392 | // Currently supports only ReLU. |
393 | // TODO(jianlijianli): support other activations. |
394 | TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActRelu); |
395 | |
396 | reference_ops::EvalIntegerSVDF( |
397 | params, GetTensorShape(input), GetTensorData<int8_t>(input), |
398 | GetTensorShape(weights_feature), |
399 | GetTensorData<int8_t>(weights_feature), GetTensorShape(weights_time), |
400 | GetTensorData<int16_t>(weights_time), GetTensorShape(bias), |
401 | GetTensorData<int32_t>(bias), GetTensorData<int16_t>(state), |
402 | GetTensorShape(output), GetTensorData<int8_t>(output), |
403 | GetTensorData<int32_t>(scratch), GetTensorData<int32_t>(output_temp), |
404 | op_data->effective_scale_1_a, op_data->effective_scale_1_b, |
405 | op_data->effective_scale_2_a, op_data->effective_scale_2_b, |
406 | input_params->zero_point->data[0], |
407 | output_params->zero_point->data[0]); |
408 | return kTfLiteOk; |
409 | } |
410 | default: |
411 | TF_LITE_KERNEL_LOG(context, "Type %s not currently supported." , |
412 | TfLiteTypeGetName(weights_feature->type)); |
413 | return kTfLiteError; |
414 | } |
415 | } |
416 | |
417 | } // namespace svdf |
418 | |
419 | TfLiteRegistration* Register_SVDF() { |
420 | static TfLiteRegistration r = {svdf::Init, svdf::Free, svdf::Prepare, |
421 | svdf::Eval}; |
422 | return &r; |
423 | } |
424 | |
425 | } // namespace builtin |
426 | } // namespace ops |
427 | } // namespace tflite |
428 | |