1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cmath> |
18 | #include <cstddef> |
19 | #include <cstdint> |
20 | #include <cstring> |
21 | #include <memory> |
22 | #include <vector> |
23 | |
24 | #include "tensorflow/lite/c/builtin_op_data.h" |
25 | #include "tensorflow/lite/c/common.h" |
26 | #include "tensorflow/lite/kernels/cpu_backend_context.h" |
27 | #include "tensorflow/lite/kernels/internal/compatibility.h" |
28 | #include "tensorflow/lite/kernels/internal/kernel_utils.h" |
29 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
30 | #include "tensorflow/lite/kernels/internal/quantization_util.h" |
31 | #include "tensorflow/lite/kernels/internal/tensor.h" |
32 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
33 | #include "tensorflow/lite/kernels/internal/tensor_utils.h" |
34 | #include "tensorflow/lite/kernels/internal/types.h" |
35 | #include "tensorflow/lite/kernels/kernel_util.h" |
36 | #include "tensorflow/lite/kernels/lstm_eval.h" |
37 | #include "tensorflow/lite/kernels/lstm_shared.h" |
38 | |
39 | namespace tflite { |
40 | namespace ops { |
41 | namespace builtin { |
42 | namespace lstm { |
43 | |
44 | struct OpData { |
45 | // Which kernel type to use. Full kernel (24 inputs) or basic kernel (5 |
46 | // inputs). |
47 | // Please note the 20-input full kernel is deprecated and only kept |
48 | // here for backward compatibility. |
49 | TfLiteLSTMKernelType kernel_type; |
50 | |
51 | // If the lstm is layer norm. |
52 | bool use_layer_norm; |
53 | |
54 | // These fields are only used by full kernel. |
55 | int scratch_tensor_index; |
56 | lstm_eval::IntegerLstmParameter integer_lstm_param; |
57 | bool compute_row_sums; |
58 | |
59 | // Only used for sparse hybrid lstm kernels. |
60 | int ledger_index; |
61 | bool ledger_initialized; |
62 | }; |
63 | |
64 | namespace full { |
65 | namespace { |
66 | |
67 | // Named temporary tensors. |
68 | enum HybridTemporaryTensor { |
69 | kScratchBuffer = 0, |
70 | kInputQuantized = 1, |
71 | kOutputStateQuantized = 2, |
72 | kCellStateQuantized = 3, |
73 | kInputScalingFactors = 4, |
74 | kOutputStateScalingFactors = 5, |
75 | kProductScalingFactors = 6, |
76 | kRecoveredCellWeights = 7, |
77 | kAccumScratch = 8, |
78 | kInputZeroPoints = 9, |
79 | kOutputStateZeroPoints = 10, |
80 | kRowSums = 11, |
81 | kNumHybridTemporaryTensors = 12, |
82 | }; |
83 | |
84 | constexpr int kLedgersToAdd = 9; |
85 | constexpr int kInputToInputWeightsLedgerOffset = 0; |
86 | constexpr int kInputToForgetWeightsLedgerOffset = 1; |
87 | constexpr int kInputToCellWeightsLedgerOffset = 2; |
88 | constexpr int kInputToOutputWeightsLedgerOffset = 3; |
89 | constexpr int kRecurrentToInputWeightsLedgerOffset = 4; |
90 | constexpr int kRecurrentToForgetWeightsLedgerOffset = 5; |
91 | constexpr int kRecurrentToCellWeightsLedgerOffset = 6; |
92 | constexpr int kRecurrentToOutputWeightsLedgerOffset = 7; |
93 | constexpr int kProjectionWeightsLedgerOffset = 8; |
94 | |
95 | TfLiteStatus make_ledger(const TfLiteSparsity* sparsity, TfLiteContext* context, |
96 | TfLiteTensor* ledger) { |
97 | ledger->type = kTfLiteUInt8; |
98 | ledger->name = "Lstm_ledger" ; |
99 | ledger->allocation_type = kTfLiteArenaRwPersistent; |
100 | if (sparsity == nullptr) { |
101 | return kTfLiteOk; |
102 | } |
103 | TfLiteIntArray* ledger_size = TfLiteIntArrayCreate(1); |
104 | ledger_size->data[0] = sparsity->dim_metadata[1].array_indices->size + |
105 | sparsity->dim_metadata[1].array_segments->size - 1; |
106 | return context->ResizeTensor(context, ledger, ledger_size); |
107 | } |
108 | |
109 | TfLiteStatus copy_ledger(const TfLiteSparsity* sparsity, TfLiteTensor* ledger) { |
110 | if (sparsity == nullptr) { |
111 | return kTfLiteOk; |
112 | } |
113 | |
114 | const auto* array_segments = sparsity->dim_metadata[1].array_segments; |
115 | const auto* array_indices = sparsity->dim_metadata[1].array_indices; |
116 | uint8_t* output_data = GetTensorData<uint8_t>(ledger); |
117 | int output_data_ptr = 0; |
118 | |
119 | for (int i = 0; i < array_segments->size - 1; i++) { |
120 | int row_start = array_segments->data[i]; |
121 | int row_end = array_segments->data[i + 1]; |
122 | if (row_end - row_start > UINT8_MAX) { |
123 | return kTfLiteError; |
124 | } |
125 | // Copy num of non-zero blocks in row i. |
126 | output_data[output_data_ptr] = static_cast<uint8_t>(row_end - row_start); |
127 | output_data_ptr++; |
128 | |
129 | for (int j = row_start; j < row_end; j++) { |
130 | if (array_indices->data[j] > UINT8_MAX) { |
131 | return kTfLiteError; |
132 | } |
133 | // Copy indices of non-zero blocks in row i. |
134 | output_data[output_data_ptr] = |
135 | static_cast<uint8_t>(array_indices->data[j]); |
136 | output_data_ptr++; |
137 | } |
138 | } |
139 | return kTfLiteOk; |
140 | } |
141 | |
142 | TfLiteStatus PopulateQuantizedLstmParams8x8_16( |
143 | TfLiteContext* context, TfLiteNode* node, |
144 | lstm_eval::IntegerLstmParameter* integer_lstm_param) { |
145 | // Calculate quantized clip for projection and cell. |
146 | const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data); |
147 | const float cell_clip = params->cell_clip; |
148 | const float proj_clip = params->proj_clip; |
149 | |
150 | const TfLiteTensor* cell_state = |
151 | GetVariableInput(context, node, kCellStateTensor); |
152 | TF_LITE_ENSURE(context, cell_state != nullptr); |
153 | TfLiteTensor* output_tensor; |
154 | TF_LITE_ENSURE_OK( |
155 | context, GetOutputSafe(context, node, kOutputTensor, &output_tensor)); |
156 | |
157 | auto* cell_state_params = |
158 | static_cast<TfLiteAffineQuantization*>(cell_state->quantization.params); |
159 | auto* proj_params = static_cast<TfLiteAffineQuantization*>( |
160 | output_tensor->quantization.params); |
161 | if (cell_clip > 0.0) { |
162 | integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min( |
163 | std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f), |
164 | 32767.0f)); |
165 | } else { |
166 | integer_lstm_param->quantized_cell_clip = 0; |
167 | } |
168 | if (proj_clip > 0.0) { |
169 | integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min( |
170 | std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f)); |
171 | } else { |
172 | integer_lstm_param->quantized_proj_clip = 0; |
173 | } |
174 | |
175 | // Calculate effective scales. |
176 | OpData* op_data = static_cast<OpData*>(node->user_data); |
177 | const bool use_layer_norm = op_data->use_layer_norm; |
178 | |
179 | const TfLiteTensor* input; |
180 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
181 | |
182 | const TfLiteTensor* input_to_input_weights = |
183 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
184 | const TfLiteTensor* input_to_forget_weights; |
185 | TF_LITE_ENSURE_OK(context, |
186 | GetInputSafe(context, node, kInputToForgetWeightsTensor, |
187 | &input_to_forget_weights)); |
188 | const TfLiteTensor* input_to_cell_weights; |
189 | TF_LITE_ENSURE_OK(context, |
190 | GetInputSafe(context, node, kInputToCellWeightsTensor, |
191 | &input_to_cell_weights)); |
192 | const TfLiteTensor* input_to_output_weights; |
193 | TF_LITE_ENSURE_OK(context, |
194 | GetInputSafe(context, node, kInputToOutputWeightsTensor, |
195 | &input_to_output_weights)); |
196 | |
197 | const TfLiteTensor* recurrent_to_input_weights = |
198 | GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); |
199 | const TfLiteTensor* recurrent_to_forget_weights; |
200 | TF_LITE_ENSURE_OK(context, |
201 | GetInputSafe(context, node, kRecurrentToForgetWeightsTensor, |
202 | &recurrent_to_forget_weights)); |
203 | const TfLiteTensor* recurrent_to_cell_weights; |
204 | TF_LITE_ENSURE_OK(context, |
205 | GetInputSafe(context, node, kRecurrentToCellWeightsTensor, |
206 | &recurrent_to_cell_weights)); |
207 | const TfLiteTensor* recurrent_to_output_weights; |
208 | TF_LITE_ENSURE_OK(context, |
209 | GetInputSafe(context, node, kRecurrentToOutputWeightsTensor, |
210 | &recurrent_to_output_weights)); |
211 | |
212 | const TfLiteTensor* cell_to_input_weights = |
213 | GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); |
214 | const TfLiteTensor* cell_to_forget_weights = |
215 | GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); |
216 | const TfLiteTensor* cell_to_output_weights = |
217 | GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); |
218 | |
219 | const TfLiteTensor* input_layer_norm_coefficients = |
220 | GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor); |
221 | const TfLiteTensor* forget_layer_norm_coefficients = |
222 | GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor); |
223 | const TfLiteTensor* cell_layer_norm_coefficients = |
224 | GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor); |
225 | const TfLiteTensor* output_layer_norm_coefficients = |
226 | GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor); |
227 | |
228 | const TfLiteTensor* projection_weights = |
229 | GetOptionalInputTensor(context, node, kProjectionWeightsTensor); |
230 | |
231 | TfLiteTensor* output_state = |
232 | GetVariableInput(context, node, kOutputStateTensor); |
233 | TF_LITE_ENSURE(context, output_state != nullptr); |
234 | |
235 | // Since we have already checked that weights are all there or none, we can |
236 | // check the existence of only one to get the condition. |
237 | const bool use_cifg = (input_to_input_weights == nullptr); |
238 | const bool use_peephole = (cell_to_output_weights != nullptr); |
239 | const bool use_projection = (projection_weights != nullptr); |
240 | |
241 | // Get intermediate scales and zero points. |
242 | std::vector<float> intermediate_scale; |
243 | std::vector<int32> intermediate_zp; |
244 | for (int i = 0; i < 4; ++i) { |
245 | if (use_layer_norm) { |
246 | TfLiteTensor* intermediate; |
247 | TF_LITE_ENSURE_OK(context, |
248 | GetIntermediatesSafe(context, node, i, &intermediate)); |
249 | auto* params = static_cast<TfLiteAffineQuantization*>( |
250 | intermediate->quantization.params); |
251 | intermediate_scale.push_back(params->scale->data[0]); |
252 | intermediate_zp.push_back(params->zero_point->data[0]); |
253 | } else { |
254 | // Q3.12 for activation functions. |
255 | intermediate_scale.push_back(std::pow(2, -12)); |
256 | intermediate_zp.push_back(0); |
257 | } |
258 | } |
259 | // In the absence of projection, hidden becomes output and this intermediate |
260 | // is ignored. |
261 | TfLiteTensor* hidden; |
262 | TF_LITE_ENSURE_OK(context, GetIntermediatesSafe(context, node, 4, &hidden)); |
263 | auto* hidden_params = |
264 | static_cast<TfLiteAffineQuantization*>(hidden->quantization.params); |
265 | intermediate_scale.push_back(hidden_params->scale->data[0]); |
266 | intermediate_zp.push_back(hidden_params->zero_point->data[0]); |
267 | |
268 | // Scales. |
269 | const float default_scale = 1.0; |
270 | float input_scale = default_scale; |
271 | float input_to_input_weight_scale = default_scale; |
272 | float recurrent_to_input_weight_scale = default_scale; |
273 | float cell_to_input_weight_scale = default_scale; |
274 | float input_to_forget_weight_scale = default_scale; |
275 | float recurrent_to_forget_weight_scale = default_scale; |
276 | float cell_to_forget_weight_scale = default_scale; |
277 | float input_to_cell_weight_scale = default_scale; |
278 | float recurrent_to_cell_weight_scale = default_scale; |
279 | float input_to_output_weight_scale = default_scale; |
280 | float recurrent_to_output_weight_scale = default_scale; |
281 | float cell_to_output_weight_scale = default_scale; |
282 | float projection_weight_scale = default_scale; |
283 | float layer_norm_input_scale = default_scale; |
284 | float layer_norm_forget_scale = default_scale; |
285 | float layer_norm_cell_scale = default_scale; |
286 | float layer_norm_output_scale = default_scale; |
287 | float output_state_scale = default_scale; |
288 | int cell_scale = 1; |
289 | |
290 | // Effective scales. |
291 | float effective_input_to_input_scale = default_scale; |
292 | float effective_recurrent_to_input_scale = default_scale; |
293 | float effective_cell_to_input_scale = default_scale; |
294 | float effective_input_to_forget_scale = default_scale; |
295 | float effective_recurrent_to_forget_scale = default_scale; |
296 | float effective_cell_to_forget_scale = default_scale; |
297 | float effective_input_to_cell_scale = default_scale; |
298 | float effective_recurrent_to_cell_scale = default_scale; |
299 | float effective_input_to_output_scale = default_scale; |
300 | float effective_recurrent_to_output_scale = default_scale; |
301 | float effective_cell_to_output_scale = default_scale; |
302 | float effective_proj_scale = default_scale; |
303 | float effective_hidden_scale = default_scale; |
304 | |
305 | // Populate scales. |
306 | if (!use_cifg) { |
307 | input_to_input_weight_scale = input_to_input_weights->params.scale; |
308 | recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale; |
309 | } |
310 | |
311 | if (use_peephole) { |
312 | if (!use_cifg) { |
313 | cell_to_input_weight_scale = cell_to_input_weights->params.scale; |
314 | } |
315 | cell_to_forget_weight_scale = cell_to_forget_weights->params.scale; |
316 | cell_to_output_weight_scale = cell_to_output_weights->params.scale; |
317 | } |
318 | |
319 | if (use_layer_norm) { |
320 | if (!use_cifg) { |
321 | layer_norm_input_scale = input_layer_norm_coefficients->params.scale; |
322 | } |
323 | layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale; |
324 | layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale; |
325 | layer_norm_output_scale = output_layer_norm_coefficients->params.scale; |
326 | } |
327 | |
328 | if (use_projection) { |
329 | projection_weight_scale = projection_weights->params.scale; |
330 | } |
331 | output_state_scale = output_state->params.scale; |
332 | |
333 | input_to_forget_weight_scale = input_to_forget_weights->params.scale; |
334 | input_to_cell_weight_scale = input_to_cell_weights->params.scale; |
335 | input_to_output_weight_scale = input_to_output_weights->params.scale; |
336 | recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale; |
337 | recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale; |
338 | recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale; |
339 | |
340 | // Check cell state (already used above) |
341 | TF_LITE_ENSURE(context, CheckedLog2(cell_state->params.scale, &cell_scale)); |
342 | TF_LITE_ENSURE(context, cell_scale <= -9); |
343 | integer_lstm_param->cell_scale = cell_scale; |
344 | input_scale = input->params.scale; |
345 | |
346 | // Calculate effective scales. |
347 | if (!use_cifg) { |
348 | effective_input_to_input_scale = |
349 | input_to_input_weight_scale * input_scale / intermediate_scale[0]; |
350 | effective_recurrent_to_input_scale = recurrent_to_input_weight_scale * |
351 | output_state_scale / |
352 | intermediate_scale[0]; |
353 | } |
354 | effective_input_to_forget_scale = |
355 | input_to_forget_weight_scale * input_scale / intermediate_scale[1]; |
356 | effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale * |
357 | output_state_scale / |
358 | intermediate_scale[1]; |
359 | |
360 | effective_input_to_cell_scale = |
361 | input_to_cell_weight_scale * input_scale / intermediate_scale[2]; |
362 | effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale * |
363 | output_state_scale / |
364 | intermediate_scale[2]; |
365 | |
366 | effective_input_to_output_scale = |
367 | input_to_output_weight_scale * input_scale / intermediate_scale[3]; |
368 | effective_recurrent_to_output_scale = recurrent_to_output_weight_scale * |
369 | output_state_scale / |
370 | intermediate_scale[3]; |
371 | |
372 | effective_hidden_scale = |
373 | std::pow(2, -15) / intermediate_scale[4] * std::pow(2, -15); |
374 | |
375 | effective_proj_scale = |
376 | projection_weight_scale * intermediate_scale[4] / output_state_scale; |
377 | |
378 | if (use_peephole) { |
379 | if (!use_cifg) { |
380 | effective_cell_to_input_scale = std::pow(2, cell_scale) * // NOLINT |
381 | cell_to_input_weight_scale / |
382 | intermediate_scale[0]; |
383 | } |
384 | effective_cell_to_forget_scale = std::pow(2, cell_scale) * // NOLINT |
385 | cell_to_forget_weight_scale / |
386 | intermediate_scale[1]; |
387 | effective_cell_to_output_scale = std::pow(2, cell_scale) * // NOLINT |
388 | cell_to_output_weight_scale / |
389 | intermediate_scale[3]; |
390 | } |
391 | |
392 | // Decompose scales. |
393 | QuantizeMultiplier(effective_input_to_input_scale, |
394 | &integer_lstm_param->effective_input_to_input_scale_a, |
395 | &integer_lstm_param->effective_input_to_input_scale_b); |
396 | QuantizeMultiplier(effective_recurrent_to_input_scale, |
397 | &integer_lstm_param->effective_recurrent_to_input_scale_a, |
398 | &integer_lstm_param->effective_recurrent_to_input_scale_b); |
399 | QuantizeMultiplier(effective_cell_to_input_scale, |
400 | &integer_lstm_param->effective_cell_to_input_scale_a, |
401 | &integer_lstm_param->effective_cell_to_input_scale_b); |
402 | QuantizeMultiplier(effective_input_to_forget_scale, |
403 | &integer_lstm_param->effective_input_to_forget_scale_a, |
404 | &integer_lstm_param->effective_input_to_forget_scale_b); |
405 | QuantizeMultiplier( |
406 | effective_recurrent_to_forget_scale, |
407 | &integer_lstm_param->effective_recurrent_to_forget_scale_a, |
408 | &integer_lstm_param->effective_recurrent_to_forget_scale_b); |
409 | QuantizeMultiplier(effective_cell_to_forget_scale, |
410 | &integer_lstm_param->effective_cell_to_forget_scale_a, |
411 | &integer_lstm_param->effective_cell_to_forget_scale_b); |
412 | QuantizeMultiplier(effective_input_to_cell_scale, |
413 | &integer_lstm_param->effective_input_to_cell_scale_a, |
414 | &integer_lstm_param->effective_input_to_cell_scale_b); |
415 | QuantizeMultiplier(effective_recurrent_to_cell_scale, |
416 | &integer_lstm_param->effective_recurrent_to_cell_scale_a, |
417 | &integer_lstm_param->effective_recurrent_to_cell_scale_b); |
418 | QuantizeMultiplier(effective_input_to_output_scale, |
419 | &integer_lstm_param->effective_input_to_output_scale_a, |
420 | &integer_lstm_param->effective_input_to_output_scale_b); |
421 | QuantizeMultiplier( |
422 | effective_recurrent_to_output_scale, |
423 | &integer_lstm_param->effective_recurrent_to_output_scale_a, |
424 | &integer_lstm_param->effective_recurrent_to_output_scale_b); |
425 | QuantizeMultiplier(effective_cell_to_output_scale, |
426 | &integer_lstm_param->effective_cell_to_output_scale_a, |
427 | &integer_lstm_param->effective_cell_to_output_scale_b); |
428 | QuantizeMultiplier(effective_proj_scale, |
429 | &integer_lstm_param->effective_proj_scale_a, |
430 | &integer_lstm_param->effective_proj_scale_b); |
431 | QuantizeMultiplier(effective_hidden_scale, |
432 | &integer_lstm_param->effective_hidden_scale_a, |
433 | &integer_lstm_param->effective_hidden_scale_b); |
434 | QuantizeMultiplier(layer_norm_input_scale, |
435 | &integer_lstm_param->layer_norm_input_scale_a, |
436 | &integer_lstm_param->layer_norm_input_scale_b); |
437 | QuantizeMultiplier(layer_norm_forget_scale, |
438 | &integer_lstm_param->layer_norm_forget_scale_a, |
439 | &integer_lstm_param->layer_norm_forget_scale_b); |
440 | QuantizeMultiplier(layer_norm_cell_scale, |
441 | &integer_lstm_param->layer_norm_cell_scale_a, |
442 | &integer_lstm_param->layer_norm_cell_scale_b); |
443 | QuantizeMultiplier(layer_norm_output_scale, |
444 | &integer_lstm_param->layer_norm_output_scale_a, |
445 | &integer_lstm_param->layer_norm_output_scale_b); |
446 | |
447 | integer_lstm_param->hidden_zp = intermediate_zp[4]; |
448 | |
449 | // 10000 is used to make sure the kernel logic does not overflow. |
450 | if (!use_cifg) { |
451 | integer_lstm_param->input_variance_guard = |
452 | std::max(1, static_cast<int32_t>(10000 * layer_norm_input_scale)); |
453 | } |
454 | integer_lstm_param->forget_variance_guard = |
455 | std::max(1, static_cast<int32_t>(10000 * layer_norm_forget_scale)); |
456 | integer_lstm_param->cell_variance_guard = |
457 | std::max(1, static_cast<int32_t>(10000 * layer_norm_cell_scale)); |
458 | integer_lstm_param->output_variance_guard = |
459 | std::max(1, static_cast<int32_t>(10000 * layer_norm_output_scale)); |
460 | |
461 | return kTfLiteOk; |
462 | } |
463 | |
464 | TfLiteStatus PopulateQuantizedLstmParams8x8_8( |
465 | TfLiteContext* context, TfLiteNode* node, |
466 | lstm_eval::IntegerLstmParameter* integer_lstm_param) { |
467 | // Get all tensors. |
468 | const TfLiteTensor* input; |
469 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
470 | const TfLiteTensor* input_to_input_weights = |
471 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
472 | const TfLiteTensor* input_to_forget_weights; |
473 | TF_LITE_ENSURE_OK(context, |
474 | GetInputSafe(context, node, kInputToForgetWeightsTensor, |
475 | &input_to_forget_weights)); |
476 | const TfLiteTensor* input_to_cell_weights; |
477 | TF_LITE_ENSURE_OK(context, |
478 | GetInputSafe(context, node, kInputToCellWeightsTensor, |
479 | &input_to_cell_weights)); |
480 | const TfLiteTensor* input_to_output_weights; |
481 | TF_LITE_ENSURE_OK(context, |
482 | GetInputSafe(context, node, kInputToOutputWeightsTensor, |
483 | &input_to_output_weights)); |
484 | |
485 | const TfLiteTensor* recurrent_to_input_weights = |
486 | GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); |
487 | const TfLiteTensor* recurrent_to_forget_weights; |
488 | TF_LITE_ENSURE_OK(context, |
489 | GetInputSafe(context, node, kRecurrentToForgetWeightsTensor, |
490 | &recurrent_to_forget_weights)); |
491 | const TfLiteTensor* recurrent_to_cell_weights; |
492 | TF_LITE_ENSURE_OK(context, |
493 | GetInputSafe(context, node, kRecurrentToCellWeightsTensor, |
494 | &recurrent_to_cell_weights)); |
495 | const TfLiteTensor* recurrent_to_output_weights; |
496 | TF_LITE_ENSURE_OK(context, |
497 | GetInputSafe(context, node, kRecurrentToOutputWeightsTensor, |
498 | &recurrent_to_output_weights)); |
499 | |
500 | const TfLiteTensor* cell_to_input_weights = |
501 | GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); |
502 | const TfLiteTensor* cell_to_forget_weights = |
503 | GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); |
504 | const TfLiteTensor* cell_to_output_weights = |
505 | GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); |
506 | |
507 | const TfLiteTensor* input_layer_norm_coefficients = |
508 | GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor); |
509 | const TfLiteTensor* forget_layer_norm_coefficients = |
510 | GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor); |
511 | const TfLiteTensor* cell_layer_norm_coefficients = |
512 | GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor); |
513 | const TfLiteTensor* output_layer_norm_coefficients = |
514 | GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor); |
515 | |
516 | const TfLiteTensor* input_gate_bias = |
517 | GetOptionalInputTensor(context, node, kInputGateBiasTensor); |
518 | const TfLiteTensor* forget_gate_bias; |
519 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor, |
520 | &forget_gate_bias)); |
521 | const TfLiteTensor* cell_gate_bias; |
522 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor, |
523 | &cell_gate_bias)); |
524 | const TfLiteTensor* output_gate_bias; |
525 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor, |
526 | &output_gate_bias)); |
527 | |
528 | const TfLiteTensor* projection_weights = |
529 | GetOptionalInputTensor(context, node, kProjectionWeightsTensor); |
530 | const TfLiteTensor* projection_bias = |
531 | GetOptionalInputTensor(context, node, kProjectionBiasTensor); |
532 | |
533 | TfLiteTensor* output_state = |
534 | GetVariableInput(context, node, kOutputStateTensor); |
535 | TF_LITE_ENSURE(context, output_state != nullptr); |
536 | TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor); |
537 | TF_LITE_ENSURE(context, cell_state != nullptr); |
538 | |
539 | // Since we have already checked that weights are all there or none, we can |
540 | // check the existence of only one to get the condition. |
541 | const bool use_cifg = (input_to_input_weights == nullptr); |
542 | const bool use_peephole = (cell_to_output_weights != nullptr); |
543 | const bool is_layer_norm_lstm = (forget_layer_norm_coefficients != nullptr); |
544 | const bool use_projection = (projection_weights != nullptr); |
545 | |
546 | // Weights and states. |
547 | int8_t* input_to_input_weight_ptr = nullptr; |
548 | int8_t* recurrent_to_input_weight_ptr = nullptr; |
549 | int8_t* cell_to_input_weight_ptr = nullptr; |
550 | int8_t* input_to_forget_weight_ptr = nullptr; |
551 | int8_t* recurrent_to_forget_weight_ptr = nullptr; |
552 | int8_t* cell_to_forget_weight_ptr = nullptr; |
553 | int8_t* input_to_cell_weight_ptr = nullptr; |
554 | int8_t* recurrent_to_cell_weight_ptr = nullptr; |
555 | int8_t* input_to_output_weight_ptr = nullptr; |
556 | int8_t* recurrent_to_output_weight_ptr = nullptr; |
557 | int8_t* cell_to_output_weight_ptr = nullptr; |
558 | int8_t* projection_weight_ptr = nullptr; |
559 | int16_t* layer_norm_input_weight_ptr = nullptr; |
560 | int16_t* layer_norm_forget_weight_ptr = nullptr; |
561 | int16_t* layer_norm_cell_weight_ptr = nullptr; |
562 | int16_t* layer_norm_output_weight_ptr = nullptr; |
563 | int32_t* input_gate_bias_ptr = nullptr; |
564 | int32_t* forget_gate_bias_ptr = nullptr; |
565 | int32_t* cell_gate_bias_ptr = nullptr; |
566 | int32_t* output_gate_bias_ptr = nullptr; |
567 | int32_t* projection_bias_ptr = nullptr; |
568 | int16_t* cell_ptr = nullptr; |
569 | int8_t* output_state_ptr = nullptr; |
570 | |
571 | // Scales. |
572 | const float default_scale = 1.0; |
573 | float input_scale = default_scale; |
574 | float input_to_input_weight_scale = default_scale; |
575 | float recurrent_to_input_weight_scale = default_scale; |
576 | float cell_to_input_weight_scale = default_scale; |
577 | float input_to_forget_weight_scale = default_scale; |
578 | float recurrent_to_forget_weight_scale = default_scale; |
579 | float cell_to_forget_weight_scale = default_scale; |
580 | float input_to_cell_weight_scale = default_scale; |
581 | float recurrent_to_cell_weight_scale = default_scale; |
582 | float input_to_output_weight_scale = default_scale; |
583 | float recurrent_to_output_weight_scale = default_scale; |
584 | float cell_to_output_weight_scale = default_scale; |
585 | float projection_weight_scale = default_scale; |
586 | float layer_norm_input_scale = default_scale; |
587 | float layer_norm_forget_scale = default_scale; |
588 | float layer_norm_cell_scale = default_scale; |
589 | float layer_norm_output_scale = default_scale; |
590 | float output_state_scale = default_scale; |
591 | |
592 | // Effective scales. |
593 | float effective_input_to_input_scale = default_scale; |
594 | float effective_recurrent_to_input_scale = default_scale; |
595 | float effective_cell_to_input_scale = default_scale; |
596 | float effective_input_to_forget_scale = default_scale; |
597 | float effective_recurrent_to_forget_scale = default_scale; |
598 | float effective_cell_to_forget_scale = default_scale; |
599 | float effective_input_to_cell_scale = default_scale; |
600 | float effective_recurrent_to_cell_scale = default_scale; |
601 | float effective_input_to_output_scale = default_scale; |
602 | float effective_recurrent_to_output_scale = default_scale; |
603 | float effective_cell_to_output_scale = default_scale; |
604 | float effective_proj_scale = default_scale; |
605 | |
606 | // Zero points |
607 | int input_zp = 0; |
608 | int output_state_zp = 0; |
609 | |
610 | // Populate all the values. |
611 | if (!use_cifg) { |
612 | input_to_input_weight_ptr = input_to_input_weights->data.int8; |
613 | recurrent_to_input_weight_ptr = recurrent_to_input_weights->data.int8; |
614 | input_gate_bias_ptr = input_gate_bias->data.i32; |
615 | input_to_input_weight_scale = input_to_input_weights->params.scale; |
616 | recurrent_to_input_weight_scale = recurrent_to_input_weights->params.scale; |
617 | } |
618 | |
619 | if (use_peephole) { |
620 | if (!use_cifg) { |
621 | cell_to_input_weight_ptr = cell_to_input_weights->data.int8; |
622 | cell_to_input_weight_scale = cell_to_input_weights->params.scale; |
623 | } |
624 | cell_to_forget_weight_ptr = cell_to_forget_weights->data.int8; |
625 | cell_to_output_weight_ptr = cell_to_output_weights->data.int8; |
626 | cell_to_forget_weight_scale = cell_to_forget_weights->params.scale; |
627 | cell_to_output_weight_scale = cell_to_output_weights->params.scale; |
628 | } |
629 | |
630 | if (is_layer_norm_lstm) { |
631 | if (!use_cifg) { |
632 | layer_norm_input_weight_ptr = input_layer_norm_coefficients->data.i16; |
633 | layer_norm_input_scale = input_layer_norm_coefficients->params.scale; |
634 | } |
635 | layer_norm_forget_weight_ptr = forget_layer_norm_coefficients->data.i16; |
636 | layer_norm_forget_scale = forget_layer_norm_coefficients->params.scale; |
637 | layer_norm_cell_weight_ptr = cell_layer_norm_coefficients->data.i16; |
638 | layer_norm_cell_scale = cell_layer_norm_coefficients->params.scale; |
639 | layer_norm_output_weight_ptr = output_layer_norm_coefficients->data.i16; |
640 | layer_norm_output_scale = output_layer_norm_coefficients->params.scale; |
641 | } |
642 | |
643 | if (use_projection) { |
644 | projection_weight_ptr = projection_weights->data.int8; |
645 | projection_weight_scale = projection_weights->params.scale; |
646 | if (projection_bias) { |
647 | projection_bias_ptr = projection_bias->data.i32; |
648 | } |
649 | } |
650 | output_state_scale = output_state->params.scale; |
651 | |
652 | input_to_forget_weight_ptr = input_to_forget_weights->data.int8; |
653 | input_to_forget_weight_scale = input_to_forget_weights->params.scale; |
654 | input_to_cell_weight_ptr = input_to_cell_weights->data.int8; |
655 | input_to_cell_weight_scale = input_to_cell_weights->params.scale; |
656 | input_to_output_weight_ptr = input_to_output_weights->data.int8; |
657 | input_to_output_weight_scale = input_to_output_weights->params.scale; |
658 | recurrent_to_forget_weight_ptr = recurrent_to_forget_weights->data.int8; |
659 | recurrent_to_forget_weight_scale = recurrent_to_forget_weights->params.scale; |
660 | recurrent_to_cell_weight_ptr = recurrent_to_cell_weights->data.int8; |
661 | recurrent_to_cell_weight_scale = recurrent_to_cell_weights->params.scale; |
662 | recurrent_to_output_weight_ptr = recurrent_to_output_weights->data.int8; |
663 | recurrent_to_output_weight_scale = recurrent_to_output_weights->params.scale; |
664 | forget_gate_bias_ptr = forget_gate_bias->data.i32; |
665 | cell_gate_bias_ptr = cell_gate_bias->data.i32; |
666 | output_gate_bias_ptr = output_gate_bias->data.i32; |
667 | output_state_ptr = output_state->data.int8; |
668 | cell_ptr = cell_state->data.i16; |
669 | input_scale = input->params.scale; |
670 | input_zp = input->params.zero_point; |
671 | output_state_zp = output_state->params.zero_point; |
672 | |
673 | std::vector<float> intermediate_scale; |
674 | for (int i = 0; i < 12; ++i) { |
675 | TfLiteTensor* intermediate = |
676 | &context->tensors[node->intermediates->data[i]]; |
677 | auto* params = reinterpret_cast<TfLiteAffineQuantization*>( |
678 | intermediate->quantization.params); |
679 | intermediate_scale.push_back(params->scale->data[0]); |
680 | integer_lstm_param->intermediate_zp[i] = params->zero_point->data[0]; |
681 | } |
682 | |
683 | // Calculate effective scales. |
684 | if (!use_cifg) { |
685 | effective_input_to_input_scale = |
686 | input_to_input_weight_scale * input_scale / intermediate_scale[1]; |
687 | effective_recurrent_to_input_scale = recurrent_to_input_weight_scale * |
688 | output_state_scale / |
689 | intermediate_scale[2]; |
690 | } |
691 | effective_input_to_forget_scale = |
692 | input_to_forget_weight_scale * input_scale / intermediate_scale[4]; |
693 | effective_recurrent_to_forget_scale = recurrent_to_forget_weight_scale * |
694 | output_state_scale / |
695 | intermediate_scale[5]; |
696 | |
697 | effective_input_to_cell_scale = |
698 | input_to_cell_weight_scale * input_scale / intermediate_scale[7]; |
699 | effective_recurrent_to_cell_scale = recurrent_to_cell_weight_scale * |
700 | output_state_scale / |
701 | intermediate_scale[8]; |
702 | |
703 | effective_input_to_output_scale = |
704 | input_to_output_weight_scale * input_scale / intermediate_scale[10]; |
705 | effective_recurrent_to_output_scale = recurrent_to_output_weight_scale * |
706 | output_state_scale / |
707 | intermediate_scale[11]; |
708 | effective_proj_scale = |
709 | projection_weight_scale * std::pow(2, -15) / output_state_scale; |
710 | |
711 | if (use_peephole) { |
712 | if (!use_cifg) { |
713 | effective_cell_to_input_scale = |
714 | std::pow(2, -15) * cell_to_input_weight_scale / intermediate_scale[0]; |
715 | } |
716 | effective_cell_to_forget_scale = |
717 | std::pow(2, -15) * cell_to_forget_weight_scale / intermediate_scale[3]; |
718 | effective_cell_to_output_scale = |
719 | std::pow(2, -15) * cell_to_output_weight_scale / intermediate_scale[9]; |
720 | } |
721 | |
722 | // Calculate effecgive scales. |
723 | QuantizeMultiplier(effective_input_to_input_scale, |
724 | &integer_lstm_param->effective_input_to_input_scale_a, |
725 | &integer_lstm_param->effective_input_to_input_scale_b); |
726 | QuantizeMultiplier(effective_recurrent_to_input_scale, |
727 | &integer_lstm_param->effective_recurrent_to_input_scale_a, |
728 | &integer_lstm_param->effective_recurrent_to_input_scale_b); |
729 | QuantizeMultiplier(effective_cell_to_input_scale, |
730 | &integer_lstm_param->effective_cell_to_input_scale_a, |
731 | &integer_lstm_param->effective_cell_to_input_scale_b); |
732 | QuantizeMultiplier(effective_input_to_forget_scale, |
733 | &integer_lstm_param->effective_input_to_forget_scale_a, |
734 | &integer_lstm_param->effective_input_to_forget_scale_b); |
735 | QuantizeMultiplier( |
736 | effective_recurrent_to_forget_scale, |
737 | &integer_lstm_param->effective_recurrent_to_forget_scale_a, |
738 | &integer_lstm_param->effective_recurrent_to_forget_scale_b); |
739 | QuantizeMultiplier(effective_cell_to_forget_scale, |
740 | &integer_lstm_param->effective_cell_to_forget_scale_a, |
741 | &integer_lstm_param->effective_cell_to_forget_scale_b); |
742 | QuantizeMultiplier(effective_input_to_cell_scale, |
743 | &integer_lstm_param->effective_input_to_cell_scale_a, |
744 | &integer_lstm_param->effective_input_to_cell_scale_b); |
745 | QuantizeMultiplier(effective_recurrent_to_cell_scale, |
746 | &integer_lstm_param->effective_recurrent_to_cell_scale_a, |
747 | &integer_lstm_param->effective_recurrent_to_cell_scale_b); |
748 | QuantizeMultiplier(effective_input_to_output_scale, |
749 | &integer_lstm_param->effective_input_to_output_scale_a, |
750 | &integer_lstm_param->effective_input_to_output_scale_b); |
751 | QuantizeMultiplier( |
752 | effective_recurrent_to_output_scale, |
753 | &integer_lstm_param->effective_recurrent_to_output_scale_a, |
754 | &integer_lstm_param->effective_recurrent_to_output_scale_b); |
755 | QuantizeMultiplier(effective_cell_to_output_scale, |
756 | &integer_lstm_param->effective_cell_to_output_scale_a, |
757 | &integer_lstm_param->effective_cell_to_output_scale_b); |
758 | QuantizeMultiplier(effective_proj_scale, |
759 | &integer_lstm_param->effective_proj_scale_a, |
760 | &integer_lstm_param->effective_proj_scale_b); |
761 | QuantizeMultiplier(layer_norm_input_scale, |
762 | &integer_lstm_param->layer_norm_input_scale_a, |
763 | &integer_lstm_param->layer_norm_input_scale_b); |
764 | QuantizeMultiplier(layer_norm_forget_scale, |
765 | &integer_lstm_param->layer_norm_forget_scale_a, |
766 | &integer_lstm_param->layer_norm_forget_scale_b); |
767 | QuantizeMultiplier(layer_norm_cell_scale, |
768 | &integer_lstm_param->layer_norm_cell_scale_a, |
769 | &integer_lstm_param->layer_norm_cell_scale_b); |
770 | QuantizeMultiplier(layer_norm_output_scale, |
771 | &integer_lstm_param->layer_norm_output_scale_a, |
772 | &integer_lstm_param->layer_norm_output_scale_b); |
773 | |
774 | { |
775 | // Intermdiates in flatbuffer holds Wx, Wh and Wx+Wh. |
776 | // effective Wx, Wh is in effective_input/recurrent_to_<...>_scale |
777 | // So use intermediate_scale to hold scale from Wx and Wh to Wx+Wh |
778 | // 0: [1] -> [0] |
779 | // 1: [2] -> [0] |
780 | // and use intermdiate_zp as is. |
781 | const float s_1_0 = intermediate_scale[1] / intermediate_scale[0]; |
782 | const float s_2_0 = intermediate_scale[2] / intermediate_scale[0]; |
783 | const float s_4_3 = intermediate_scale[4] / intermediate_scale[3]; |
784 | const float s_5_3 = intermediate_scale[5] / intermediate_scale[3]; |
785 | const float s_7_6 = intermediate_scale[7] / intermediate_scale[6]; |
786 | const float s_8_6 = intermediate_scale[8] / intermediate_scale[6]; |
787 | const float s_10_9 = intermediate_scale[10] / intermediate_scale[9]; |
788 | const float s_11_9 = intermediate_scale[11] / intermediate_scale[9]; |
789 | QuantizeMultiplier(s_1_0, &integer_lstm_param->intermediate_scale_a[0], |
790 | &integer_lstm_param->intermediate_scale_b[0]); |
791 | QuantizeMultiplier(s_2_0, &integer_lstm_param->intermediate_scale_a[1], |
792 | &integer_lstm_param->intermediate_scale_b[1]); |
793 | QuantizeMultiplier(s_4_3, &integer_lstm_param->intermediate_scale_a[2], |
794 | &integer_lstm_param->intermediate_scale_b[2]); |
795 | QuantizeMultiplier(s_5_3, &integer_lstm_param->intermediate_scale_a[3], |
796 | &integer_lstm_param->intermediate_scale_b[3]); |
797 | QuantizeMultiplier(s_7_6, &integer_lstm_param->intermediate_scale_a[4], |
798 | &integer_lstm_param->intermediate_scale_b[4]); |
799 | QuantizeMultiplier(s_8_6, &integer_lstm_param->intermediate_scale_a[5], |
800 | &integer_lstm_param->intermediate_scale_b[5]); |
801 | QuantizeMultiplier(s_10_9, &integer_lstm_param->intermediate_scale_a[6], |
802 | &integer_lstm_param->intermediate_scale_b[6]); |
803 | QuantizeMultiplier(s_11_9, &integer_lstm_param->intermediate_scale_a[7], |
804 | &integer_lstm_param->intermediate_scale_b[7]); |
805 | } |
806 | |
807 | // Calculate quantized clip for projection and cell. |
808 | const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); |
809 | const float cell_clip = params->cell_clip; |
810 | const float proj_clip = params->proj_clip; |
811 | |
812 | TfLiteTensor* output_tensor; |
813 | TF_LITE_ENSURE_OK( |
814 | context, GetOutputSafe(context, node, kOutputTensor, &output_tensor)); |
815 | |
816 | auto* cell_state_params = reinterpret_cast<TfLiteAffineQuantization*>( |
817 | cell_state->quantization.params); |
818 | auto* proj_params = reinterpret_cast<TfLiteAffineQuantization*>( |
819 | output_tensor->quantization.params); |
820 | TF_LITE_ENSURE_EQ(context, cell_state_params->scale->data[0], 1.0 / 32768); |
821 | if (cell_clip > 0.0 && cell_clip < 1.0) { |
822 | integer_lstm_param->quantized_cell_clip = static_cast<int16_t>(std::min( |
823 | std::max(cell_clip / cell_state_params->scale->data[0], -32768.0f), |
824 | 32767.0f)); |
825 | } else { |
826 | integer_lstm_param->quantized_cell_clip = 0; |
827 | } |
828 | if (proj_clip > 0.0) { |
829 | integer_lstm_param->quantized_proj_clip = static_cast<int8_t>(std::min( |
830 | std::max(proj_clip / proj_params->scale->data[0], -128.0f), 127.0f)); |
831 | } else { |
832 | integer_lstm_param->quantized_proj_clip = 0; |
833 | } |
834 | return kTfLiteOk; |
835 | } |
836 | |
837 | } // namespace |
838 | |
839 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
840 | auto* op_data = new OpData(); |
841 | op_data->kernel_type = kTfLiteLSTMFullKernel; |
842 | // TODO(b/159066113): maybe just add the minimum required temp tensors? |
843 | context->AddTensors(context, kNumHybridTemporaryTensors, |
844 | &op_data->scratch_tensor_index); |
845 | // Tensors used for the sparse hybrid kernel. |
846 | context->AddTensors(context, /*tensors_to_add=*/kLedgersToAdd, |
847 | &op_data->ledger_index); |
848 | return op_data; |
849 | } |
850 | |
851 | // LINT.IfChange |
852 | // Check that input tensor dimensions matches with each other. |
853 | TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, |
854 | TfLiteNode* node, int n_input, |
855 | int n_output, int n_cell, |
856 | bool use_layer_norm, bool is_integer) { |
857 | const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data); |
858 | |
859 | // Making sure clipping parameters have valid values. |
860 | // == 0 means no clipping |
861 | // > 0 means clipping |
862 | TF_LITE_ENSURE(context, params->cell_clip >= 0); |
863 | TF_LITE_ENSURE(context, params->proj_clip >= 0); |
864 | |
865 | const TfLiteTensor* input_to_forget_weights; |
866 | TF_LITE_ENSURE_OK(context, |
867 | GetInputSafe(context, node, kInputToForgetWeightsTensor, |
868 | &input_to_forget_weights)); |
869 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); |
870 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); |
871 | TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); |
872 | TF_LITE_ENSURE(context, (input_to_forget_weights->type == kTfLiteFloat32) || |
873 | (input_to_forget_weights->type == kTfLiteUInt8) || |
874 | (input_to_forget_weights->type == kTfLiteInt8)); |
875 | |
876 | const TfLiteTensor* input_to_input_weights = |
877 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
878 | const bool use_cifg = (input_to_input_weights == nullptr); |
879 | if (!use_cifg) { |
880 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); |
881 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); |
882 | TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); |
883 | TF_LITE_ENSURE_TYPES_EQ(context, input_to_input_weights->type, |
884 | input_to_forget_weights->type); |
885 | } |
886 | |
887 | const TfLiteTensor* input_to_cell_weights; |
888 | TF_LITE_ENSURE_OK(context, |
889 | GetInputSafe(context, node, kInputToCellWeightsTensor, |
890 | &input_to_cell_weights)); |
891 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); |
892 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); |
893 | TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); |
894 | TF_LITE_ENSURE_TYPES_EQ(context, input_to_cell_weights->type, |
895 | input_to_forget_weights->type); |
896 | |
897 | const TfLiteTensor* recurrent_to_input_weights = |
898 | GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); |
899 | if (recurrent_to_input_weights != nullptr) { |
900 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); |
901 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], |
902 | n_cell); |
903 | TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], |
904 | n_output); |
905 | TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_input_weights->type, |
906 | input_to_forget_weights->type); |
907 | } |
908 | |
909 | const TfLiteTensor* recurrent_to_forget_weights; |
910 | TF_LITE_ENSURE_OK(context, |
911 | GetInputSafe(context, node, kRecurrentToForgetWeightsTensor, |
912 | &recurrent_to_forget_weights)); |
913 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); |
914 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], |
915 | n_cell); |
916 | TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], |
917 | n_output); |
918 | TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_forget_weights->type, |
919 | input_to_forget_weights->type); |
920 | |
921 | const TfLiteTensor* recurrent_to_cell_weights; |
922 | TF_LITE_ENSURE_OK(context, |
923 | GetInputSafe(context, node, kRecurrentToCellWeightsTensor, |
924 | &recurrent_to_cell_weights)); |
925 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); |
926 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); |
927 | TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], |
928 | n_output); |
929 | TF_LITE_ENSURE_TYPES_EQ(context, recurrent_to_cell_weights->type, |
930 | input_to_forget_weights->type); |
931 | |
932 | // We make sure the input-gate's parameters are either both present (regular |
933 | // LSTM) or not at all (CIFG-LSTM). |
934 | const bool cifg_weights_all_or_none = |
935 | ((input_to_input_weights != nullptr) && |
936 | (recurrent_to_input_weights != nullptr)) || |
937 | ((input_to_input_weights == nullptr) && |
938 | (recurrent_to_input_weights == nullptr)); |
939 | TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); |
940 | |
941 | const TfLiteTensor* cell_to_input_weights = |
942 | GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); |
943 | if (cell_to_input_weights) { |
944 | TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); |
945 | TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); |
946 | TF_LITE_ENSURE_TYPES_EQ( |
947 | context, cell_to_input_weights->type, |
948 | is_integer ? kTfLiteInt16 : input_to_forget_weights->type); |
949 | } |
950 | |
951 | const TfLiteTensor* cell_to_forget_weights = |
952 | GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); |
953 | if (cell_to_forget_weights) { |
954 | TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); |
955 | TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); |
956 | TF_LITE_ENSURE_TYPES_EQ( |
957 | context, cell_to_forget_weights->type, |
958 | is_integer ? kTfLiteInt16 : input_to_forget_weights->type); |
959 | } |
960 | |
961 | const TfLiteTensor* cell_to_output_weights = |
962 | GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); |
963 | if (cell_to_output_weights) { |
964 | TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); |
965 | TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); |
966 | TF_LITE_ENSURE_TYPES_EQ( |
967 | context, cell_to_output_weights->type, |
968 | is_integer ? kTfLiteInt16 : input_to_forget_weights->type); |
969 | } |
970 | |
971 | // Making sure the peephole weights are there all or none. |
972 | const bool peephole_weights_all_or_none = |
973 | ((cell_to_input_weights != nullptr || use_cifg) && |
974 | (cell_to_forget_weights != nullptr) && |
975 | (cell_to_output_weights != nullptr)) || |
976 | ((cell_to_input_weights == nullptr) && |
977 | (cell_to_forget_weights == nullptr) && |
978 | (cell_to_output_weights == nullptr)); |
979 | TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); |
980 | |
981 | // Make sure the input gate bias is present only when not a CIFG-LSTM. |
982 | const TfLiteTensor* input_gate_bias = |
983 | GetOptionalInputTensor(context, node, kInputGateBiasTensor); |
984 | if (use_cifg) { |
985 | TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); |
986 | } else { |
987 | TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); |
988 | TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); |
989 | if (is_integer) { |
990 | TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteInt32); |
991 | } else { |
992 | TF_LITE_ENSURE_TYPES_EQ(context, input_gate_bias->type, kTfLiteFloat32); |
993 | } |
994 | } |
995 | |
996 | const TfLiteTensor* forget_gate_bias; |
997 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor, |
998 | &forget_gate_bias)); |
999 | TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); |
1000 | TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); |
1001 | if (is_integer) { |
1002 | TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteInt32); |
1003 | } else { |
1004 | TF_LITE_ENSURE_TYPES_EQ(context, forget_gate_bias->type, kTfLiteFloat32); |
1005 | } |
1006 | |
1007 | const TfLiteTensor* cell_gate_bias; |
1008 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor, |
1009 | &cell_gate_bias)); |
1010 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->size, 1); |
1011 | TF_LITE_ENSURE_EQ(context, cell_gate_bias->dims->data[0], n_cell); |
1012 | if (is_integer) { |
1013 | TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteInt32); |
1014 | } else { |
1015 | TF_LITE_ENSURE_TYPES_EQ(context, cell_gate_bias->type, kTfLiteFloat32); |
1016 | } |
1017 | |
1018 | const TfLiteTensor* output_gate_bias; |
1019 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor, |
1020 | &output_gate_bias)); |
1021 | TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); |
1022 | TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); |
1023 | if (is_integer) { |
1024 | TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteInt32); |
1025 | } else { |
1026 | TF_LITE_ENSURE_TYPES_EQ(context, output_gate_bias->type, kTfLiteFloat32); |
1027 | } |
1028 | |
1029 | const TfLiteTensor* projection_weights = |
1030 | GetOptionalInputTensor(context, node, kProjectionWeightsTensor); |
1031 | if (projection_weights != nullptr) { |
1032 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); |
1033 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); |
1034 | TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); |
1035 | TF_LITE_ENSURE_TYPES_EQ(context, projection_weights->type, |
1036 | input_to_forget_weights->type); |
1037 | } |
1038 | |
1039 | const TfLiteTensor* projection_bias = |
1040 | GetOptionalInputTensor(context, node, kProjectionBiasTensor); |
1041 | if (projection_bias != nullptr) { |
1042 | TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); |
1043 | TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); |
1044 | if (is_integer) { |
1045 | TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteInt32); |
1046 | } else { |
1047 | TF_LITE_ENSURE_TYPES_EQ(context, projection_bias->type, kTfLiteFloat32); |
1048 | } |
1049 | } |
1050 | |
1051 | // Making sure the projection tensors are consistent: |
1052 | // 1) If projection weight is not present, then projection bias should not be |
1053 | // present. |
1054 | // 2) If projection weight is present, then projection bias is optional. |
1055 | // TODO(ghodrat): make sure this is correct. |
1056 | const bool projection_tensors_consistent = |
1057 | ((projection_weights != nullptr) || (projection_bias == nullptr)); |
1058 | TF_LITE_ENSURE(context, projection_tensors_consistent == true); |
1059 | |
1060 | if (use_layer_norm) { |
1061 | const TfLiteTensor* input_layer_norm_coefficients = GetOptionalInputTensor( |
1062 | context, node, kInputLayerNormCoefficientsTensor); |
1063 | if (use_cifg) { |
1064 | TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients, nullptr); |
1065 | } else { |
1066 | TF_LITE_ENSURE(context, input_layer_norm_coefficients != nullptr); |
1067 | TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->size, 1); |
1068 | TF_LITE_ENSURE_EQ(context, input_layer_norm_coefficients->dims->data[0], |
1069 | n_cell); |
1070 | if (is_integer) { |
1071 | TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, |
1072 | kTfLiteInt16); |
1073 | } else { |
1074 | TF_LITE_ENSURE_TYPES_EQ(context, input_layer_norm_coefficients->type, |
1075 | kTfLiteFloat32); |
1076 | } |
1077 | } |
1078 | |
1079 | const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( |
1080 | context, node, kForgetLayerNormCoefficientsTensor); |
1081 | TF_LITE_ENSURE(context, forget_layer_norm_coefficients != nullptr); |
1082 | TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->size, 1); |
1083 | TF_LITE_ENSURE_EQ(context, forget_layer_norm_coefficients->dims->data[0], |
1084 | n_cell); |
1085 | if (is_integer) { |
1086 | TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, |
1087 | kTfLiteInt16); |
1088 | } else { |
1089 | TF_LITE_ENSURE_TYPES_EQ(context, forget_layer_norm_coefficients->type, |
1090 | kTfLiteFloat32); |
1091 | } |
1092 | |
1093 | const TfLiteTensor* cell_layer_norm_coefficients = |
1094 | GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor); |
1095 | TF_LITE_ENSURE(context, cell_layer_norm_coefficients != nullptr); |
1096 | TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->size, 1); |
1097 | TF_LITE_ENSURE_EQ(context, cell_layer_norm_coefficients->dims->data[0], |
1098 | n_cell); |
1099 | if (is_integer) { |
1100 | TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, |
1101 | kTfLiteInt16); |
1102 | } else { |
1103 | TF_LITE_ENSURE_TYPES_EQ(context, cell_layer_norm_coefficients->type, |
1104 | kTfLiteFloat32); |
1105 | } |
1106 | |
1107 | const TfLiteTensor* output_layer_norm_coefficients = GetOptionalInputTensor( |
1108 | context, node, kOutputLayerNormCoefficientsTensor); |
1109 | TF_LITE_ENSURE(context, output_layer_norm_coefficients != nullptr); |
1110 | TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->size, 1); |
1111 | TF_LITE_ENSURE_EQ(context, output_layer_norm_coefficients->dims->data[0], |
1112 | n_cell); |
1113 | if (is_integer) { |
1114 | TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, |
1115 | kTfLiteInt16); |
1116 | } else { |
1117 | TF_LITE_ENSURE_TYPES_EQ(context, output_layer_norm_coefficients->type, |
1118 | kTfLiteFloat32); |
1119 | } |
1120 | } |
1121 | |
1122 | return kTfLiteOk; |
1123 | } |
1124 | // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc) |
1125 | |
1126 | TfLiteStatus PrecomputeZeroPointTimesWeightWithBias( |
1127 | TfLiteContext* context, int32_t zero_point, |
1128 | const TfLiteTensor* weight_tensor, const TfLiteTensor* bias_tensor, |
1129 | std::unique_ptr<int32_t[]>* output) { |
1130 | if (weight_tensor == nullptr) { |
1131 | return kTfLiteOk; |
1132 | } |
1133 | |
1134 | const RuntimeShape& weight_shape = GetTensorShape(weight_tensor); |
1135 | TF_LITE_ENSURE_EQ(context, weight_shape.DimensionsCount(), 2); |
1136 | const int row = weight_shape.Dims(0); |
1137 | const int col = weight_shape.Dims(1); |
1138 | output->reset(new int32_t[row]); |
1139 | if (bias_tensor == nullptr) { |
1140 | memset(output->get(), 0, row * sizeof(int32_t)); |
1141 | } else { |
1142 | const int32_t* bias = GetTensorData<int32_t>(bias_tensor); |
1143 | memcpy(output->get(), bias, row * sizeof(int32_t)); |
1144 | } |
1145 | if (zero_point != 0) { |
1146 | const int8_t* weight = GetTensorData<int8_t>(weight_tensor); |
1147 | tensor_utils::MatrixScalarMultiplyAccumulate(weight, zero_point, row, col, |
1148 | output->get()); |
1149 | } |
1150 | return kTfLiteOk; |
1151 | } |
1152 | |
1153 | TfLiteStatus PopulatePrecomputedZPTimesWeightsWithBias(TfLiteContext* context, |
1154 | OpData* op_data, |
1155 | TfLiteNode* node) { |
1156 | const TfLiteTensor* input; |
1157 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
1158 | const TfLiteTensor* output_state = |
1159 | GetVariableInput(context, node, kOutputStateTensor); |
1160 | TF_LITE_ENSURE(context, output_state != nullptr); |
1161 | |
1162 | const int32_t input_zero_point = -input->params.zero_point; |
1163 | const int32_t output_state_zero_point = -output_state->params.zero_point; |
1164 | |
1165 | const TfLiteTensor* input_to_input_weights = |
1166 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
1167 | const TfLiteTensor* input_to_forget_weights; |
1168 | TF_LITE_ENSURE_OK(context, |
1169 | GetInputSafe(context, node, kInputToForgetWeightsTensor, |
1170 | &input_to_forget_weights)); |
1171 | const TfLiteTensor* input_to_cell_weights; |
1172 | TF_LITE_ENSURE_OK(context, |
1173 | GetInputSafe(context, node, kInputToCellWeightsTensor, |
1174 | &input_to_cell_weights)); |
1175 | const TfLiteTensor* input_to_output_weights; |
1176 | TF_LITE_ENSURE_OK(context, |
1177 | GetInputSafe(context, node, kInputToOutputWeightsTensor, |
1178 | &input_to_output_weights)); |
1179 | |
1180 | const TfLiteTensor* recurrent_to_input_weights = |
1181 | GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); |
1182 | const TfLiteTensor* recurrent_to_forget_weights; |
1183 | TF_LITE_ENSURE_OK(context, |
1184 | GetInputSafe(context, node, kRecurrentToForgetWeightsTensor, |
1185 | &recurrent_to_forget_weights)); |
1186 | const TfLiteTensor* recurrent_to_cell_weights; |
1187 | TF_LITE_ENSURE_OK(context, |
1188 | GetInputSafe(context, node, kRecurrentToCellWeightsTensor, |
1189 | &recurrent_to_cell_weights)); |
1190 | const TfLiteTensor* recurrent_to_output_weights; |
1191 | TF_LITE_ENSURE_OK(context, |
1192 | GetInputSafe(context, node, kRecurrentToOutputWeightsTensor, |
1193 | &recurrent_to_output_weights)); |
1194 | |
1195 | const TfLiteTensor* projection_weights = |
1196 | GetOptionalInputTensor(context, node, kProjectionWeightsTensor); |
1197 | const TfLiteTensor* projection_bias = |
1198 | GetOptionalInputTensor(context, node, kProjectionBiasTensor); |
1199 | |
1200 | lstm_eval::IntegerLstmParameter* integer_lstm_params = |
1201 | &op_data->integer_lstm_param; |
1202 | |
1203 | const TfLiteTensor* intermediate = |
1204 | &context->tensors[node->intermediates->data[4]]; |
1205 | const auto* params = |
1206 | static_cast<TfLiteAffineQuantization*>(intermediate->quantization.params); |
1207 | const int32_t hidden_zp = params->zero_point->data[0]; |
1208 | |
1209 | // Get bias and perform zero point calculation. |
1210 | // When there is layer normalization, the gate bias does not apply to matmul |
1211 | // directly: |
1212 | // y = ln(w * x + w * r + w * c) + b. |
1213 | const bool is_layer_norm = op_data->use_layer_norm; |
1214 | |
1215 | // Forget gate. |
1216 | const TfLiteTensor* forget_gate_bias = |
1217 | is_layer_norm ? nullptr : GetInput(context, node, kForgetGateBiasTensor); |
1218 | TF_LITE_ENSURE_OK( |
1219 | context, |
1220 | PrecomputeZeroPointTimesWeightWithBias( |
1221 | context, input_zero_point, input_to_forget_weights, forget_gate_bias, |
1222 | &(integer_lstm_params->input_to_forget_effective_bias))); |
1223 | |
1224 | TF_LITE_ENSURE_OK( |
1225 | context, |
1226 | PrecomputeZeroPointTimesWeightWithBias( |
1227 | context, output_state_zero_point, recurrent_to_forget_weights, |
1228 | nullptr, &(integer_lstm_params->recurrent_to_forget_effective_bias))); |
1229 | |
1230 | // Modulation gate. |
1231 | const TfLiteTensor* cell_gate_bias = |
1232 | is_layer_norm ? nullptr : GetInput(context, node, kCellGateBiasTensor); |
1233 | TF_LITE_ENSURE_OK( |
1234 | context, |
1235 | PrecomputeZeroPointTimesWeightWithBias( |
1236 | context, input_zero_point, input_to_cell_weights, cell_gate_bias, |
1237 | &(integer_lstm_params->input_to_cell_effective_bias))); |
1238 | TF_LITE_ENSURE_OK( |
1239 | context, |
1240 | PrecomputeZeroPointTimesWeightWithBias( |
1241 | context, output_state_zero_point, recurrent_to_cell_weights, nullptr, |
1242 | &(integer_lstm_params->recurrent_to_cell_effective_bias))); |
1243 | |
1244 | // Output gate. |
1245 | const TfLiteTensor* output_gate_bias = |
1246 | is_layer_norm ? nullptr : GetInput(context, node, kOutputGateBiasTensor); |
1247 | TF_LITE_ENSURE_OK( |
1248 | context, |
1249 | PrecomputeZeroPointTimesWeightWithBias( |
1250 | context, input_zero_point, input_to_output_weights, output_gate_bias, |
1251 | &(integer_lstm_params->input_to_output_effective_bias))); |
1252 | |
1253 | TF_LITE_ENSURE_OK( |
1254 | context, |
1255 | PrecomputeZeroPointTimesWeightWithBias( |
1256 | context, output_state_zero_point, recurrent_to_output_weights, |
1257 | nullptr, &(integer_lstm_params->recurrent_to_output_effective_bias))); |
1258 | |
1259 | // Input gate. The calculation is only meaningful for non-cifg case. |
1260 | const TfLiteTensor* input_gate_bias = |
1261 | is_layer_norm ? nullptr : GetInput(context, node, kInputGateBiasTensor); |
1262 | TF_LITE_ENSURE_OK( |
1263 | context, |
1264 | PrecomputeZeroPointTimesWeightWithBias( |
1265 | context, input_zero_point, input_to_input_weights, input_gate_bias, |
1266 | &(integer_lstm_params->input_to_input_effective_bias))); |
1267 | TF_LITE_ENSURE_OK( |
1268 | context, |
1269 | PrecomputeZeroPointTimesWeightWithBias( |
1270 | context, output_state_zero_point, recurrent_to_input_weights, nullptr, |
1271 | &(integer_lstm_params->recurrent_to_input_effective_bias))); |
1272 | |
1273 | // Projection bias. The calculation is only meaningful for with projection. |
1274 | TF_LITE_ENSURE_OK(context, |
1275 | PrecomputeZeroPointTimesWeightWithBias( |
1276 | context, hidden_zp, projection_weights, projection_bias, |
1277 | &(integer_lstm_params->projection_effective_bias))); |
1278 | return kTfLiteOk; |
1279 | } |
1280 | |
1281 | // Resize the output, state tensors based on the sizes of the input tensors. |
1282 | // Allocate a temporary scratch tensor. Also check that the sizes of the input |
1283 | // tensors match each other. |
1284 | // LINT.IfChange |
1285 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
1286 | OpData* op_data = static_cast<OpData*>(node->user_data); |
1287 | |
1288 | TF_LITE_ENSURE_EQ(context, node->outputs->size, 1); |
1289 | // Logic for determining regular lstm and layer norm lstm: |
1290 | // input_size, forget_gate_layer_norm_tensor (20) null? is_layer_norm? |
1291 | // 20, N/A, No. |
1292 | // 24, null, No. |
1293 | // 24, not null, Yes. |
1294 | // 20-inputs lstm are deprecated and is only kept here for backward |
1295 | // compatibility. |
1296 | if (node->inputs->size == 24) { |
1297 | const TfLiteTensor* forget_layer_norm_coefficients = GetOptionalInputTensor( |
1298 | context, node, kForgetLayerNormCoefficientsTensor); |
1299 | if (forget_layer_norm_coefficients == nullptr) { |
1300 | op_data->use_layer_norm = false; |
1301 | } else { |
1302 | op_data->use_layer_norm = true; |
1303 | } |
1304 | } else if (node->inputs->size == 20) { |
1305 | // This is deprecated and is only kept here for backward compatibility. |
1306 | op_data->use_layer_norm = false; |
1307 | } else { |
1308 | TF_LITE_KERNEL_LOG( |
1309 | context, "The LSTM Full kernel expects 20 or 24 inputs. Got %d inputs" , |
1310 | node->inputs->size); |
1311 | return kTfLiteError; |
1312 | } |
1313 | |
1314 | const bool use_layer_norm = op_data->use_layer_norm; |
1315 | |
1316 | // Inferring batch size, number of outputs and number of cells from the |
1317 | // input tensors. |
1318 | const TfLiteTensor* input; |
1319 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
1320 | const bool is_integer = input->type == kTfLiteInt8; |
1321 | TF_LITE_ENSURE(context, input->dims->size > 1); |
1322 | const int n_batch = input->dims->data[0]; |
1323 | const int n_input = input->dims->data[1]; |
1324 | |
1325 | const TfLiteTensor* input_to_output_weights; |
1326 | TF_LITE_ENSURE_OK(context, |
1327 | GetInputSafe(context, node, kInputToOutputWeightsTensor, |
1328 | &input_to_output_weights)); |
1329 | const int n_cell = input_to_output_weights->dims->data[0]; |
1330 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); |
1331 | TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); |
1332 | |
1333 | const TfLiteTensor* recurrent_to_output_weights; |
1334 | TF_LITE_ENSURE_OK(context, |
1335 | GetInputSafe(context, node, kRecurrentToOutputWeightsTensor, |
1336 | &recurrent_to_output_weights)); |
1337 | TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); |
1338 | TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], |
1339 | n_cell); |
1340 | const int n_output = recurrent_to_output_weights->dims->data[1]; |
1341 | |
1342 | // Check that input tensor dimensions matches with each other. |
1343 | TF_LITE_ENSURE_OK( |
1344 | context, CheckInputTensorDimensions(context, node, n_input, n_output, |
1345 | n_cell, use_layer_norm, is_integer)); |
1346 | |
1347 | // Get the pointer to output, output_state and cell_state tensors. |
1348 | TfLiteTensor* output; |
1349 | TF_LITE_ENSURE_OK(context, |
1350 | GetOutputSafe(context, node, kOutputTensor, &output)); |
1351 | |
1352 | TfLiteTensor* output_state = |
1353 | GetVariableInput(context, node, kOutputStateTensor); |
1354 | TF_LITE_ENSURE(context, output_state != nullptr); |
1355 | TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor); |
1356 | TF_LITE_ENSURE(context, cell_state != nullptr); |
1357 | |
1358 | // Check the shape of input state tensors. |
1359 | // These tensor may be 1D or 2D. It's fine as long as the total size is |
1360 | // correct. |
1361 | TF_LITE_ENSURE_EQ(context, NumElements(output_state), n_batch * n_output); |
1362 | TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell); |
1363 | |
1364 | // Resize the output tensors. |
1365 | TfLiteIntArray* output_size = TfLiteIntArrayCreate(2); |
1366 | output_size->data[0] = n_batch; |
1367 | output_size->data[1] = n_output; |
1368 | TF_LITE_ENSURE_OK(context, |
1369 | context->ResizeTensor(context, output, output_size)); |
1370 | |
1371 | // The weights are of consistent type, so it suffices to check one. |
1372 | const bool is_hybrid_op = IsHybridOp(input, input_to_output_weights); |
1373 | |
1374 | const bool is_sparse_op = (input_to_output_weights->sparsity != nullptr); |
1375 | |
1376 | // The type of Integer LSTM. |
1377 | const int num_intermediate_tensors = node->intermediates->size; |
1378 | if (is_integer) { |
1379 | TF_LITE_ENSURE(context, num_intermediate_tensors == 5 || |
1380 | num_intermediate_tensors == 12); |
1381 | } |
1382 | // We use number of intermediate tensors to distinguish the 8 bit matmul |
1383 | // output and the 16 bit matmul output version. |
1384 | const bool is_8x8_16 = num_intermediate_tensors == 5; |
1385 | |
1386 | TfLiteIntArrayFree(node->temporaries); |
1387 | if (is_hybrid_op) { |
1388 | if (is_sparse_op) { |
1389 | node->temporaries = |
1390 | TfLiteIntArrayCreate(kNumHybridTemporaryTensors + kLedgersToAdd); |
1391 | } else { |
1392 | node->temporaries = TfLiteIntArrayCreate(kNumHybridTemporaryTensors); |
1393 | } |
1394 | } else if (is_integer) { |
1395 | if (is_8x8_16) { |
1396 | node->temporaries = TfLiteIntArrayCreate(6); |
1397 | } else { |
1398 | node->temporaries = TfLiteIntArrayCreate(8); |
1399 | } |
1400 | } else { |
1401 | node->temporaries = TfLiteIntArrayCreate(1); |
1402 | } |
1403 | |
1404 | // Create a scratch buffer tensor for float case and hybrid case. |
1405 | // TODO(b/152066492): Create a is_float boolean and reorganize the temporary |
1406 | // buffer allocation logic. |
1407 | if (!is_integer) { |
1408 | node->temporaries->data[kScratchBuffer] = |
1409 | op_data->scratch_tensor_index + kScratchBuffer; |
1410 | TfLiteTensor* scratch_buffer; |
1411 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kScratchBuffer, |
1412 | &scratch_buffer)); |
1413 | scratch_buffer->type = input->type; |
1414 | scratch_buffer->allocation_type = kTfLiteArenaRw; |
1415 | |
1416 | const TfLiteTensor* input_to_input_weights = |
1417 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
1418 | const bool use_cifg = (input_to_input_weights == nullptr); |
1419 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); |
1420 | scratch_buffer_size->data[0] = n_batch; |
1421 | if (use_cifg) { |
1422 | // Reserving space for Cell, Forget, Output gates and scratch accumulation |
1423 | // buffer and an extra 16 bytes to avoid internal ruy copies. |
1424 | scratch_buffer_size->data[1] = n_cell * 4; |
1425 | } else { |
1426 | // Reserving space for Input, Cell, Forget, Output gates and scratch |
1427 | // accumulation buffer and an extra 16 bytes to avoid internal ruy copies. |
1428 | scratch_buffer_size->data[1] = n_cell * 5; |
1429 | } |
1430 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, |
1431 | scratch_buffer_size)); |
1432 | } |
1433 | |
1434 | if (is_hybrid_op) { |
1435 | if (!is_sparse_op) { |
1436 | op_data->compute_row_sums = true; |
1437 | } |
1438 | // Allocate temporary tensors to store quantized values of input, |
1439 | // output_state and cell_state tensors. |
1440 | node->temporaries->data[kInputQuantized] = |
1441 | op_data->scratch_tensor_index + kInputQuantized; |
1442 | TfLiteTensor* input_quantized; |
1443 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kInputQuantized, |
1444 | &input_quantized)); |
1445 | input_quantized->type = input_to_output_weights->type; |
1446 | input_quantized->allocation_type = kTfLiteArenaRw; |
1447 | if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) { |
1448 | TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims); |
1449 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized, |
1450 | input_quantized_size)); |
1451 | } |
1452 | node->temporaries->data[kOutputStateQuantized] = |
1453 | op_data->scratch_tensor_index + kOutputStateQuantized; |
1454 | TfLiteTensor* output_state_quantized; |
1455 | TF_LITE_ENSURE_OK(context, |
1456 | GetTemporarySafe(context, node, kOutputStateQuantized, |
1457 | &output_state_quantized)); |
1458 | output_state_quantized->type = input_to_output_weights->type; |
1459 | output_state_quantized->allocation_type = kTfLiteArenaRw; |
1460 | if (!TfLiteIntArrayEqual(output_state_quantized->dims, |
1461 | output_state->dims)) { |
1462 | TfLiteIntArray* output_state_quantized_size = |
1463 | TfLiteIntArrayCopy(output_state->dims); |
1464 | TF_LITE_ENSURE_OK(context, |
1465 | context->ResizeTensor(context, output_state_quantized, |
1466 | output_state_quantized_size)); |
1467 | } |
1468 | node->temporaries->data[kCellStateQuantized] = |
1469 | op_data->scratch_tensor_index + kCellStateQuantized; |
1470 | TfLiteTensor* cell_state_quantized; |
1471 | TF_LITE_ENSURE_OK(context, |
1472 | GetTemporarySafe(context, node, kCellStateQuantized, |
1473 | &cell_state_quantized)); |
1474 | cell_state_quantized->type = input_to_output_weights->type; |
1475 | cell_state_quantized->allocation_type = kTfLiteArenaRw; |
1476 | if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) { |
1477 | TfLiteIntArray* cell_state_quantized_size = |
1478 | TfLiteIntArrayCopy(cell_state->dims); |
1479 | TF_LITE_ENSURE_OK(context, |
1480 | context->ResizeTensor(context, cell_state_quantized, |
1481 | cell_state_quantized_size)); |
1482 | } |
1483 | // Allocate temporary tensors to store scaling factors and product scaling |
1484 | // factors. The latter is a convenience storage which allows to quantize |
1485 | // a vector once (which produces the scaling factors) and multiply it with |
1486 | // different matrices (which requires multiplying the scaling factors with |
1487 | // the scaling factor of the matrix). |
1488 | node->temporaries->data[kInputScalingFactors] = |
1489 | op_data->scratch_tensor_index + kInputScalingFactors; |
1490 | TfLiteTensor* input_sf; |
1491 | TF_LITE_ENSURE_OK( |
1492 | context, |
1493 | GetTemporarySafe(context, node, kInputScalingFactors, &input_sf)); |
1494 | input_sf->type = kTfLiteFloat32; |
1495 | input_sf->allocation_type = kTfLiteArenaRw; |
1496 | int scaling_dims[1] = {n_batch}; |
1497 | if (!TfLiteIntArrayEqualsArray(input_sf->dims, 1, scaling_dims)) { |
1498 | TfLiteIntArray* input_sf_size = TfLiteIntArrayCreate(1); |
1499 | input_sf_size->data[0] = n_batch; |
1500 | TF_LITE_ENSURE_OK( |
1501 | context, context->ResizeTensor(context, input_sf, input_sf_size)); |
1502 | } |
1503 | node->temporaries->data[kOutputStateScalingFactors] = |
1504 | op_data->scratch_tensor_index + kOutputStateScalingFactors; |
1505 | TfLiteTensor* output_state_sf; |
1506 | TF_LITE_ENSURE_OK( |
1507 | context, GetTemporarySafe(context, node, kOutputStateScalingFactors, |
1508 | &output_state_sf)); |
1509 | output_state_sf->type = kTfLiteFloat32; |
1510 | output_state_sf->allocation_type = kTfLiteArenaRw; |
1511 | if (!TfLiteIntArrayEqualsArray(output_state_sf->dims, 1, scaling_dims)) { |
1512 | TfLiteIntArray* output_state_sf_size = TfLiteIntArrayCreate(1); |
1513 | output_state_sf_size->data[0] = n_batch; |
1514 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_sf, |
1515 | output_state_sf_size)); |
1516 | } |
1517 | node->temporaries->data[kProductScalingFactors] = |
1518 | op_data->scratch_tensor_index + kProductScalingFactors; |
1519 | TfLiteTensor* prod_scaling_factors; |
1520 | TF_LITE_ENSURE_OK(context, |
1521 | GetTemporarySafe(context, node, kProductScalingFactors, |
1522 | &prod_scaling_factors)); |
1523 | prod_scaling_factors->type = kTfLiteFloat32; |
1524 | prod_scaling_factors->allocation_type = kTfLiteArenaRw; |
1525 | if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1, |
1526 | scaling_dims)) { |
1527 | TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1); |
1528 | prod_scaling_factors_size->data[0] = n_batch; |
1529 | TF_LITE_ENSURE_OK(context, |
1530 | context->ResizeTensor(context, prod_scaling_factors, |
1531 | prod_scaling_factors_size)); |
1532 | } |
1533 | |
1534 | // Allocate a temporary tensor to store the recovered cell weights. Since |
1535 | // this is used for diagonal matrices, only need to store n_cell values. |
1536 | node->temporaries->data[kRecoveredCellWeights] = |
1537 | op_data->scratch_tensor_index + kRecoveredCellWeights; |
1538 | TfLiteTensor* recovered_cell_weights; |
1539 | TF_LITE_ENSURE_OK(context, |
1540 | GetTemporarySafe(context, node, kRecoveredCellWeights, |
1541 | &recovered_cell_weights)); |
1542 | recovered_cell_weights->type = kTfLiteFloat32; |
1543 | recovered_cell_weights->allocation_type = kTfLiteArenaRw; |
1544 | int recovered_cell_dims[1] = {n_cell}; |
1545 | if (!TfLiteIntArrayEqualsArray(recovered_cell_weights->dims, 1, |
1546 | recovered_cell_dims)) { |
1547 | TfLiteIntArray* recovered_cell_weights_size = TfLiteIntArrayCreate(1); |
1548 | recovered_cell_weights_size->data[0] = n_cell; |
1549 | TF_LITE_ENSURE_OK(context, |
1550 | context->ResizeTensor(context, recovered_cell_weights, |
1551 | recovered_cell_weights_size)); |
1552 | } |
1553 | // Allocate a temporary tensor to store accumulate values for matrix |
1554 | // multiplication before multiplication by scaling factor |
1555 | node->temporaries->data[kAccumScratch] = |
1556 | op_data->scratch_tensor_index + kAccumScratch; |
1557 | TfLiteTensor* accum_scratch; |
1558 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kAccumScratch, |
1559 | &accum_scratch)); |
1560 | accum_scratch->type = kTfLiteInt32; |
1561 | accum_scratch->allocation_type = kTfLiteArenaRw; |
1562 | int accum_scratch_dims[2] = {n_cell, n_batch}; |
1563 | if (!TfLiteIntArrayEqualsArray(accum_scratch->dims, 2, |
1564 | accum_scratch_dims)) { |
1565 | TfLiteIntArray* accum_size = TfLiteIntArrayCreate(2); |
1566 | accum_size->data[0] = n_cell; |
1567 | accum_size->data[1] = n_batch; |
1568 | TF_LITE_ENSURE_OK( |
1569 | context, context->ResizeTensor(context, accum_scratch, accum_size)); |
1570 | } |
1571 | node->temporaries->data[kInputZeroPoints] = |
1572 | op_data->scratch_tensor_index + kInputZeroPoints; |
1573 | TfLiteTensor* input_zp; |
1574 | TF_LITE_ENSURE_OK( |
1575 | context, GetTemporarySafe(context, node, kInputZeroPoints, &input_zp)); |
1576 | input_zp->type = kTfLiteFloat32; |
1577 | input_zp->allocation_type = kTfLiteArenaRw; |
1578 | if (!TfLiteIntArrayEqualsArray(input_zp->dims, 1, scaling_dims)) { |
1579 | TfLiteIntArray* input_zp_size = TfLiteIntArrayCreate(1); |
1580 | input_zp_size->data[0] = n_batch; |
1581 | TF_LITE_ENSURE_OK( |
1582 | context, context->ResizeTensor(context, input_zp, input_zp_size)); |
1583 | } |
1584 | node->temporaries->data[kOutputStateZeroPoints] = |
1585 | op_data->scratch_tensor_index + kOutputStateZeroPoints; |
1586 | TfLiteTensor* output_state_zp; |
1587 | TF_LITE_ENSURE_OK(context, |
1588 | GetTemporarySafe(context, node, kOutputStateZeroPoints, |
1589 | &output_state_zp)); |
1590 | output_state_zp->type = kTfLiteFloat32; |
1591 | output_state_zp->allocation_type = kTfLiteArenaRw; |
1592 | if (!TfLiteIntArrayEqualsArray(output_state_zp->dims, 1, scaling_dims)) { |
1593 | TfLiteIntArray* output_state_zp_size = TfLiteIntArrayCreate(1); |
1594 | output_state_zp_size->data[0] = n_batch; |
1595 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output_state_zp, |
1596 | output_state_zp_size)); |
1597 | } |
1598 | |
1599 | node->temporaries->data[kRowSums] = |
1600 | op_data->scratch_tensor_index + kRowSums; |
1601 | const TfLiteTensor* input_to_input_weights = |
1602 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
1603 | const bool use_cifg = (input_to_input_weights == nullptr); |
1604 | int row_sums_rows = use_cifg ? 6 : 8; |
1605 | const TfLiteTensor* projection_weights = |
1606 | GetOptionalInputTensor(context, node, kProjectionWeightsTensor); |
1607 | if (projection_weights != nullptr) { |
1608 | row_sums_rows += ceil(static_cast<float>(n_output) / n_cell); |
1609 | } |
1610 | |
1611 | TfLiteTensor* row_sums; |
1612 | TF_LITE_ENSURE_OK(context, |
1613 | GetTemporarySafe(context, node, kRowSums, &row_sums)); |
1614 | row_sums->type = kTfLiteInt32; |
1615 | row_sums->name = "Lstm_row_sums" ; |
1616 | row_sums->allocation_type = kTfLiteArenaRwPersistent; |
1617 | const int row_sums_dims[2] = {row_sums_rows, n_cell}; |
1618 | if (!TfLiteIntArrayEqualsArray(row_sums->dims, 2, row_sums_dims)) { |
1619 | TfLiteIntArray* row_sums_size = TfLiteIntArrayCreate(2); |
1620 | row_sums_size->data[0] = row_sums_dims[0]; |
1621 | row_sums_size->data[1] = row_sums_dims[1]; |
1622 | TF_LITE_ENSURE_OK( |
1623 | context, context->ResizeTensor(context, row_sums, row_sums_size)); |
1624 | } |
1625 | |
1626 | if (is_sparse_op) { |
1627 | op_data->ledger_initialized = false; |
1628 | int offset = kNumHybridTemporaryTensors; |
1629 | { |
1630 | node->temporaries->data[offset + kInputToInputWeightsLedgerOffset] = |
1631 | op_data->ledger_index + kInputToInputWeightsLedgerOffset; |
1632 | const TfLiteTensor* input_to_input_weights = |
1633 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
1634 | TfLiteTensor* input_to_input_weights_ledger = |
1635 | &context->tensors[op_data->ledger_index + |
1636 | kInputToInputWeightsLedgerOffset]; |
1637 | auto status = make_ledger(input_to_input_weights == nullptr |
1638 | ? nullptr |
1639 | : input_to_input_weights->sparsity, |
1640 | context, input_to_input_weights_ledger); |
1641 | if (status != kTfLiteOk) return status; |
1642 | } |
1643 | { |
1644 | node->temporaries->data[offset + kInputToForgetWeightsLedgerOffset] = |
1645 | op_data->ledger_index + kInputToForgetWeightsLedgerOffset; |
1646 | const TfLiteTensor* input_to_forget_weights = |
1647 | GetInput(context, node, kInputToForgetWeightsTensor); |
1648 | TfLiteTensor* input_to_forget_weights_ledger = |
1649 | &context->tensors[op_data->ledger_index + |
1650 | kInputToForgetWeightsLedgerOffset]; |
1651 | auto status = make_ledger(input_to_forget_weights->sparsity, context, |
1652 | input_to_forget_weights_ledger); |
1653 | if (status != kTfLiteOk) return status; |
1654 | } |
1655 | { |
1656 | node->temporaries->data[offset + kInputToCellWeightsLedgerOffset] = |
1657 | op_data->ledger_index + kInputToCellWeightsLedgerOffset; |
1658 | const TfLiteTensor* input_to_cell_weights = |
1659 | GetInput(context, node, kInputToCellWeightsTensor); |
1660 | TfLiteTensor* input_to_cell_weights_ledger = |
1661 | &context->tensors[op_data->ledger_index + |
1662 | kInputToCellWeightsLedgerOffset]; |
1663 | auto status = make_ledger(input_to_cell_weights->sparsity, context, |
1664 | input_to_cell_weights_ledger); |
1665 | if (status != kTfLiteOk) return status; |
1666 | } |
1667 | { |
1668 | node->temporaries->data[offset + kInputToOutputWeightsLedgerOffset] = |
1669 | op_data->ledger_index + kInputToOutputWeightsLedgerOffset; |
1670 | const TfLiteTensor* input_to_output_weights = |
1671 | GetInput(context, node, kInputToOutputWeightsTensor); |
1672 | TfLiteTensor* input_to_output_weights_ledger = |
1673 | &context->tensors[op_data->ledger_index + |
1674 | kInputToOutputWeightsLedgerOffset]; |
1675 | auto status = make_ledger(input_to_output_weights->sparsity, context, |
1676 | input_to_output_weights_ledger); |
1677 | if (status != kTfLiteOk) return status; |
1678 | } |
1679 | { |
1680 | node->temporaries->data[offset + kRecurrentToInputWeightsLedgerOffset] = |
1681 | op_data->ledger_index + kRecurrentToInputWeightsLedgerOffset; |
1682 | const TfLiteTensor* recurrent_to_input_weights = GetOptionalInputTensor( |
1683 | context, node, kRecurrentToInputWeightsTensor); |
1684 | TfLiteTensor* recurrent_to_input_weights_ledger = |
1685 | &context->tensors[op_data->ledger_index + |
1686 | kRecurrentToInputWeightsLedgerOffset]; |
1687 | auto status = make_ledger(recurrent_to_input_weights == nullptr |
1688 | ? nullptr |
1689 | : recurrent_to_input_weights->sparsity, |
1690 | context, recurrent_to_input_weights_ledger); |
1691 | if (status != kTfLiteOk) return status; |
1692 | } |
1693 | { |
1694 | node->temporaries |
1695 | ->data[offset + kRecurrentToForgetWeightsLedgerOffset] = |
1696 | op_data->ledger_index + kRecurrentToForgetWeightsLedgerOffset; |
1697 | const TfLiteTensor* recurrent_to_forget_weights = |
1698 | GetInput(context, node, kRecurrentToForgetWeightsTensor); |
1699 | TfLiteTensor* recurrent_to_forget_weights_ledger = |
1700 | &context->tensors[op_data->ledger_index + |
1701 | kRecurrentToForgetWeightsLedgerOffset]; |
1702 | auto status = make_ledger(recurrent_to_forget_weights->sparsity, |
1703 | context, recurrent_to_forget_weights_ledger); |
1704 | if (status != kTfLiteOk) return status; |
1705 | } |
1706 | { |
1707 | node->temporaries->data[offset + kRecurrentToCellWeightsLedgerOffset] = |
1708 | op_data->ledger_index + kRecurrentToCellWeightsLedgerOffset; |
1709 | const TfLiteTensor* recurrent_to_cell_weights = |
1710 | GetInput(context, node, kRecurrentToCellWeightsTensor); |
1711 | TfLiteTensor* recurrent_to_cell_weights_ledger = |
1712 | &context->tensors[op_data->ledger_index + |
1713 | kRecurrentToCellWeightsLedgerOffset]; |
1714 | auto status = make_ledger(recurrent_to_cell_weights->sparsity, context, |
1715 | recurrent_to_cell_weights_ledger); |
1716 | if (status != kTfLiteOk) return status; |
1717 | } |
1718 | { |
1719 | node->temporaries |
1720 | ->data[offset + kRecurrentToOutputWeightsLedgerOffset] = |
1721 | op_data->ledger_index + kRecurrentToOutputWeightsLedgerOffset; |
1722 | const TfLiteTensor* recurrent_to_output_weights = |
1723 | GetInput(context, node, kRecurrentToOutputWeightsTensor); |
1724 | TfLiteTensor* recurrent_to_output_weights_ledger = |
1725 | &context->tensors[op_data->ledger_index + |
1726 | kRecurrentToOutputWeightsLedgerOffset]; |
1727 | auto status = make_ledger(recurrent_to_output_weights->sparsity, |
1728 | context, recurrent_to_output_weights_ledger); |
1729 | if (status != kTfLiteOk) return status; |
1730 | } |
1731 | { |
1732 | node->temporaries->data[offset + kProjectionWeightsLedgerOffset] = |
1733 | op_data->ledger_index + kProjectionWeightsLedgerOffset; |
1734 | const TfLiteTensor* projection_weights = |
1735 | GetInput(context, node, kProjectionWeightsTensor); |
1736 | TfLiteTensor* projection_weights_ledger = |
1737 | &context->tensors[op_data->ledger_index + |
1738 | kProjectionWeightsLedgerOffset]; |
1739 | auto status = make_ledger(projection_weights->sparsity, context, |
1740 | projection_weights_ledger); |
1741 | if (status != kTfLiteOk) return status; |
1742 | } |
1743 | } |
1744 | } |
1745 | |
1746 | if (is_integer) { |
1747 | if (is_8x8_16) { |
1748 | // Integer LSTM prepare function for 8x8->16. |
1749 | // This code path needs 5 intermediate tensors per Op. |
1750 | // Populate quantization parameters. |
1751 | PopulateQuantizedLstmParams8x8_16(context, node, |
1752 | &op_data->integer_lstm_param); |
1753 | |
1754 | // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell |
1755 | // and 1 8bit buffer with size n_batch * n_cell. We also need 1 32 bit |
1756 | // buffer with size n_batch * n_cell. |
1757 | // |
1758 | // Handle cifg case as well, which might save one buffer. |
1759 | for (int scratch_index = 0; scratch_index < 6; ++scratch_index) { |
1760 | node->temporaries->data[scratch_index] = |
1761 | op_data->scratch_tensor_index + scratch_index; |
1762 | TfLiteTensor* scratch_tensor; |
1763 | TF_LITE_ENSURE_OK( |
1764 | context, |
1765 | GetTemporarySafe(context, node, scratch_index, &scratch_tensor)); |
1766 | scratch_tensor->type = kTfLiteInt16; |
1767 | if (scratch_index == 4) { |
1768 | scratch_tensor->type = kTfLiteInt8; |
1769 | } else if (scratch_index == 5) { |
1770 | scratch_tensor->type = kTfLiteInt32; |
1771 | } |
1772 | scratch_tensor->allocation_type = kTfLiteArenaRw; |
1773 | const int scratch_dimension[2] = {n_batch, n_cell}; |
1774 | if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2, |
1775 | scratch_dimension)) { |
1776 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); |
1777 | scratch_buffer_size->data[0] = n_batch; |
1778 | scratch_buffer_size->data[1] = n_cell; |
1779 | TF_LITE_ENSURE_OK(context, |
1780 | context->ResizeTensor(context, scratch_tensor, |
1781 | scratch_buffer_size)); |
1782 | } |
1783 | } |
1784 | |
1785 | // Populate precomputed zp * weight. |
1786 | TF_LITE_ENSURE_OK(context, PopulatePrecomputedZPTimesWeightsWithBias( |
1787 | context, op_data, node)); |
1788 | } else { |
1789 | // Integer LSTM prepare function for 8x8->8. |
1790 | // This code path needs 12 intermediate tensors per Op. |
1791 | PopulateQuantizedLstmParams8x8_8(context, node, |
1792 | &op_data->integer_lstm_param); |
1793 | |
1794 | // Allocate scratch buffer. Need 6 16bit buffer with size n_batch * n_cell |
1795 | // and 2 8bit buffer with size n_batch * n_cell. |
1796 | // |
1797 | // Handle cifg case as well, which might save one buffer. |
1798 | for (int scratch_index = 0; scratch_index < 8; ++scratch_index) { |
1799 | node->temporaries->data[scratch_index] = |
1800 | op_data->scratch_tensor_index + scratch_index; |
1801 | TfLiteTensor* scratch_tensor; |
1802 | TF_LITE_ENSURE_OK( |
1803 | context, |
1804 | GetTemporarySafe(context, node, scratch_index, &scratch_tensor)); |
1805 | if (scratch_index == 0 || scratch_index == 1) { |
1806 | scratch_tensor->type = kTfLiteInt8; |
1807 | } else { |
1808 | scratch_tensor->type = kTfLiteInt16; |
1809 | } |
1810 | scratch_tensor->allocation_type = kTfLiteArenaRw; |
1811 | const int scratch_dimension[2] = {n_batch, n_cell}; |
1812 | if (!TfLiteIntArrayEqualsArray(scratch_tensor->dims, 2, |
1813 | scratch_dimension)) { |
1814 | TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); |
1815 | scratch_buffer_size->data[0] = n_batch; |
1816 | scratch_buffer_size->data[1] = n_cell; |
1817 | TF_LITE_ENSURE_OK(context, |
1818 | context->ResizeTensor(context, scratch_tensor, |
1819 | scratch_buffer_size)); |
1820 | } |
1821 | } |
1822 | } |
1823 | } |
1824 | return kTfLiteOk; |
1825 | } |
1826 | // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc) |
1827 | |
1828 | // LINT.IfChange |
1829 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
1830 | const auto* params = static_cast<TfLiteLSTMParams*>(node->builtin_data); |
1831 | OpData* op_data = static_cast<OpData*>(node->user_data); |
1832 | |
1833 | const TfLiteTensor* input; |
1834 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input)); |
1835 | |
1836 | const TfLiteTensor* input_to_input_weights = |
1837 | GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); |
1838 | const TfLiteTensor* input_to_forget_weights; |
1839 | TF_LITE_ENSURE_OK(context, |
1840 | GetInputSafe(context, node, kInputToForgetWeightsTensor, |
1841 | &input_to_forget_weights)); |
1842 | const TfLiteTensor* input_to_cell_weights; |
1843 | TF_LITE_ENSURE_OK(context, |
1844 | GetInputSafe(context, node, kInputToCellWeightsTensor, |
1845 | &input_to_cell_weights)); |
1846 | const TfLiteTensor* input_to_output_weights; |
1847 | TF_LITE_ENSURE_OK(context, |
1848 | GetInputSafe(context, node, kInputToOutputWeightsTensor, |
1849 | &input_to_output_weights)); |
1850 | |
1851 | const TfLiteTensor* recurrent_to_input_weights = |
1852 | GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); |
1853 | const TfLiteTensor* recurrent_to_forget_weights; |
1854 | TF_LITE_ENSURE_OK(context, |
1855 | GetInputSafe(context, node, kRecurrentToForgetWeightsTensor, |
1856 | &recurrent_to_forget_weights)); |
1857 | const TfLiteTensor* recurrent_to_cell_weights; |
1858 | TF_LITE_ENSURE_OK(context, |
1859 | GetInputSafe(context, node, kRecurrentToCellWeightsTensor, |
1860 | &recurrent_to_cell_weights)); |
1861 | const TfLiteTensor* recurrent_to_output_weights; |
1862 | TF_LITE_ENSURE_OK(context, |
1863 | GetInputSafe(context, node, kRecurrentToOutputWeightsTensor, |
1864 | &recurrent_to_output_weights)); |
1865 | |
1866 | const TfLiteTensor* cell_to_input_weights = |
1867 | GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); |
1868 | const TfLiteTensor* cell_to_forget_weights = |
1869 | GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); |
1870 | const TfLiteTensor* cell_to_output_weights = |
1871 | GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); |
1872 | |
1873 | const TfLiteTensor* input_layer_norm_coefficients = |
1874 | GetOptionalInputTensor(context, node, kInputLayerNormCoefficientsTensor); |
1875 | const TfLiteTensor* forget_layer_norm_coefficients = |
1876 | GetOptionalInputTensor(context, node, kForgetLayerNormCoefficientsTensor); |
1877 | const TfLiteTensor* cell_layer_norm_coefficients = |
1878 | GetOptionalInputTensor(context, node, kCellLayerNormCoefficientsTensor); |
1879 | const TfLiteTensor* output_layer_norm_coefficients = |
1880 | GetOptionalInputTensor(context, node, kOutputLayerNormCoefficientsTensor); |
1881 | |
1882 | const TfLiteTensor* input_gate_bias = |
1883 | GetOptionalInputTensor(context, node, kInputGateBiasTensor); |
1884 | const TfLiteTensor* forget_gate_bias; |
1885 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kForgetGateBiasTensor, |
1886 | &forget_gate_bias)); |
1887 | const TfLiteTensor* cell_gate_bias; |
1888 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCellGateBiasTensor, |
1889 | &cell_gate_bias)); |
1890 | const TfLiteTensor* output_gate_bias; |
1891 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kOutputGateBiasTensor, |
1892 | &output_gate_bias)); |
1893 | |
1894 | const TfLiteTensor* projection_weights = |
1895 | GetOptionalInputTensor(context, node, kProjectionWeightsTensor); |
1896 | const TfLiteTensor* projection_bias = |
1897 | GetOptionalInputTensor(context, node, kProjectionBiasTensor); |
1898 | |
1899 | TfLiteTensor* output_state = |
1900 | GetVariableInput(context, node, kOutputStateTensor); |
1901 | TFLITE_DCHECK(output_state != nullptr); |
1902 | TfLiteTensor* cell_state = GetVariableInput(context, node, kCellStateTensor); |
1903 | TFLITE_DCHECK(cell_state != nullptr); |
1904 | |
1905 | TfLiteTensor* output; |
1906 | TF_LITE_ENSURE_OK(context, |
1907 | GetOutputSafe(context, node, kOutputTensor, &output)); |
1908 | |
1909 | switch (input_to_output_weights->type) { |
1910 | case kTfLiteFloat32: { |
1911 | // Index the scratch buffers pointers to the global scratch buffer. |
1912 | TfLiteTensor* scratch_buffer; |
1913 | TF_LITE_ENSURE_OK(context, |
1914 | GetTemporarySafe(context, node, 0, &scratch_buffer)); |
1915 | return lstm_eval::EvalFloat( |
1916 | input, input_to_input_weights, input_to_forget_weights, |
1917 | input_to_cell_weights, input_to_output_weights, |
1918 | recurrent_to_input_weights, recurrent_to_forget_weights, |
1919 | recurrent_to_cell_weights, recurrent_to_output_weights, |
1920 | cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, |
1921 | input_layer_norm_coefficients, forget_layer_norm_coefficients, |
1922 | cell_layer_norm_coefficients, output_layer_norm_coefficients, |
1923 | /*aux_input=*/nullptr, |
1924 | /*aux_input_to_input_weights=*/nullptr, |
1925 | /*aux_input_to_forget_weights=*/nullptr, |
1926 | /*aux_input_to_cell_weights=*/nullptr, |
1927 | /*aux_input_to_output_weights=*/nullptr, input_gate_bias, |
1928 | forget_gate_bias, cell_gate_bias, output_gate_bias, |
1929 | projection_weights, projection_bias, params, |
1930 | /*forward_sequence=*/true, |
1931 | /*time_major=*/true, |
1932 | /*output_offset=*/0, scratch_buffer, output_state, cell_state, output, |
1933 | CpuBackendContext::GetFromContext(context)); |
1934 | } |
1935 | case kTfLiteUInt8: |
1936 | case kTfLiteInt8: { |
1937 | const bool is_hybrid = (input->type == kTfLiteFloat32); |
1938 | const bool is_sparse = input_to_output_weights->sparsity != nullptr; |
1939 | if (is_hybrid) { |
1940 | TfLiteTensor* row_sums; |
1941 | TF_LITE_ENSURE_OK(context, |
1942 | GetTemporarySafe(context, node, kRowSums, &row_sums)); |
1943 | const int row_sums_size = row_sums->dims->data[0]; |
1944 | if (is_sparse) { |
1945 | TfLiteTensor* input_to_input_weights_ledger = |
1946 | &context->tensors[op_data->ledger_index + |
1947 | kInputToInputWeightsLedgerOffset]; |
1948 | TfLiteTensor* input_to_forget_weights_ledger = |
1949 | &context->tensors[op_data->ledger_index + |
1950 | kInputToForgetWeightsLedgerOffset]; |
1951 | TfLiteTensor* input_to_cell_weights_ledger = |
1952 | &context->tensors[op_data->ledger_index + |
1953 | kInputToCellWeightsLedgerOffset]; |
1954 | TfLiteTensor* input_to_output_weights_ledger = |
1955 | &context->tensors[op_data->ledger_index + |
1956 | kInputToOutputWeightsLedgerOffset]; |
1957 | TfLiteTensor* recurrent_to_input_weights_ledger = |
1958 | &context->tensors[op_data->ledger_index + |
1959 | kRecurrentToInputWeightsLedgerOffset]; |
1960 | TfLiteTensor* recurrent_to_forget_weights_ledger = |
1961 | &context->tensors[op_data->ledger_index + |
1962 | kRecurrentToForgetWeightsLedgerOffset]; |
1963 | TfLiteTensor* recurrent_to_cell_weights_ledger = |
1964 | &context->tensors[op_data->ledger_index + |
1965 | kRecurrentToCellWeightsLedgerOffset]; |
1966 | TfLiteTensor* recurrent_to_output_weights_ledger = |
1967 | &context->tensors[op_data->ledger_index + |
1968 | kRecurrentToOutputWeightsLedgerOffset]; |
1969 | TfLiteTensor* projection_weights_ledger = |
1970 | &context->tensors[op_data->ledger_index + |
1971 | kProjectionWeightsLedgerOffset]; |
1972 | if (!op_data->ledger_initialized) { |
1973 | copy_ledger(input_to_input_weights == nullptr |
1974 | ? nullptr |
1975 | : input_to_input_weights->sparsity, |
1976 | input_to_input_weights_ledger); |
1977 | copy_ledger(input_to_forget_weights->sparsity, |
1978 | input_to_forget_weights_ledger); |
1979 | copy_ledger(input_to_cell_weights->sparsity, |
1980 | input_to_cell_weights_ledger); |
1981 | copy_ledger(input_to_output_weights->sparsity, |
1982 | input_to_output_weights_ledger); |
1983 | copy_ledger(recurrent_to_input_weights == nullptr |
1984 | ? nullptr |
1985 | : recurrent_to_input_weights->sparsity, |
1986 | recurrent_to_input_weights_ledger); |
1987 | copy_ledger(recurrent_to_forget_weights->sparsity, |
1988 | recurrent_to_forget_weights_ledger); |
1989 | copy_ledger(recurrent_to_cell_weights->sparsity, |
1990 | recurrent_to_cell_weights_ledger); |
1991 | copy_ledger(recurrent_to_output_weights->sparsity, |
1992 | recurrent_to_output_weights_ledger); |
1993 | copy_ledger(projection_weights->sparsity, |
1994 | projection_weights_ledger); |
1995 | op_data->ledger_initialized = true; |
1996 | } |
1997 | return lstm_eval::EvalHybrid( |
1998 | input, input_to_input_weights, input_to_input_weights_ledger, |
1999 | input_to_forget_weights, input_to_forget_weights_ledger, |
2000 | input_to_cell_weights, input_to_cell_weights_ledger, |
2001 | input_to_output_weights, input_to_output_weights_ledger, |
2002 | recurrent_to_input_weights, recurrent_to_input_weights_ledger, |
2003 | recurrent_to_forget_weights, recurrent_to_forget_weights_ledger, |
2004 | recurrent_to_cell_weights, recurrent_to_cell_weights_ledger, |
2005 | recurrent_to_output_weights, recurrent_to_output_weights_ledger, |
2006 | cell_to_input_weights, cell_to_forget_weights, |
2007 | cell_to_output_weights, input_layer_norm_coefficients, |
2008 | forget_layer_norm_coefficients, cell_layer_norm_coefficients, |
2009 | output_layer_norm_coefficients, |
2010 | /*aux_input=*/nullptr, |
2011 | /*aux_input_to_input_weights=*/nullptr, |
2012 | /*aux_input_to_forget_weights=*/nullptr, |
2013 | /*aux_input_to_cell_weights=*/nullptr, |
2014 | /*aux_input_to_output_weights=*/nullptr, input_gate_bias, |
2015 | forget_gate_bias, cell_gate_bias, output_gate_bias, |
2016 | projection_weights, projection_weights_ledger, projection_bias, |
2017 | params, |
2018 | /*forward_sequence=*/true, /*time_major=*/true, |
2019 | /*output_offset=*/0, GetTemporary(context, node, kScratchBuffer), |
2020 | GetTemporary(context, node, kInputScalingFactors), |
2021 | /*aux_input_sf=*/nullptr, |
2022 | GetTemporary(context, node, kOutputStateScalingFactors), |
2023 | GetTemporary(context, node, kProductScalingFactors), |
2024 | GetTemporary(context, node, kRecoveredCellWeights), |
2025 | GetTemporary(context, node, kInputQuantized), |
2026 | /*aux_input_quantized=*/nullptr, |
2027 | GetTemporary(context, node, kOutputStateQuantized), |
2028 | GetTemporary(context, node, kCellStateQuantized), output_state, |
2029 | cell_state, GetTemporary(context, node, kAccumScratch), output, |
2030 | GetTemporary(context, node, kInputZeroPoints), |
2031 | /*aux_input_zp=*/nullptr, |
2032 | GetTemporary(context, node, kOutputStateZeroPoints), row_sums, |
2033 | row_sums_size, &op_data->compute_row_sums, |
2034 | CpuBackendContext::GetFromContext(context)); |
2035 | } |
2036 | return lstm_eval::EvalHybrid( |
2037 | input, input_to_input_weights, |
2038 | /*input_to_input_weights_ledger*/ nullptr, input_to_forget_weights, |
2039 | /*input_to_forget_weights_ledger*/ nullptr, input_to_cell_weights, |
2040 | /*input_to_cell_weights_ledger*/ nullptr, input_to_output_weights, |
2041 | /*input_to_output_weights_ledger*/ nullptr, |
2042 | recurrent_to_input_weights, |
2043 | /*recurrent_to_input_weights_ledger*/ nullptr, |
2044 | recurrent_to_forget_weights, |
2045 | /*recurrent_to_forget_weights_ledger*/ nullptr, |
2046 | recurrent_to_cell_weights, |
2047 | /*recurrent_to_cell_weights_ledger*/ nullptr, |
2048 | recurrent_to_output_weights, |
2049 | /*recurrent_to_output_weights_ledger*/ nullptr, |
2050 | cell_to_input_weights, cell_to_forget_weights, |
2051 | cell_to_output_weights, input_layer_norm_coefficients, |
2052 | forget_layer_norm_coefficients, cell_layer_norm_coefficients, |
2053 | output_layer_norm_coefficients, /*aux_input=*/nullptr, |
2054 | /*aux_input_to_input_weights=*/nullptr, |
2055 | /*aux_input_to_forget_weights=*/nullptr, |
2056 | /*aux_input_to_cell_weights=*/nullptr, |
2057 | /*aux_input_to_output_weights=*/nullptr, input_gate_bias, |
2058 | forget_gate_bias, cell_gate_bias, output_gate_bias, |
2059 | projection_weights, /*projection_weights_ledger*/ nullptr, |
2060 | projection_bias, params, |
2061 | /*forward_sequence=*/true, /*time_major=*/true, /*output_offset=*/0, |
2062 | GetTemporary(context, node, kScratchBuffer), |
2063 | GetTemporary(context, node, kInputScalingFactors), |
2064 | /*aux_input_sf=*/nullptr, |
2065 | GetTemporary(context, node, kOutputStateScalingFactors), |
2066 | GetTemporary(context, node, kProductScalingFactors), |
2067 | GetTemporary(context, node, kRecoveredCellWeights), |
2068 | GetTemporary(context, node, kInputQuantized), |
2069 | /*aux_input_quantized=*/nullptr, |
2070 | GetTemporary(context, node, kOutputStateQuantized), |
2071 | GetTemporary(context, node, kCellStateQuantized), output_state, |
2072 | cell_state, GetTemporary(context, node, kAccumScratch), output, |
2073 | GetTemporary(context, node, kInputZeroPoints), |
2074 | /*aux_input_zp=*/nullptr, |
2075 | GetTemporary(context, node, kOutputStateZeroPoints), row_sums, |
2076 | row_sums_size, &op_data->compute_row_sums, |
2077 | CpuBackendContext::GetFromContext(context)); |
2078 | } |
2079 | const int num_intermediate_tensors = node->intermediates->size; |
2080 | TfLiteTensor* scratch0; |
2081 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 0, &scratch0)); |
2082 | TfLiteTensor* scratch1; |
2083 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 1, &scratch1)); |
2084 | TfLiteTensor* scratch2; |
2085 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 2, &scratch2)); |
2086 | TfLiteTensor* scratch3; |
2087 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 3, &scratch3)); |
2088 | TfLiteTensor* scratch4; |
2089 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 4, &scratch4)); |
2090 | TfLiteTensor* scratch5; |
2091 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 5, &scratch5)); |
2092 | if (num_intermediate_tensors == 5) { |
2093 | return lstm_eval::EvalInteger8x8_16( |
2094 | input, input_to_input_weights, input_to_forget_weights, |
2095 | input_to_cell_weights, input_to_output_weights, |
2096 | recurrent_to_input_weights, recurrent_to_forget_weights, |
2097 | recurrent_to_cell_weights, recurrent_to_output_weights, |
2098 | cell_to_input_weights, cell_to_forget_weights, |
2099 | cell_to_output_weights, input_layer_norm_coefficients, |
2100 | forget_layer_norm_coefficients, cell_layer_norm_coefficients, |
2101 | output_layer_norm_coefficients, input_gate_bias, forget_gate_bias, |
2102 | cell_gate_bias, output_gate_bias, projection_weights, |
2103 | projection_bias, params, /*forward_sequence=*/true, |
2104 | /*time_major=*/true, &op_data->integer_lstm_param, output_state, |
2105 | cell_state, output, scratch0, scratch1, scratch2, scratch3, |
2106 | scratch4, scratch5, CpuBackendContext::GetFromContext(context)); |
2107 | } |
2108 | TfLiteTensor* scratch6; |
2109 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 6, &scratch6)); |
2110 | TfLiteTensor* scratch7; |
2111 | TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, 7, &scratch7)); |
2112 | return lstm_eval::EvalInteger8x8_8( |
2113 | input, input_to_input_weights, input_to_forget_weights, |
2114 | input_to_cell_weights, input_to_output_weights, |
2115 | recurrent_to_input_weights, recurrent_to_forget_weights, |
2116 | recurrent_to_cell_weights, recurrent_to_output_weights, |
2117 | cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights, |
2118 | input_layer_norm_coefficients, forget_layer_norm_coefficients, |
2119 | cell_layer_norm_coefficients, output_layer_norm_coefficients, |
2120 | input_gate_bias, forget_gate_bias, cell_gate_bias, output_gate_bias, |
2121 | projection_weights, projection_bias, params, output_state, cell_state, |
2122 | output, &op_data->integer_lstm_param, scratch0, scratch1, scratch2, |
2123 | scratch3, scratch4, scratch5, scratch6, scratch7); |
2124 | } |
2125 | default: |
2126 | TF_LITE_KERNEL_LOG(context, "Type %d is not currently supported." , |
2127 | input_to_output_weights->type); |
2128 | return kTfLiteError; |
2129 | } |
2130 | } |
2131 | // LINT.ThenChange(//tensorflow/lite/tools/optimize/calibration/builtin_logging_ops/lstm.cc) |
2132 | |
2133 | } // namespace full |
2134 | |
2135 | // For basic kernel (5-inputs). |
2136 | namespace basic { |
2137 | |
2138 | enum InputTensor { |
2139 | kInputData = 0, |
2140 | kInputPrevActivation = 1, |
2141 | kInputWeights = 2, |
2142 | kInputBiases = 3, |
2143 | kInputPrevState = 4, |
2144 | kInputNum = 5, |
2145 | }; |
2146 | |
2147 | enum OutputTensor { |
2148 | kOutputActivation = 0, |
2149 | kOutputState = 1, |
2150 | kOutputConcatTemp = 2, |
2151 | kOutputActivationTemp = 3, |
2152 | kOutputNum = 4, |
2153 | }; |
2154 | |
2155 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
2156 | auto* op_data = new OpData(); |
2157 | op_data->kernel_type = kTfLiteLSTMBasicKernel; |
2158 | // `scratch_tensor_index` is unused in this kernel. |
2159 | op_data->scratch_tensor_index = -1; |
2160 | return op_data; |
2161 | } |
2162 | |
2163 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
2164 | TF_LITE_ENSURE(context, node->inputs->size == kInputNum); |
2165 | TF_LITE_ENSURE(context, node->outputs->size == kOutputNum); |
2166 | |
2167 | const TfLiteTensor* input; |
2168 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input)); |
2169 | const TfLiteTensor* prev_activation; |
2170 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation, |
2171 | &prev_activation)); |
2172 | const TfLiteTensor* weights; |
2173 | TF_LITE_ENSURE_OK(context, |
2174 | GetInputSafe(context, node, kInputWeights, &weights)); |
2175 | const TfLiteTensor* bias; |
2176 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias)); |
2177 | const TfLiteTensor* prev_state; |
2178 | TF_LITE_ENSURE_OK(context, |
2179 | GetInputSafe(context, node, kInputPrevState, &prev_state)); |
2180 | |
2181 | TF_LITE_ENSURE_EQ(context, input->dims->size, 2); |
2182 | const int num_batches = input->dims->data[0]; |
2183 | const int input_depth = input->dims->data[1]; |
2184 | |
2185 | TF_LITE_ENSURE_EQ(context, prev_activation->dims->size, 2); |
2186 | TF_LITE_ENSURE_EQ(context, prev_activation->dims->data[0], num_batches); |
2187 | const int activation_depth = prev_activation->dims->data[1]; |
2188 | const int total_depth = input_depth + activation_depth; |
2189 | |
2190 | TF_LITE_ENSURE_EQ(context, weights->dims->size, 2); |
2191 | TF_LITE_ENSURE_EQ(context, weights->dims->data[0], 4 * activation_depth); |
2192 | TF_LITE_ENSURE_EQ(context, weights->dims->data[1], total_depth); |
2193 | |
2194 | TF_LITE_ENSURE_EQ(context, bias->dims->size, 1); |
2195 | TF_LITE_ENSURE_EQ(context, bias->dims->data[0], 4 * activation_depth); |
2196 | |
2197 | TF_LITE_ENSURE_EQ(context, prev_state->dims->size, 2); |
2198 | TF_LITE_ENSURE_EQ(context, prev_state->dims->data[0], num_batches); |
2199 | TF_LITE_ENSURE_EQ(context, prev_state->dims->data[1], activation_depth); |
2200 | |
2201 | TfLiteTensor* activation_out; |
2202 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation, |
2203 | &activation_out)); |
2204 | TfLiteTensor* state_out; |
2205 | TF_LITE_ENSURE_OK(context, |
2206 | GetOutputSafe(context, node, kOutputState, &state_out)); |
2207 | TfLiteTensor* concat_temp; |
2208 | TF_LITE_ENSURE_OK( |
2209 | context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp)); |
2210 | TfLiteTensor* activation_temp; |
2211 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp, |
2212 | &activation_temp)); |
2213 | |
2214 | TF_LITE_ENSURE_OK(context, context->ResizeTensor( |
2215 | context, activation_out, |
2216 | TfLiteIntArrayCopy(prev_activation->dims))); |
2217 | TF_LITE_ENSURE_OK( |
2218 | context, context->ResizeTensor(context, state_out, |
2219 | TfLiteIntArrayCopy(prev_state->dims))); |
2220 | |
2221 | TfLiteIntArray* concat_temp_size = TfLiteIntArrayCreate(2); |
2222 | concat_temp_size->data[0] = num_batches; |
2223 | concat_temp_size->data[1] = total_depth; |
2224 | TF_LITE_ENSURE_OK( |
2225 | context, context->ResizeTensor(context, concat_temp, concat_temp_size)); |
2226 | TfLiteIntArray* activation_temp_size = TfLiteIntArrayCreate(2); |
2227 | activation_temp_size->data[0] = num_batches; |
2228 | activation_temp_size->data[1] = 4 * activation_depth; |
2229 | TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, activation_temp, |
2230 | activation_temp_size)); |
2231 | |
2232 | // Set the state tensors as persistent. |
2233 | for (auto index : {kInputPrevActivation, kInputPrevState}) { |
2234 | TfLiteTensor* tensor = &context->tensors[node->inputs->data[index]]; |
2235 | tensor->allocation_type = kTfLiteArenaRwPersistent; |
2236 | } |
2237 | return kTfLiteOk; |
2238 | } |
2239 | |
2240 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
2241 | const TfLiteTensor* input; |
2242 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputData, &input)); |
2243 | const TfLiteTensor* prev_activation; |
2244 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputPrevActivation, |
2245 | &prev_activation)); |
2246 | const TfLiteTensor* weights; |
2247 | TF_LITE_ENSURE_OK(context, |
2248 | GetInputSafe(context, node, kInputWeights, &weights)); |
2249 | const TfLiteTensor* bias; |
2250 | TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputBiases, &bias)); |
2251 | const TfLiteTensor* prev_state; |
2252 | TF_LITE_ENSURE_OK(context, |
2253 | GetInputSafe(context, node, kInputPrevState, &prev_state)); |
2254 | |
2255 | TfLiteTensor* activation_out; |
2256 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivation, |
2257 | &activation_out)); |
2258 | TfLiteTensor* state_out; |
2259 | TF_LITE_ENSURE_OK(context, |
2260 | GetOutputSafe(context, node, kOutputState, &state_out)); |
2261 | TfLiteTensor* concat_temp; |
2262 | TF_LITE_ENSURE_OK( |
2263 | context, GetOutputSafe(context, node, kOutputConcatTemp, &concat_temp)); |
2264 | TfLiteTensor* activation_temp; |
2265 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutputActivationTemp, |
2266 | &activation_temp)); |
2267 | |
2268 | if (input->type == kTfLiteFloat32 && |
2269 | prev_activation->type == kTfLiteFloat32 && |
2270 | weights->type == kTfLiteFloat32 && bias->type == kTfLiteFloat32 && |
2271 | prev_state->type == kTfLiteFloat32 && state_out->type == kTfLiteFloat32 && |
2272 | activation_out->type == kTfLiteFloat32 && |
2273 | concat_temp->type == kTfLiteFloat32 && |
2274 | activation_temp->type == kTfLiteFloat32) { |
2275 | tflite::LstmCellParams op_params; |
2276 | // Float LSTM cell does not need parameters to be set: leave untouched. |
2277 | optimized_ops::LstmCell( |
2278 | op_params, |
2279 | // Inputs. |
2280 | GetTensorShape(input), GetTensorData<float>(input), |
2281 | GetTensorShape(prev_activation), GetTensorData<float>(prev_activation), |
2282 | GetTensorShape(weights), GetTensorData<float>(weights), |
2283 | GetTensorShape(bias), GetTensorData<float>(bias), |
2284 | GetTensorShape(prev_state), GetTensorData<float>(prev_state), |
2285 | // Outputs. |
2286 | GetTensorShape(state_out), GetTensorData<float>(state_out), |
2287 | GetTensorShape(activation_out), GetTensorData<float>(activation_out), |
2288 | GetTensorShape(concat_temp), GetTensorData<float>(concat_temp), |
2289 | GetTensorShape(activation_temp), GetTensorData<float>(activation_temp), |
2290 | CpuBackendContext::GetFromContext(context)); |
2291 | } else if (input->type == kTfLiteUInt8 && |
2292 | prev_activation->type == kTfLiteUInt8 && |
2293 | weights->type == kTfLiteUInt8 && bias->type == kTfLiteInt32 && |
2294 | prev_state->type == kTfLiteInt16 && |
2295 | state_out->type == kTfLiteInt16 && |
2296 | activation_out->type == kTfLiteUInt8 && |
2297 | concat_temp->type == kTfLiteUInt8 && |
2298 | activation_temp->type == kTfLiteInt16) { |
2299 | int state_scale_log2_rounded; |
2300 | if (!CheckedLog2(state_out->params.scale, &state_scale_log2_rounded)) { |
2301 | TF_LITE_KERNEL_LOG( |
2302 | context, |
2303 | "The internal state of a LSTM cell must have a power-of-two scale." ); |
2304 | return kTfLiteError; |
2305 | } |
2306 | const int state_integer_bits = 15 + state_scale_log2_rounded; |
2307 | if (state_integer_bits != 4) { |
2308 | TF_LITE_KERNEL_LOG(context, |
2309 | "The only case of quantized LstmCell currently " |
2310 | "supported is with StateIntegerBits==4" ); |
2311 | return kTfLiteError; |
2312 | } |
2313 | |
2314 | double real_accum_multiplier = 4096 * bias->params.scale; |
2315 | int32 accum_multiplier; |
2316 | int accum_shift; |
2317 | tflite::QuantizeMultiplier(real_accum_multiplier, &accum_multiplier, |
2318 | &accum_shift); |
2319 | tflite::LstmCellParams op_params; |
2320 | op_params.weights_zero_point = weights->params.zero_point; |
2321 | op_params.accum_multiplier = accum_multiplier; |
2322 | op_params.accum_shift = accum_shift; |
2323 | optimized_ops::LstmCell<4>( |
2324 | op_params, |
2325 | // Inputs. |
2326 | GetTensorShape(input), GetTensorData<uint8_t>(input), |
2327 | GetTensorShape(prev_activation), |
2328 | GetTensorData<uint8_t>(prev_activation), GetTensorShape(weights), |
2329 | GetTensorData<uint8_t>(weights), GetTensorShape(bias), |
2330 | GetTensorData<int32_t>(bias), GetTensorShape(prev_state), |
2331 | GetTensorData<int16_t>(prev_state), |
2332 | // Outputs. |
2333 | GetTensorShape(state_out), GetTensorData<int16_t>(state_out), |
2334 | GetTensorShape(activation_out), GetTensorData<uint8_t>(activation_out), |
2335 | GetTensorShape(concat_temp), GetTensorData<uint8_t>(concat_temp), |
2336 | GetTensorShape(activation_temp), |
2337 | GetTensorData<int16_t>(activation_temp), |
2338 | CpuBackendContext::GetFromContext(context)); |
2339 | } else { |
2340 | TF_LITE_KERNEL_LOG(context, |
2341 | "Unsupported combination of data types for LstmCell" ); |
2342 | return kTfLiteError; |
2343 | } |
2344 | |
2345 | memcpy(prev_activation->data.raw, activation_out->data.raw, |
2346 | activation_out->bytes); |
2347 | memcpy(prev_state->data.raw, state_out->data.raw, state_out->bytes); |
2348 | |
2349 | return kTfLiteOk; |
2350 | } |
2351 | |
2352 | } // namespace basic |
2353 | |
2354 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
2355 | const auto* params = reinterpret_cast<const TfLiteLSTMParams*>(buffer); |
2356 | switch (params->kernel_type) { |
2357 | case kTfLiteLSTMFullKernel: |
2358 | return full::Init(context, buffer, length); |
2359 | case kTfLiteLSTMBasicKernel: |
2360 | return basic::Init(context, buffer, length); |
2361 | default: |
2362 | return nullptr; |
2363 | } |
2364 | } |
2365 | void Free(TfLiteContext* context, void* buffer) { |
2366 | delete static_cast<OpData*>(buffer); |
2367 | } |
2368 | |
2369 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
2370 | const auto* op_data = static_cast<const OpData*>(node->user_data); |
2371 | switch (op_data->kernel_type) { |
2372 | case kTfLiteLSTMFullKernel: |
2373 | return full::Prepare(context, node); |
2374 | case kTfLiteLSTMBasicKernel: |
2375 | return basic::Prepare(context, node); |
2376 | default: |
2377 | return kTfLiteError; |
2378 | } |
2379 | } |
2380 | |
2381 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
2382 | const auto* op_data = static_cast<const OpData*>(node->user_data); |
2383 | switch (op_data->kernel_type) { |
2384 | case kTfLiteLSTMFullKernel: |
2385 | return full::Eval(context, node); |
2386 | case kTfLiteLSTMBasicKernel: |
2387 | return basic::Eval(context, node); |
2388 | default: |
2389 | return kTfLiteError; |
2390 | } |
2391 | } |
2392 | |
2393 | } // namespace lstm |
2394 | |
2395 | TfLiteRegistration* Register_LSTM() { |
2396 | static TfLiteRegistration r = {lstm::Init, lstm::Free, lstm::Prepare, |
2397 | lstm::Eval}; |
2398 | return &r; |
2399 | } |
2400 | |
2401 | } // namespace builtin |
2402 | } // namespace ops |
2403 | } // namespace tflite |
2404 | |